Bugfix/normalize core comparisons (and Python 3 update fixes) (#952)

* Correct implementation of comparisons to be Pythonic

If a comparison cannot be made return NotImplemented, this allows the
Python interpreter to try other comparisons (eg __ieq__) and either
return False (in the case of __eq__) or raise a TypeError
exception (eg in the case of __lt__).

* Python 3 updates

* Add a more helpful message in exception if platform is not defined

* Added a basic pre-commit check
This commit is contained in:
Tim Savage 2020-01-14 09:35:55 +11:00 committed by Brandon Davidson
parent 3b689ef39c
commit 30ecb58e06
4 changed files with 67 additions and 56 deletions

11
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,11 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: flake8

View file

@ -168,34 +168,34 @@ class TimePeriod:
return self.days or 0
def __eq__(self, other):
if not isinstance(other, TimePeriod):
raise ValueError("other must be TimePeriod")
return self.total_microseconds == other.total_microseconds
if isinstance(other, TimePeriod):
return self.total_microseconds == other.total_microseconds
return NotImplemented
def __ne__(self, other):
if not isinstance(other, TimePeriod):
raise ValueError("other must be TimePeriod")
return self.total_microseconds != other.total_microseconds
if isinstance(other, TimePeriod):
return self.total_microseconds != other.total_microseconds
return NotImplemented
def __lt__(self, other):
if not isinstance(other, TimePeriod):
raise ValueError("other must be TimePeriod")
return self.total_microseconds < other.total_microseconds
if isinstance(other, TimePeriod):
return self.total_microseconds < other.total_microseconds
return NotImplemented
def __gt__(self, other):
if not isinstance(other, TimePeriod):
raise ValueError("other must be TimePeriod")
return self.total_microseconds > other.total_microseconds
if isinstance(other, TimePeriod):
return self.total_microseconds > other.total_microseconds
return NotImplemented
def __le__(self, other):
if not isinstance(other, TimePeriod):
raise ValueError("other must be TimePeriod")
return self.total_microseconds <= other.total_microseconds
if isinstance(other, TimePeriod):
return self.total_microseconds <= other.total_microseconds
return NotImplemented
def __ge__(self, other):
if not isinstance(other, TimePeriod):
raise ValueError("other must be TimePeriod")
return self.total_microseconds >= other.total_microseconds
if isinstance(other, TimePeriod):
return self.total_microseconds >= other.total_microseconds
return NotImplemented
class TimePeriodMicroseconds(TimePeriod):
@ -264,7 +264,7 @@ class ID:
else:
self.is_manual = is_manual
self.is_declaration = is_declaration
self.type = type # type: Optional['MockObjClass']
self.type: Optional['MockObjClass'] = type
def resolve(self, registered_ids):
from esphome.config_validation import RESERVED_IDS
@ -282,13 +282,13 @@ class ID:
return self.id
def __repr__(self):
return 'ID<{} declaration={}, type={}, manual={}>'.format(
self.id, self.is_declaration, self.type, self.is_manual)
return (f'ID<{self.id} declaration={self.is_declaration}, '
f'type={self.type}, manual={self.is_manual}>')
def __eq__(self, other):
if not isinstance(other, ID):
raise ValueError("other must be ID {} {}".format(type(other), other))
return self.id == other.id
if isinstance(other, ID):
return self.id == other.id
return NotImplemented
def __hash__(self):
return hash(self.id)
@ -299,11 +299,10 @@ class ID:
class DocumentLocation:
def __init__(self, document, line, column):
# type: (str, int, int) -> None
self.document = document # type: str
self.line = line # type: int
self.column = column # type: int
def __init__(self, document: str, line: int, column: int):
self.document: str = document
self.line: int = line
self.column: int = column
@classmethod
def from_mark(cls, mark):
@ -318,10 +317,9 @@ class DocumentLocation:
class DocumentRange:
def __init__(self, start_mark, end_mark):
# type: (DocumentLocation, DocumentLocation) -> None
self.start_mark = start_mark # type: DocumentLocation
self.end_mark = end_mark # type: DocumentLocation
def __init__(self, start_mark: DocumentLocation, end_mark: DocumentLocation):
self.start_mark: DocumentLocation = start_mark
self.end_mark: DocumentLocation = end_mark
@classmethod
def from_marks(cls, start_mark, end_mark):
@ -359,7 +357,9 @@ class Define:
return hash(self.as_tuple)
def __eq__(self, other):
return isinstance(self, type(other)) and self.as_tuple == other.as_tuple
if isinstance(other, Define):
return self.as_tuple == other.as_tuple
return NotImplemented
class Library:
@ -381,7 +381,9 @@ class Library:
return hash(self.as_tuple)
def __eq__(self, other):
return isinstance(self, type(other)) and self.as_tuple == other.as_tuple
if isinstance(other, Library):
return self.as_tuple == other.as_tuple
return NotImplemented
def coroutine(func):
@ -462,19 +464,19 @@ class EsphomeCore:
self.vscode = False
self.ace = False
# The name of the node
self.name = None # type: str
self.name: Optional[str] = None
# The relative path to the configuration YAML
self.config_path = None # type: str
self.config_path: Optional[str] = None
# The relative path to where all build files are stored
self.build_path = None # type: str
self.build_path: Optional[str] = None
# The platform (ESP8266, ESP32) of this device
self.esp_platform = None # type: str
self.esp_platform: Optional[str] = None
# The board that's used (for example nodemcuv2)
self.board = None # type: str
self.board: Optional[str] = None
# The full raw configuration
self.raw_config = {} # type: ConfigType
self.raw_config: ConfigType = {}
# The validated configuration, this is None until the config has been validated
self.config = {} # type: ConfigType
self.config: ConfigType = {}
# The pending tasks in the task queue (mostly for C++ generation)
# This is a priority queue (with heapq)
# Each item is a tuple of form: (-priority, unique number, task)
@ -482,20 +484,20 @@ class EsphomeCore:
# Task counter for pending tasks
self.task_counter = 0
# The variable cache, for each ID this holds a MockObj of the variable obj
self.variables = {} # type: Dict[str, 'MockObj']
self.variables: Dict[str, 'MockObj'] = {}
# A list of statements that go in the main setup() block
self.main_statements = [] # type: List['Statement']
self.main_statements: List['Statement'] = []
# A list of statements to insert in the global block (includes and global variables)
self.global_statements = [] # type: List['Statement']
self.global_statements: List['Statement'] = []
# A set of platformio libraries to add to the project
self.libraries = [] # type: List[Library]
self.libraries: List[Library] = []
# A set of build flags to set in the platformio project
self.build_flags = set() # type: Set[str]
self.build_flags: Set[str] = set()
# A set of defines to set for the compile process in esphome/core/defines.h
self.defines = set() # type: Set['Define']
self.defines: Set['Define'] = set()
# A dictionary of started coroutines, used to warn when a coroutine was not
# awaited.
self.active_coroutines = {} # type: Dict[int, Any]
self.active_coroutines: Dict[int, Any] = {}
# A set of strings of names of loaded integrations, used to find namespace ID conflicts
self.loaded_integrations = set()
# A set of component IDs to track what Component subclasses are declared
@ -525,7 +527,7 @@ class EsphomeCore:
self.component_ids = set()
@property
def address(self): # type: () -> str
def address(self) -> Optional[str]:
if 'wifi' in self.config:
return self.config[CONF_WIFI][CONF_USE_ADDRESS]
@ -535,7 +537,7 @@ class EsphomeCore:
return None
@property
def comment(self): # type: () -> str
def comment(self) -> Optional[str]:
if CONF_COMMENT in self.config[CONF_ESPHOME]:
return self.config[CONF_ESPHOME][CONF_COMMENT]
@ -548,7 +550,7 @@ class EsphomeCore:
self.active_coroutines.pop(instance_id)
@property
def arduino_version(self): # type: () -> str
def arduino_version(self) -> str:
return self.config[CONF_ESPHOME][CONF_ARDUINO_VERSION]
@property
@ -587,13 +589,13 @@ class EsphomeCore:
@property
def is_esp8266(self):
if self.esp_platform is None:
raise ValueError
raise ValueError("No platform specified")
return self.esp_platform == 'ESP8266'
@property
def is_esp32(self):
if self.esp_platform is None:
raise ValueError
raise ValueError("No platform specified")
return self.esp_platform == 'ESP32'
def add_job(self, func, *args, **kwargs):

View file

@ -11,7 +11,6 @@ classifier =
Intended Audience :: End Users/Desktop
License :: OSI Approved :: MIT License
Programming Language :: C++
Programming Language :: Python :: 2
Programming Language :: Python :: 3
Topic :: Home Automation
Topic :: Home Automation

View file

@ -50,7 +50,6 @@ CLASSIFIERS = [
'Intended Audience :: End Users/Desktop',
'License :: OSI Approved :: MIT License',
'Programming Language :: C++',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 3',
'Topic :: Home Automation',
]