Handle nanoseconds in config (#5695)

This commit is contained in:
Jesse Hills 2023-11-08 21:34:44 +13:00
parent 966c6a4531
commit dbb1263a36
No known key found for this signature in database
GPG key ID: BEAAE804EFD8E83A
4 changed files with 57 additions and 12 deletions

View file

@ -66,6 +66,7 @@ from esphome.core import (
TimePeriod, TimePeriod,
TimePeriodMicroseconds, TimePeriodMicroseconds,
TimePeriodMilliseconds, TimePeriodMilliseconds,
TimePeriodNanoseconds,
TimePeriodSeconds, TimePeriodSeconds,
TimePeriodMinutes, TimePeriodMinutes,
) )
@ -718,6 +719,8 @@ def time_period_str_unit(value):
raise Invalid("Expected string for time period with unit.") raise Invalid("Expected string for time period with unit.")
unit_to_kwarg = { unit_to_kwarg = {
"ns": "nanoseconds",
"nanoseconds": "nanoseconds",
"us": "microseconds", "us": "microseconds",
"microseconds": "microseconds", "microseconds": "microseconds",
"ms": "milliseconds", "ms": "milliseconds",
@ -739,7 +742,10 @@ def time_period_str_unit(value):
raise Invalid(f"Expected time period with unit, got {value}") raise Invalid(f"Expected time period with unit, got {value}")
kwarg = unit_to_kwarg[one_of(*unit_to_kwarg)(match.group(2))] kwarg = unit_to_kwarg[one_of(*unit_to_kwarg)(match.group(2))]
try:
return TimePeriod(**{kwarg: float(match.group(1))}) return TimePeriod(**{kwarg: float(match.group(1))})
except ValueError as e:
raise Invalid(e) from e
def time_period_in_milliseconds_(value): def time_period_in_milliseconds_(value):
@ -749,10 +755,18 @@ def time_period_in_milliseconds_(value):
def time_period_in_microseconds_(value): def time_period_in_microseconds_(value):
if value.nanoseconds is not None and value.nanoseconds != 0:
raise Invalid("Maximum precision is microseconds")
return TimePeriodMicroseconds(**value.as_dict()) return TimePeriodMicroseconds(**value.as_dict())
def time_period_in_nanoseconds_(value):
return TimePeriodNanoseconds(**value.as_dict())
def time_period_in_seconds_(value): def time_period_in_seconds_(value):
if value.nanoseconds is not None and value.nanoseconds != 0:
raise Invalid("Maximum precision is seconds")
if value.microseconds is not None and value.microseconds != 0: if value.microseconds is not None and value.microseconds != 0:
raise Invalid("Maximum precision is seconds") raise Invalid("Maximum precision is seconds")
if value.milliseconds is not None and value.milliseconds != 0: if value.milliseconds is not None and value.milliseconds != 0:
@ -761,6 +775,8 @@ def time_period_in_seconds_(value):
def time_period_in_minutes_(value): def time_period_in_minutes_(value):
if value.nanoseconds is not None and value.nanoseconds != 0:
raise Invalid("Maximum precision is minutes")
if value.microseconds is not None and value.microseconds != 0: if value.microseconds is not None and value.microseconds != 0:
raise Invalid("Maximum precision is minutes") raise Invalid("Maximum precision is minutes")
if value.milliseconds is not None and value.milliseconds != 0: if value.milliseconds is not None and value.milliseconds != 0:
@ -787,6 +803,9 @@ time_period_microseconds = All(time_period, time_period_in_microseconds_)
positive_time_period_microseconds = All( positive_time_period_microseconds = All(
positive_time_period, time_period_in_microseconds_ positive_time_period, time_period_in_microseconds_
) )
positive_time_period_nanoseconds = All(
positive_time_period, time_period_in_nanoseconds_
)
positive_not_null_time_period = All( positive_not_null_time_period = All(
time_period, Range(min=TimePeriod(), min_included=False) time_period, Range(min=TimePeriod(), min_included=False)
) )

View file

@ -87,6 +87,7 @@ def is_approximately_integer(value):
class TimePeriod: class TimePeriod:
def __init__( def __init__(
self, self,
nanoseconds=None,
microseconds=None, microseconds=None,
milliseconds=None, milliseconds=None,
seconds=None, seconds=None,
@ -136,13 +137,23 @@ class TimePeriod:
if microseconds is not None: if microseconds is not None:
if not is_approximately_integer(microseconds): if not is_approximately_integer(microseconds):
raise ValueError("Maximum precision is microseconds") frac_microseconds, microseconds = math.modf(microseconds)
nanoseconds = (nanoseconds or 0) + frac_microseconds * 1000
self.microseconds = int(round(microseconds)) self.microseconds = int(round(microseconds))
else: else:
self.microseconds = None self.microseconds = None
if nanoseconds is not None:
if not is_approximately_integer(nanoseconds):
raise ValueError("Maximum precision is nanoseconds")
self.nanoseconds = int(round(nanoseconds))
else:
self.nanoseconds = None
def as_dict(self): def as_dict(self):
out = OrderedDict() out = OrderedDict()
if self.nanoseconds is not None:
out["nanoseconds"] = self.nanoseconds
if self.microseconds is not None: if self.microseconds is not None:
out["microseconds"] = self.microseconds out["microseconds"] = self.microseconds
if self.milliseconds is not None: if self.milliseconds is not None:
@ -158,6 +169,8 @@ class TimePeriod:
return out return out
def __str__(self): def __str__(self):
if self.nanoseconds is not None:
return f"{self.total_nanoseconds}ns"
if self.microseconds is not None: if self.microseconds is not None:
return f"{self.total_microseconds}us" return f"{self.total_microseconds}us"
if self.milliseconds is not None: if self.milliseconds is not None:
@ -173,7 +186,11 @@ class TimePeriod:
return "0s" return "0s"
def __repr__(self): def __repr__(self):
return f"TimePeriod<{self.total_microseconds}>" return f"TimePeriod<{self.total_nanoseconds}ns>"
@property
def total_nanoseconds(self):
return self.total_microseconds * 1000 + (self.nanoseconds or 0)
@property @property
def total_microseconds(self): def total_microseconds(self):
@ -201,35 +218,39 @@ class TimePeriod:
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, TimePeriod): if isinstance(other, TimePeriod):
return self.total_microseconds == other.total_microseconds return self.total_nanoseconds == other.total_nanoseconds
return NotImplemented return NotImplemented
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, TimePeriod): if isinstance(other, TimePeriod):
return self.total_microseconds != other.total_microseconds return self.total_nanoseconds != other.total_nanoseconds
return NotImplemented return NotImplemented
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, TimePeriod): if isinstance(other, TimePeriod):
return self.total_microseconds < other.total_microseconds return self.total_nanoseconds < other.total_nanoseconds
return NotImplemented return NotImplemented
def __gt__(self, other): def __gt__(self, other):
if isinstance(other, TimePeriod): if isinstance(other, TimePeriod):
return self.total_microseconds > other.total_microseconds return self.total_nanoseconds > other.total_nanoseconds
return NotImplemented return NotImplemented
def __le__(self, other): def __le__(self, other):
if isinstance(other, TimePeriod): if isinstance(other, TimePeriod):
return self.total_microseconds <= other.total_microseconds return self.total_nanoseconds <= other.total_nanoseconds
return NotImplemented return NotImplemented
def __ge__(self, other): def __ge__(self, other):
if isinstance(other, TimePeriod): if isinstance(other, TimePeriod):
return self.total_microseconds >= other.total_microseconds return self.total_nanoseconds >= other.total_nanoseconds
return NotImplemented return NotImplemented
class TimePeriodNanoseconds(TimePeriod):
pass
class TimePeriodMicroseconds(TimePeriod): class TimePeriodMicroseconds(TimePeriod):
pass pass

View file

@ -17,6 +17,7 @@ from esphome.core import (
TimePeriodMicroseconds, TimePeriodMicroseconds,
TimePeriodMilliseconds, TimePeriodMilliseconds,
TimePeriodMinutes, TimePeriodMinutes,
TimePeriodNanoseconds,
TimePeriodSeconds, TimePeriodSeconds,
) )
from esphome.helpers import cpp_string_escape, indent_all_but_first_and_last from esphome.helpers import cpp_string_escape, indent_all_but_first_and_last
@ -351,6 +352,8 @@ def safe_exp(obj: SafeExpType) -> Expression:
return IntLiteral(obj) return IntLiteral(obj)
if isinstance(obj, float): if isinstance(obj, float):
return FloatLiteral(obj) return FloatLiteral(obj)
if isinstance(obj, TimePeriodNanoseconds):
return IntLiteral(int(obj.total_nanoseconds))
if isinstance(obj, TimePeriodMicroseconds): if isinstance(obj, TimePeriodMicroseconds):
return IntLiteral(int(obj.total_microseconds)) return IntLiteral(int(obj.total_microseconds))
if isinstance(obj, TimePeriodMilliseconds): if isinstance(obj, TimePeriodMilliseconds):

View file

@ -116,14 +116,16 @@ class TestTimePeriod:
assert actual == expected assert actual == expected
def test_init__microseconds_with_fraction(self): def test_init__nanoseconds_with_fraction(self):
with pytest.raises(ValueError, match="Maximum precision is microseconds"): with pytest.raises(ValueError, match="Maximum precision is nanoseconds"):
core.TimePeriod(microseconds=1.1) core.TimePeriod(nanoseconds=1.1)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kwargs, expected", "kwargs, expected",
( (
({}, "0s"), ({}, "0s"),
({"nanoseconds": 1}, "1ns"),
({"nanoseconds": 1.0001}, "1ns"),
({"microseconds": 1}, "1us"), ({"microseconds": 1}, "1us"),
({"microseconds": 1.0001}, "1us"), ({"microseconds": 1.0001}, "1us"),
({"milliseconds": 2}, "2ms"), ({"milliseconds": 2}, "2ms"),