1
0
mirror of https://github.com/esphome/esphome.git synced 2025-01-31 02:00:55 +00:00

Convert IPAddress to use Pythonmodule ipaddress (#8072)

This commit is contained in:
Jimmy Hedman 2025-01-12 20:12:38 +01:00 committed by GitHub
parent fe80750743
commit d69926485c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 65 additions and 75 deletions

View File

@ -94,11 +94,11 @@ CLK_MODES = {
MANUAL_IP_SCHEMA = cv.Schema( MANUAL_IP_SCHEMA = cv.Schema(
{ {
cv.Required(CONF_STATIC_IP): cv.ipv4, cv.Required(CONF_STATIC_IP): cv.ipv4address,
cv.Required(CONF_GATEWAY): cv.ipv4, cv.Required(CONF_GATEWAY): cv.ipv4address,
cv.Required(CONF_SUBNET): cv.ipv4, cv.Required(CONF_SUBNET): cv.ipv4address,
cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4, cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4address,
cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4, cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4address,
} }
) )
@ -255,11 +255,11 @@ FINAL_VALIDATE_SCHEMA = _final_validate
def manual_ip(config): def manual_ip(config):
return cg.StructInitializer( return cg.StructInitializer(
ManualIP, ManualIP,
("static_ip", IPAddress(*config[CONF_STATIC_IP].args)), ("static_ip", IPAddress(str(config[CONF_STATIC_IP]))),
("gateway", IPAddress(*config[CONF_GATEWAY].args)), ("gateway", IPAddress(str(config[CONF_GATEWAY]))),
("subnet", IPAddress(*config[CONF_SUBNET].args)), ("subnet", IPAddress(str(config[CONF_SUBNET]))),
("dns1", IPAddress(*config[CONF_DNS1].args)), ("dns1", IPAddress(str(config[CONF_DNS1]))),
("dns2", IPAddress(*config[CONF_DNS2].args)), ("dns2", IPAddress(str(config[CONF_DNS2]))),
) )

View File

@ -85,7 +85,7 @@ CONFIG_SCHEMA = cv.All(
cv.GenerateID(): cv.declare_id(UDPComponent), cv.GenerateID(): cv.declare_id(UDPComponent),
cv.Optional(CONF_PORT, default=18511): cv.port, cv.Optional(CONF_PORT, default=18511): cv.port,
cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list( cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list(
cv.ipv4 cv.ipv4address,
), ),
cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean, cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean,
cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean, cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean,

View File

@ -93,16 +93,16 @@ def validate_channel(value):
AP_MANUAL_IP_SCHEMA = cv.Schema( AP_MANUAL_IP_SCHEMA = cv.Schema(
{ {
cv.Required(CONF_STATIC_IP): cv.ipv4, cv.Required(CONF_STATIC_IP): cv.ipv4address,
cv.Required(CONF_GATEWAY): cv.ipv4, cv.Required(CONF_GATEWAY): cv.ipv4address,
cv.Required(CONF_SUBNET): cv.ipv4, cv.Required(CONF_SUBNET): cv.ipv4address,
} }
) )
STA_MANUAL_IP_SCHEMA = AP_MANUAL_IP_SCHEMA.extend( STA_MANUAL_IP_SCHEMA = AP_MANUAL_IP_SCHEMA.extend(
{ {
cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4, cv.Optional(CONF_DNS1, default="0.0.0.0"): cv.ipv4address,
cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4, cv.Optional(CONF_DNS2, default="0.0.0.0"): cv.ipv4address,
} }
) )
@ -364,7 +364,7 @@ def eap_auth(config):
def safe_ip(ip): def safe_ip(ip):
if ip is None: if ip is None:
return IPAddress(0, 0, 0, 0) return IPAddress(0, 0, 0, 0)
return IPAddress(*ip.args) return IPAddress(str(ip))
def manual_ip(config): def manual_ip(config):

View File

@ -67,8 +67,8 @@ CONFIG_SCHEMA = cv.Schema(
{ {
cv.GenerateID(): cv.declare_id(Wireguard), cv.GenerateID(): cv.declare_id(Wireguard),
cv.GenerateID(CONF_TIME_ID): cv.use_id(time.RealTimeClock), cv.GenerateID(CONF_TIME_ID): cv.use_id(time.RealTimeClock),
cv.Required(CONF_ADDRESS): cv.ipv4, cv.Required(CONF_ADDRESS): cv.ipv4address,
cv.Optional(CONF_NETMASK, default="255.255.255.255"): cv.ipv4, cv.Optional(CONF_NETMASK, default="255.255.255.255"): cv.ipv4address,
cv.Required(CONF_PRIVATE_KEY): _wireguard_key, cv.Required(CONF_PRIVATE_KEY): _wireguard_key,
cv.Required(CONF_PEER_ENDPOINT): cv.string, cv.Required(CONF_PEER_ENDPOINT): cv.string,
cv.Required(CONF_PEER_PUBLIC_KEY): _wireguard_key, cv.Required(CONF_PEER_PUBLIC_KEY): _wireguard_key,

View File

@ -3,6 +3,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from ipaddress import AddressValueError, IPv4Address, ip_address
import logging import logging
import os import os
import re import re
@ -67,7 +68,6 @@ from esphome.const import (
from esphome.core import ( from esphome.core import (
CORE, CORE,
HexInt, HexInt,
IPAddress,
Lambda, Lambda,
TimePeriod, TimePeriod,
TimePeriodMicroseconds, TimePeriodMicroseconds,
@ -1130,7 +1130,7 @@ def domain(value):
if re.match(vol.DOMAIN_REGEX, value) is not None: if re.match(vol.DOMAIN_REGEX, value) is not None:
return value return value
try: try:
return str(ipv4(value)) return str(ipaddress(value))
except Invalid as err: except Invalid as err:
raise Invalid(f"Invalid domain: {value}") from err raise Invalid(f"Invalid domain: {value}") from err
@ -1160,21 +1160,20 @@ def ssid(value):
return value return value
def ipv4(value): def ipv4address(value):
if isinstance(value, list): try:
parts = value address = IPv4Address(value)
elif isinstance(value, str): except AddressValueError as exc:
parts = value.split(".") raise Invalid(f"{value} is not a valid IPv4 address") from exc
elif isinstance(value, IPAddress): return address
return value
else:
raise Invalid("IPv4 address must consist of either string or integer list") def ipaddress(value):
if len(parts) != 4: try:
raise Invalid("IPv4 address must consist of four point-separated integers") address = ip_address(value)
parts_ = list(map(int, parts)) except ValueError as exc:
if not all(0 <= x < 256 for x in parts_): raise Invalid(f"{value} is not a valid IP address") from exc
raise Invalid("IPv4 address parts must be in range from 0 to 255") return address
return IPAddress(*parts_)
def _valid_topic(value): def _valid_topic(value):

View File

@ -54,16 +54,6 @@ class HexInt(int):
return f"{sign}0x{value:X}" return f"{sign}0x{value:X}"
class IPAddress:
def __init__(self, *args):
if len(args) != 4:
raise ValueError("IPAddress must consist of 4 items")
self.args = args
def __str__(self):
return ".".join(str(x) for x in self.args)
class MACAddress: class MACAddress:
def __init__(self, *parts): def __init__(self, *parts):
if len(parts) != 6: if len(parts) != 6:

View File

@ -4,6 +4,7 @@ import fnmatch
import functools import functools
import inspect import inspect
from io import TextIOWrapper from io import TextIOWrapper
from ipaddress import _BaseAddress
import logging import logging
import math import math
import os import os
@ -25,7 +26,6 @@ from esphome.core import (
CORE, CORE,
DocumentRange, DocumentRange,
EsphomeError, EsphomeError,
IPAddress,
Lambda, Lambda,
MACAddress, MACAddress,
TimePeriod, TimePeriod,
@ -576,7 +576,7 @@ ESPHomeDumper.add_multi_representer(bool, ESPHomeDumper.represent_bool)
ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int) ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int)
ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float) ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float)
ESPHomeDumper.add_multi_representer(IPAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(_BaseAddress, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda) ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda)

View File

@ -1,12 +1,12 @@
import pytest
import string import string
from hypothesis import given, example from hypothesis import example, given
from hypothesis.strategies import one_of, text, integers, builds from hypothesis.strategies import builds, integers, ip_addresses, one_of, text
import pytest
from esphome import config_validation from esphome import config_validation
from esphome.config_validation import Invalid from esphome.config_validation import Invalid
from esphome.core import CORE, Lambda, HexInt from esphome.core import CORE, HexInt, Lambda
def test_check_not_templatable__invalid(): def test_check_not_templatable__invalid():
@ -145,6 +145,28 @@ def test_boolean__invalid(value):
config_validation.boolean(value) config_validation.boolean(value)
@given(value=ip_addresses(v=4).map(str))
def test_ipv4__valid(value):
config_validation.ipv4address(value)
@pytest.mark.parametrize("value", ("127.0.0", "localhost", ""))
def test_ipv4__invalid(value):
with pytest.raises(Invalid, match="is not a valid IPv4 address"):
config_validation.ipv4address(value)
@given(value=ip_addresses(v=6).map(str))
def test_ipv6__valid(value):
config_validation.ipaddress(value)
@pytest.mark.parametrize("value", ("127.0.0", "localhost", "", "2001:db8::2::3"))
def test_ipv6__invalid(value):
with pytest.raises(Invalid, match="is not a valid IP address"):
config_validation.ipaddress(value)
# TODO: ensure_list # TODO: ensure_list
@given(integers()) @given(integers())
def hex_int__valid(value): def hex_int__valid(value):

View File

@ -1,10 +1,8 @@
import pytest
from hypothesis import given from hypothesis import given
from hypothesis.strategies import ip_addresses import pytest
from strategies import mac_addr_strings from strategies import mac_addr_strings
from esphome import core, const from esphome import const, core
class TestHexInt: class TestHexInt:
@ -26,25 +24,6 @@ class TestHexInt:
assert actual == expected assert actual == expected
class TestIPAddress:
@given(value=ip_addresses(v=4).map(str))
def test_init__valid(self, value):
core.IPAddress(*value.split("."))
@pytest.mark.parametrize("value", ("127.0.0", "localhost", ""))
def test_init__invalid(self, value):
with pytest.raises(ValueError, match="IPAddress must consist of 4 items"):
core.IPAddress(*value.split("."))
@given(value=ip_addresses(v=4).map(str))
def test_str(self, value):
target = core.IPAddress(*value.split("."))
actual = str(target)
assert actual == value
class TestMACAddress: class TestMACAddress:
@given(value=mac_addr_strings()) @given(value=mac_addr_strings())
def test_init__valid(self, value): def test_init__valid(self, value):