From d69926485c6ed33b476aef8994b41ac50b31ece9 Mon Sep 17 00:00:00 2001 From: Jimmy Hedman Date: Sun, 12 Jan 2025 20:12:38 +0100 Subject: [PATCH] Convert IPAddress to use Pythonmodule ipaddress (#8072) --- esphome/components/ethernet/__init__.py | 20 ++++++------- esphome/components/udp/__init__.py | 2 +- esphome/components/wifi/__init__.py | 12 ++++---- esphome/components/wireguard/__init__.py | 4 +-- esphome/config_validation.py | 33 +++++++++++----------- esphome/core/__init__.py | 10 ------- esphome/yaml_util.py | 4 +-- tests/unit_tests/test_config_validation.py | 30 +++++++++++++++++--- tests/unit_tests/test_core.py | 25 ++-------------- 9 files changed, 65 insertions(+), 75 deletions(-) diff --git a/esphome/components/ethernet/__init__.py b/esphome/components/ethernet/__init__.py index dca37b8dc2..ab760a9b6c 100644 --- a/esphome/components/ethernet/__init__.py +++ b/esphome/components/ethernet/__init__.py @@ -94,11 +94,11 @@ CLK_MODES = { MANUAL_IP_SCHEMA = cv.Schema( { - cv.Required(CONF_STATIC_IP): cv.ipv4, - cv.Required(CONF_GATEWAY): cv.ipv4, - cv.Required(CONF_SUBNET): cv.ipv4, - cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4, - cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4, + cv.Required(CONF_STATIC_IP): cv.ipv4address, + cv.Required(CONF_GATEWAY): cv.ipv4address, + cv.Required(CONF_SUBNET): cv.ipv4address, + cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4address, + cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4address, } ) @@ -255,11 +255,11 @@ FINAL_VALIDATE_SCHEMA = _final_validate def manual_ip(config): return cg.StructInitializer( ManualIP, - ("static_ip", IPAddress(*config[CONF_STATIC_IP].args)), - ("gateway", IPAddress(*config[CONF_GATEWAY].args)), - ("subnet", IPAddress(*config[CONF_SUBNET].args)), - ("dns1", IPAddress(*config[CONF_DNS1].args)), - ("dns2", IPAddress(*config[CONF_DNS2].args)), + ("static_ip", IPAddress(str(config[CONF_STATIC_IP]))), + ("gateway", IPAddress(str(config[CONF_GATEWAY]))), + ("subnet", IPAddress(str(config[CONF_SUBNET]))), + ("dns1", IPAddress(str(config[CONF_DNS1]))), + ("dns2", IPAddress(str(config[CONF_DNS2]))), ) diff --git a/esphome/components/udp/__init__.py b/esphome/components/udp/__init__.py index ca15be2a80..e189975ade 100644 --- a/esphome/components/udp/__init__.py +++ b/esphome/components/udp/__init__.py @@ -85,7 +85,7 @@ CONFIG_SCHEMA = cv.All( cv.GenerateID(): cv.declare_id(UDPComponent), cv.Optional(CONF_PORT, default=18511): cv.port, cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list( - cv.ipv4 + cv.ipv4address, ), cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean, cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean, diff --git a/esphome/components/wifi/__init__.py b/esphome/components/wifi/__init__.py index ad1a4f5262..582b826de0 100644 --- a/esphome/components/wifi/__init__.py +++ b/esphome/components/wifi/__init__.py @@ -93,16 +93,16 @@ def validate_channel(value): AP_MANUAL_IP_SCHEMA = cv.Schema( { - cv.Required(CONF_STATIC_IP): cv.ipv4, - cv.Required(CONF_GATEWAY): cv.ipv4, - cv.Required(CONF_SUBNET): cv.ipv4, + cv.Required(CONF_STATIC_IP): cv.ipv4address, + cv.Required(CONF_GATEWAY): cv.ipv4address, + cv.Required(CONF_SUBNET): cv.ipv4address, } ) STA_MANUAL_IP_SCHEMA = AP_MANUAL_IP_SCHEMA.extend( { - cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4, - cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4, + cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4address, + cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4address, } ) @@ -364,7 +364,7 @@ def eap_auth(config): def safe_ip(ip): if ip is None: return IPAddress(0, 0, 0, 0) - return IPAddress(*ip.args) + return IPAddress(str(ip)) def manual_ip(config): diff --git a/esphome/components/wireguard/__init__.py b/esphome/components/wireguard/__init__.py index 5e34a8a19b..fc0e4e0538 100644 --- a/esphome/components/wireguard/__init__.py +++ b/esphome/components/wireguard/__init__.py @@ -67,8 +67,8 @@ CONFIG_SCHEMA = cv.Schema( { cv.GenerateID(): cv.declare_id(Wireguard), cv.GenerateID(CONF_TIME_ID): cv.use_id(time.RealTimeClock), - cv.Required(CONF_ADDRESS): cv.ipv4, - cv.Optional(CONF_NETMASK, default="255.255.255.255"): cv.ipv4, + cv.Required(CONF_ADDRESS): cv.ipv4address, + cv.Optional(CONF_NETMASK, default="255.255.255.255"): cv.ipv4address, cv.Required(CONF_PRIVATE_KEY): _wireguard_key, cv.Required(CONF_PEER_ENDPOINT): cv.string, cv.Required(CONF_PEER_PUBLIC_KEY): _wireguard_key, diff --git a/esphome/config_validation.py b/esphome/config_validation.py index 38fd677a2a..20a0774ccb 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime +from ipaddress import AddressValueError, IPv4Address, ip_address import logging import os import re @@ -67,7 +68,6 @@ from esphome.const import ( from esphome.core import ( CORE, HexInt, - IPAddress, Lambda, TimePeriod, TimePeriodMicroseconds, @@ -1130,7 +1130,7 @@ def domain(value): if re.match(vol.DOMAIN_REGEX, value) is not None: return value try: - return str(ipv4(value)) + return str(ipaddress(value)) except Invalid as err: raise Invalid(f"Invalid domain: {value}") from err @@ -1160,21 +1160,20 @@ def ssid(value): return value -def ipv4(value): - if isinstance(value, list): - parts = value - elif isinstance(value, str): - parts = value.split(".") - elif isinstance(value, IPAddress): - return value - else: - raise Invalid("IPv4 address must consist of either string or integer list") - if len(parts) != 4: - raise Invalid("IPv4 address must consist of four point-separated integers") - parts_ = list(map(int, parts)) - if not all(0 <= x < 256 for x in parts_): - raise Invalid("IPv4 address parts must be in range from 0 to 255") - return IPAddress(*parts_) +def ipv4address(value): + try: + address = IPv4Address(value) + except AddressValueError as exc: + raise Invalid(f"{value} is not a valid IPv4 address") from exc + return address + + +def ipaddress(value): + try: + address = ip_address(value) + except ValueError as exc: + raise Invalid(f"{value} is not a valid IP address") from exc + return address def _valid_topic(value): diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index a97c3b18c9..f26c3da483 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -54,16 +54,6 @@ class HexInt(int): return f"{sign}0x{value:X}" -class IPAddress: - def __init__(self, *args): - if len(args) != 4: - raise ValueError("IPAddress must consist of 4 items") - self.args = args - - def __str__(self): - return ".".join(str(x) for x in self.args) - - class MACAddress: def __init__(self, *parts): if len(parts) != 6: diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index d67511dfec..b27ce4c3e3 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -4,6 +4,7 @@ import fnmatch import functools import inspect from io import TextIOWrapper +from ipaddress import _BaseAddress import logging import math import os @@ -25,7 +26,6 @@ from esphome.core import ( CORE, DocumentRange, EsphomeError, - IPAddress, Lambda, MACAddress, TimePeriod, @@ -576,7 +576,7 @@ ESPHomeDumper.add_multi_representer(bool, ESPHomeDumper.represent_bool) ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int) ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float) -ESPHomeDumper.add_multi_representer(IPAddress, ESPHomeDumper.represent_stringify) +ESPHomeDumper.add_multi_representer(_BaseAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda) diff --git a/tests/unit_tests/test_config_validation.py b/tests/unit_tests/test_config_validation.py index 34f70be2fb..93ae67754a 100644 --- a/tests/unit_tests/test_config_validation.py +++ b/tests/unit_tests/test_config_validation.py @@ -1,12 +1,12 @@ -import pytest import string -from hypothesis import given, example -from hypothesis.strategies import one_of, text, integers, builds +from hypothesis import example, given +from hypothesis.strategies import builds, integers, ip_addresses, one_of, text +import pytest from esphome import config_validation from esphome.config_validation import Invalid -from esphome.core import CORE, Lambda, HexInt +from esphome.core import CORE, HexInt, Lambda def test_check_not_templatable__invalid(): @@ -145,6 +145,28 @@ def test_boolean__invalid(value): config_validation.boolean(value) +@given(value=ip_addresses(v=4).map(str)) +def test_ipv4__valid(value): + config_validation.ipv4address(value) + + +@pytest.mark.parametrize("value", ("127.0.0", "localhost", "")) +def test_ipv4__invalid(value): + with pytest.raises(Invalid, match="is not a valid IPv4 address"): + config_validation.ipv4address(value) + + +@given(value=ip_addresses(v=6).map(str)) +def test_ipv6__valid(value): + config_validation.ipaddress(value) + + +@pytest.mark.parametrize("value", ("127.0.0", "localhost", "", "2001:db8::2::3")) +def test_ipv6__invalid(value): + with pytest.raises(Invalid, match="is not a valid IP address"): + config_validation.ipaddress(value) + + # TODO: ensure_list @given(integers()) def hex_int__valid(value): diff --git a/tests/unit_tests/test_core.py b/tests/unit_tests/test_core.py index 2860486efe..4f2a6453b4 100644 --- a/tests/unit_tests/test_core.py +++ b/tests/unit_tests/test_core.py @@ -1,10 +1,8 @@ -import pytest - from hypothesis import given -from hypothesis.strategies import ip_addresses +import pytest from strategies import mac_addr_strings -from esphome import core, const +from esphome import const, core class TestHexInt: @@ -26,25 +24,6 @@ class TestHexInt: assert actual == expected -class TestIPAddress: - @given(value=ip_addresses(v=4).map(str)) - def test_init__valid(self, value): - core.IPAddress(*value.split(".")) - - @pytest.mark.parametrize("value", ("127.0.0", "localhost", "")) - def test_init__invalid(self, value): - with pytest.raises(ValueError, match="IPAddress must consist of 4 items"): - core.IPAddress(*value.split(".")) - - @given(value=ip_addresses(v=4).map(str)) - def test_str(self, value): - target = core.IPAddress(*value.split(".")) - - actual = str(target) - - assert actual == value - - class TestMACAddress: @given(value=mac_addr_strings()) def test_init__valid(self, value):