mirror of
https://github.com/esphome/esphome.git
synced 2024-11-09 16:57:47 +01:00
Bugfix/normalize core comparisons (and Python 3 update fixes) (#952)
* Correct implementation of comparisons to be Pythonic If a comparison cannot be made return NotImplemented, this allows the Python interpreter to try other comparisons (eg __ieq__) and either return False (in the case of __eq__) or raise a TypeError exception (eg in the case of __lt__). * Python 3 updates * Add a more helpful message in exception if platform is not defined * Added a basic pre-commit check
This commit is contained in:
parent
3b689ef39c
commit
30ecb58e06
4 changed files with 67 additions and 56 deletions
11
.pre-commit-config.yaml
Normal file
11
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,11 @@
|
|||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: flake8
|
110
esphome/core.py
110
esphome/core.py
|
@ -168,34 +168,34 @@ class TimePeriod:
|
|||
return self.days or 0
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, TimePeriod):
|
||||
raise ValueError("other must be TimePeriod")
|
||||
return self.total_microseconds == other.total_microseconds
|
||||
if isinstance(other, TimePeriod):
|
||||
return self.total_microseconds == other.total_microseconds
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
if not isinstance(other, TimePeriod):
|
||||
raise ValueError("other must be TimePeriod")
|
||||
return self.total_microseconds != other.total_microseconds
|
||||
if isinstance(other, TimePeriod):
|
||||
return self.total_microseconds != other.total_microseconds
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, TimePeriod):
|
||||
raise ValueError("other must be TimePeriod")
|
||||
return self.total_microseconds < other.total_microseconds
|
||||
if isinstance(other, TimePeriod):
|
||||
return self.total_microseconds < other.total_microseconds
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, TimePeriod):
|
||||
raise ValueError("other must be TimePeriod")
|
||||
return self.total_microseconds > other.total_microseconds
|
||||
if isinstance(other, TimePeriod):
|
||||
return self.total_microseconds > other.total_microseconds
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if not isinstance(other, TimePeriod):
|
||||
raise ValueError("other must be TimePeriod")
|
||||
return self.total_microseconds <= other.total_microseconds
|
||||
if isinstance(other, TimePeriod):
|
||||
return self.total_microseconds <= other.total_microseconds
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, TimePeriod):
|
||||
raise ValueError("other must be TimePeriod")
|
||||
return self.total_microseconds >= other.total_microseconds
|
||||
if isinstance(other, TimePeriod):
|
||||
return self.total_microseconds >= other.total_microseconds
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class TimePeriodMicroseconds(TimePeriod):
|
||||
|
@ -264,7 +264,7 @@ class ID:
|
|||
else:
|
||||
self.is_manual = is_manual
|
||||
self.is_declaration = is_declaration
|
||||
self.type = type # type: Optional['MockObjClass']
|
||||
self.type: Optional['MockObjClass'] = type
|
||||
|
||||
def resolve(self, registered_ids):
|
||||
from esphome.config_validation import RESERVED_IDS
|
||||
|
@ -282,13 +282,13 @@ class ID:
|
|||
return self.id
|
||||
|
||||
def __repr__(self):
|
||||
return 'ID<{} declaration={}, type={}, manual={}>'.format(
|
||||
self.id, self.is_declaration, self.type, self.is_manual)
|
||||
return (f'ID<{self.id} declaration={self.is_declaration}, '
|
||||
f'type={self.type}, manual={self.is_manual}>')
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ID):
|
||||
raise ValueError("other must be ID {} {}".format(type(other), other))
|
||||
return self.id == other.id
|
||||
if isinstance(other, ID):
|
||||
return self.id == other.id
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
@ -299,11 +299,10 @@ class ID:
|
|||
|
||||
|
||||
class DocumentLocation:
|
||||
def __init__(self, document, line, column):
|
||||
# type: (str, int, int) -> None
|
||||
self.document = document # type: str
|
||||
self.line = line # type: int
|
||||
self.column = column # type: int
|
||||
def __init__(self, document: str, line: int, column: int):
|
||||
self.document: str = document
|
||||
self.line: int = line
|
||||
self.column: int = column
|
||||
|
||||
@classmethod
|
||||
def from_mark(cls, mark):
|
||||
|
@ -318,10 +317,9 @@ class DocumentLocation:
|
|||
|
||||
|
||||
class DocumentRange:
|
||||
def __init__(self, start_mark, end_mark):
|
||||
# type: (DocumentLocation, DocumentLocation) -> None
|
||||
self.start_mark = start_mark # type: DocumentLocation
|
||||
self.end_mark = end_mark # type: DocumentLocation
|
||||
def __init__(self, start_mark: DocumentLocation, end_mark: DocumentLocation):
|
||||
self.start_mark: DocumentLocation = start_mark
|
||||
self.end_mark: DocumentLocation = end_mark
|
||||
|
||||
@classmethod
|
||||
def from_marks(cls, start_mark, end_mark):
|
||||
|
@ -359,7 +357,9 @@ class Define:
|
|||
return hash(self.as_tuple)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(self, type(other)) and self.as_tuple == other.as_tuple
|
||||
if isinstance(other, Define):
|
||||
return self.as_tuple == other.as_tuple
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class Library:
|
||||
|
@ -381,7 +381,9 @@ class Library:
|
|||
return hash(self.as_tuple)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(self, type(other)) and self.as_tuple == other.as_tuple
|
||||
if isinstance(other, Library):
|
||||
return self.as_tuple == other.as_tuple
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def coroutine(func):
|
||||
|
@ -462,19 +464,19 @@ class EsphomeCore:
|
|||
self.vscode = False
|
||||
self.ace = False
|
||||
# The name of the node
|
||||
self.name = None # type: str
|
||||
self.name: Optional[str] = None
|
||||
# The relative path to the configuration YAML
|
||||
self.config_path = None # type: str
|
||||
self.config_path: Optional[str] = None
|
||||
# The relative path to where all build files are stored
|
||||
self.build_path = None # type: str
|
||||
self.build_path: Optional[str] = None
|
||||
# The platform (ESP8266, ESP32) of this device
|
||||
self.esp_platform = None # type: str
|
||||
self.esp_platform: Optional[str] = None
|
||||
# The board that's used (for example nodemcuv2)
|
||||
self.board = None # type: str
|
||||
self.board: Optional[str] = None
|
||||
# The full raw configuration
|
||||
self.raw_config = {} # type: ConfigType
|
||||
self.raw_config: ConfigType = {}
|
||||
# The validated configuration, this is None until the config has been validated
|
||||
self.config = {} # type: ConfigType
|
||||
self.config: ConfigType = {}
|
||||
# The pending tasks in the task queue (mostly for C++ generation)
|
||||
# This is a priority queue (with heapq)
|
||||
# Each item is a tuple of form: (-priority, unique number, task)
|
||||
|
@ -482,20 +484,20 @@ class EsphomeCore:
|
|||
# Task counter for pending tasks
|
||||
self.task_counter = 0
|
||||
# The variable cache, for each ID this holds a MockObj of the variable obj
|
||||
self.variables = {} # type: Dict[str, 'MockObj']
|
||||
self.variables: Dict[str, 'MockObj'] = {}
|
||||
# A list of statements that go in the main setup() block
|
||||
self.main_statements = [] # type: List['Statement']
|
||||
self.main_statements: List['Statement'] = []
|
||||
# A list of statements to insert in the global block (includes and global variables)
|
||||
self.global_statements = [] # type: List['Statement']
|
||||
self.global_statements: List['Statement'] = []
|
||||
# A set of platformio libraries to add to the project
|
||||
self.libraries = [] # type: List[Library]
|
||||
self.libraries: List[Library] = []
|
||||
# A set of build flags to set in the platformio project
|
||||
self.build_flags = set() # type: Set[str]
|
||||
self.build_flags: Set[str] = set()
|
||||
# A set of defines to set for the compile process in esphome/core/defines.h
|
||||
self.defines = set() # type: Set['Define']
|
||||
self.defines: Set['Define'] = set()
|
||||
# A dictionary of started coroutines, used to warn when a coroutine was not
|
||||
# awaited.
|
||||
self.active_coroutines = {} # type: Dict[int, Any]
|
||||
self.active_coroutines: Dict[int, Any] = {}
|
||||
# A set of strings of names of loaded integrations, used to find namespace ID conflicts
|
||||
self.loaded_integrations = set()
|
||||
# A set of component IDs to track what Component subclasses are declared
|
||||
|
@ -525,7 +527,7 @@ class EsphomeCore:
|
|||
self.component_ids = set()
|
||||
|
||||
@property
|
||||
def address(self): # type: () -> str
|
||||
def address(self) -> Optional[str]:
|
||||
if 'wifi' in self.config:
|
||||
return self.config[CONF_WIFI][CONF_USE_ADDRESS]
|
||||
|
||||
|
@ -535,7 +537,7 @@ class EsphomeCore:
|
|||
return None
|
||||
|
||||
@property
|
||||
def comment(self): # type: () -> str
|
||||
def comment(self) -> Optional[str]:
|
||||
if CONF_COMMENT in self.config[CONF_ESPHOME]:
|
||||
return self.config[CONF_ESPHOME][CONF_COMMENT]
|
||||
|
||||
|
@ -548,7 +550,7 @@ class EsphomeCore:
|
|||
self.active_coroutines.pop(instance_id)
|
||||
|
||||
@property
|
||||
def arduino_version(self): # type: () -> str
|
||||
def arduino_version(self) -> str:
|
||||
return self.config[CONF_ESPHOME][CONF_ARDUINO_VERSION]
|
||||
|
||||
@property
|
||||
|
@ -587,13 +589,13 @@ class EsphomeCore:
|
|||
@property
|
||||
def is_esp8266(self):
|
||||
if self.esp_platform is None:
|
||||
raise ValueError
|
||||
raise ValueError("No platform specified")
|
||||
return self.esp_platform == 'ESP8266'
|
||||
|
||||
@property
|
||||
def is_esp32(self):
|
||||
if self.esp_platform is None:
|
||||
raise ValueError
|
||||
raise ValueError("No platform specified")
|
||||
return self.esp_platform == 'ESP32'
|
||||
|
||||
def add_job(self, func, *args, **kwargs):
|
||||
|
|
|
@ -11,7 +11,6 @@ classifier =
|
|||
Intended Audience :: End Users/Desktop
|
||||
License :: OSI Approved :: MIT License
|
||||
Programming Language :: C++
|
||||
Programming Language :: Python :: 2
|
||||
Programming Language :: Python :: 3
|
||||
Topic :: Home Automation
|
||||
Topic :: Home Automation
|
||||
|
|
1
setup.py
1
setup.py
|
@ -50,7 +50,6 @@ CLASSIFIERS = [
|
|||
'Intended Audience :: End Users/Desktop',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: C++',
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Topic :: Home Automation',
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue