From e4f164f1799f8a7777955b80090694ee8e3ad2b3 Mon Sep 17 00:00:00 2001 From: Jimmy Hedman Date: Mon, 2 Sep 2024 20:38:42 +0200 Subject: [PATCH] Use ipaddress standard library - Remove IPAddress ESPHome Python class and use standard ipaddress from standard library --- esphome/config_validation.py | 14 ++++++-- esphome/core/__init__.py | 12 ------- esphome/yaml_util.py | 4 +-- tests/unit_tests/test_config_validation.py | 30 +++++++++++++--- tests/unit_tests/test_core.py | 41 ---------------------- 5 files changed, 39 insertions(+), 62 deletions(-) diff --git a/esphome/config_validation.py b/esphome/config_validation.py index a3f5611a78..8a7192339e 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime +from ipaddress import AddressValueError, IPv4Address, ip_address import logging import os import re @@ -66,7 +67,6 @@ from esphome.const import ( from esphome.core import ( CORE, HexInt, - IPAddress, Lambda, TimePeriod, TimePeriodMicroseconds, @@ -1159,11 +1159,19 @@ def ssid(value): def ipv4address(value): - return IPAddress(value, allow_ipv6=False) + try: + address = IPv4Address(value) + except AddressValueError as exc: + raise Invalid(f"{value} is not a valid IPv4 address") from exc + return address def ipaddress(value): - return IPAddress(value, allow_ipv6=True) + try: + address = ip_address(value) + except ValueError as exc: + raise Invalid(f"{value} is not a valid IP address") from exc + return address def _valid_topic(value): diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 41a7505129..f26c3da483 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -1,4 +1,3 @@ -import ipaddress import logging import math import os @@ -55,17 +54,6 @@ class HexInt(int): return f"{sign}0x{value:X}" -class IPAddress: - def __init__(self, arg, allow_ipv6=False): - if allow_ipv6: - self.args = str(ipaddress.ip_address(arg)) - else: - self.args = str(ipaddress.IPv4Address(arg)) - - def __str__(self): - return self.args - - class MACAddress: def __init__(self, *parts): if len(parts) != 6: diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index d67511dfec..b27ce4c3e3 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -4,6 +4,7 @@ import fnmatch import functools import inspect from io import TextIOWrapper +from ipaddress import _BaseAddress import logging import math import os @@ -25,7 +26,6 @@ from esphome.core import ( CORE, DocumentRange, EsphomeError, - IPAddress, Lambda, MACAddress, TimePeriod, @@ -576,7 +576,7 @@ ESPHomeDumper.add_multi_representer(bool, ESPHomeDumper.represent_bool) ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int) ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float) -ESPHomeDumper.add_multi_representer(IPAddress, ESPHomeDumper.represent_stringify) +ESPHomeDumper.add_multi_representer(_BaseAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda) diff --git a/tests/unit_tests/test_config_validation.py b/tests/unit_tests/test_config_validation.py index 34f70be2fb..93ae67754a 100644 --- a/tests/unit_tests/test_config_validation.py +++ b/tests/unit_tests/test_config_validation.py @@ -1,12 +1,12 @@ -import pytest import string -from hypothesis import given, example -from hypothesis.strategies import one_of, text, integers, builds +from hypothesis import example, given +from hypothesis.strategies import builds, integers, ip_addresses, one_of, text +import pytest from esphome import config_validation from esphome.config_validation import Invalid -from esphome.core import CORE, Lambda, HexInt +from esphome.core import CORE, HexInt, Lambda def test_check_not_templatable__invalid(): @@ -145,6 +145,28 @@ def test_boolean__invalid(value): config_validation.boolean(value) +@given(value=ip_addresses(v=4).map(str)) +def test_ipv4__valid(value): + config_validation.ipv4address(value) + + +@pytest.mark.parametrize("value", ("127.0.0", "localhost", "")) +def test_ipv4__invalid(value): + with pytest.raises(Invalid, match="is not a valid IPv4 address"): + config_validation.ipv4address(value) + + +@given(value=ip_addresses(v=6).map(str)) +def test_ipv6__valid(value): + config_validation.ipaddress(value) + + +@pytest.mark.parametrize("value", ("127.0.0", "localhost", "", "2001:db8::2::3")) +def test_ipv6__invalid(value): + with pytest.raises(Invalid, match="is not a valid IP address"): + config_validation.ipaddress(value) + + # TODO: ensure_list @given(integers()) def hex_int__valid(value): diff --git a/tests/unit_tests/test_core.py b/tests/unit_tests/test_core.py index b2f2288c4b..4f2a6453b4 100644 --- a/tests/unit_tests/test_core.py +++ b/tests/unit_tests/test_core.py @@ -1,7 +1,4 @@ -from ipaddress import AddressValueError - from hypothesis import given -from hypothesis.strategies import ip_addresses import pytest from strategies import mac_addr_strings @@ -27,44 +24,6 @@ class TestHexInt: assert actual == expected -class TestIP4Address: - @given(value=ip_addresses(v=4).map(str)) - def test_init__valid(self, value): - core.IPAddress(value, allow_ipv6=False) - - @pytest.mark.parametrize("value", ("127.0.0", "localhost", "")) - def test_init__invalid(self, value): - with pytest.raises((ValueError, AddressValueError)): - core.IPAddress(value, allow_ipv6=False) - - @given(value=ip_addresses(v=4).map(str)) - def test_str(self, value): - target = core.IPAddress(value, allow_ipv6=False) - - actual = str(target) - - assert actual == value - - -class TestIP6Address: - @given(value=ip_addresses(v=6).map(str)) - def test_init__valid(self, value): - core.IPAddress(value, allow_ipv6=True) - - @pytest.mark.parametrize("value", ("127.0.0", "localhost", "", "2001:db8::2::3")) - def test_init__invalid(self, value): - with pytest.raises((ValueError, AddressValueError)): - core.IPAddress(value, allow_ipv6=True) - - @given(value=ip_addresses(v=6).map(str)) - def test_str(self, value): - target = core.IPAddress(value, allow_ipv6=True) - - actual = str(target) - - assert actual == value - - class TestMACAddress: @given(value=mac_addr_strings()) def test_init__valid(self, value):