mirror of
https://github.com/esphome/esphome.git
synced 2025-09-29 16:42:19 +01:00
Merge branch 'integration' into memory_api
This commit is contained in:
@@ -114,6 +114,14 @@ class Purpose(StrEnum):
|
||||
LOGGING = "logging"
|
||||
|
||||
|
||||
def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
|
||||
"""Resolve an address using cache if available, otherwise return the address itself."""
|
||||
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(address)):
|
||||
_LOGGER.debug("Using cached addresses for %s: %s", purpose.value, cached)
|
||||
return cached
|
||||
return [address]
|
||||
|
||||
|
||||
def choose_upload_log_host(
|
||||
default: list[str] | str | None,
|
||||
check_default: str | None,
|
||||
@@ -142,7 +150,7 @@ def choose_upload_log_host(
|
||||
(purpose == Purpose.LOGGING and has_api())
|
||||
or (purpose == Purpose.UPLOADING and has_ota())
|
||||
):
|
||||
resolved.append(CORE.address)
|
||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
||||
|
||||
if purpose == Purpose.LOGGING:
|
||||
if has_api() and has_mqtt_ip_lookup():
|
||||
@@ -152,15 +160,14 @@ def choose_upload_log_host(
|
||||
resolved.append("MQTT")
|
||||
|
||||
if has_api() and has_non_ip_address():
|
||||
resolved.append(CORE.address)
|
||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
||||
|
||||
elif purpose == Purpose.UPLOADING:
|
||||
if has_ota() and has_mqtt_ip_lookup():
|
||||
resolved.append("MQTTIP")
|
||||
|
||||
if has_ota() and has_non_ip_address():
|
||||
resolved.append(CORE.address)
|
||||
|
||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
||||
else:
|
||||
resolved.append(device)
|
||||
if not resolved:
|
||||
@@ -972,6 +979,18 @@ def parse_args(argv):
|
||||
help="Add a substitution",
|
||||
metavar=("key", "value"),
|
||||
)
|
||||
options_parser.add_argument(
|
||||
"--mdns-address-cache",
|
||||
help="mDNS address cache mapping in format 'hostname=ip1,ip2'",
|
||||
action="append",
|
||||
default=[],
|
||||
)
|
||||
options_parser.add_argument(
|
||||
"--dns-address-cache",
|
||||
help="DNS address cache mapping in format 'hostname=ip1,ip2'",
|
||||
action="append",
|
||||
default=[],
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"ESPHome {const.__version__}", parents=[options_parser]
|
||||
@@ -1230,9 +1249,15 @@ def parse_args(argv):
|
||||
|
||||
|
||||
def run_esphome(argv):
|
||||
from esphome.address_cache import AddressCache
|
||||
|
||||
args = parse_args(argv)
|
||||
CORE.dashboard = args.dashboard
|
||||
|
||||
# Create address cache from command-line arguments
|
||||
CORE.address_cache = AddressCache.from_cli_args(
|
||||
args.mdns_address_cache, args.dns_address_cache
|
||||
)
|
||||
# Override log level if verbose is set
|
||||
if args.verbose:
|
||||
args.log_level = "DEBUG"
|
||||
|
142
esphome/address_cache.py
Normal file
142
esphome/address_cache.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Address cache for DNS and mDNS lookups."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_hostname(hostname: str) -> str:
|
||||
"""Normalize hostname for cache lookups.
|
||||
|
||||
Removes trailing dots and converts to lowercase.
|
||||
"""
|
||||
return hostname.rstrip(".").lower()
|
||||
|
||||
|
||||
class AddressCache:
|
||||
"""Cache for DNS and mDNS address lookups.
|
||||
|
||||
This cache stores pre-resolved addresses from command-line arguments
|
||||
to avoid slow DNS/mDNS lookups during builds.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mdns_cache: dict[str, list[str]] | None = None,
|
||||
dns_cache: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the address cache.
|
||||
|
||||
Args:
|
||||
mdns_cache: Pre-populated mDNS addresses (hostname -> IPs)
|
||||
dns_cache: Pre-populated DNS addresses (hostname -> IPs)
|
||||
"""
|
||||
self.mdns_cache = mdns_cache or {}
|
||||
self.dns_cache = dns_cache or {}
|
||||
|
||||
def _get_cached_addresses(
|
||||
self, hostname: str, cache: dict[str, list[str]], cache_type: str
|
||||
) -> list[str] | None:
|
||||
"""Get cached addresses from a specific cache.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to look up
|
||||
cache: The cache dictionary to check
|
||||
cache_type: Type of cache for logging ("mDNS" or "DNS")
|
||||
|
||||
Returns:
|
||||
List of IP addresses if found in cache, None otherwise
|
||||
"""
|
||||
normalized = normalize_hostname(hostname)
|
||||
if addresses := cache.get(normalized):
|
||||
_LOGGER.debug("Using %s cache for %s: %s", cache_type, hostname, addresses)
|
||||
return addresses
|
||||
return None
|
||||
|
||||
def get_mdns_addresses(self, hostname: str) -> list[str] | None:
|
||||
"""Get cached mDNS addresses for a hostname.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to look up (should end with .local)
|
||||
|
||||
Returns:
|
||||
List of IP addresses if found in cache, None otherwise
|
||||
"""
|
||||
return self._get_cached_addresses(hostname, self.mdns_cache, "mDNS")
|
||||
|
||||
def get_dns_addresses(self, hostname: str) -> list[str] | None:
|
||||
"""Get cached DNS addresses for a hostname.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to look up
|
||||
|
||||
Returns:
|
||||
List of IP addresses if found in cache, None otherwise
|
||||
"""
|
||||
return self._get_cached_addresses(hostname, self.dns_cache, "DNS")
|
||||
|
||||
def get_addresses(self, hostname: str) -> list[str] | None:
|
||||
"""Get cached addresses for a hostname.
|
||||
|
||||
Checks mDNS cache for .local domains, DNS cache otherwise.
|
||||
|
||||
Args:
|
||||
hostname: The hostname to look up
|
||||
|
||||
Returns:
|
||||
List of IP addresses if found in cache, None otherwise
|
||||
"""
|
||||
normalized = normalize_hostname(hostname)
|
||||
if normalized.endswith(".local"):
|
||||
return self.get_mdns_addresses(hostname)
|
||||
return self.get_dns_addresses(hostname)
|
||||
|
||||
def has_cache(self) -> bool:
|
||||
"""Check if any cache entries exist."""
|
||||
return bool(self.mdns_cache or self.dns_cache)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
|
||||
) -> AddressCache:
|
||||
"""Create cache from command-line arguments.
|
||||
|
||||
Args:
|
||||
mdns_args: List of mDNS cache entries like ['host=ip1,ip2']
|
||||
dns_args: List of DNS cache entries like ['host=ip1,ip2']
|
||||
|
||||
Returns:
|
||||
Configured AddressCache instance
|
||||
"""
|
||||
mdns_cache = cls._parse_cache_args(mdns_args)
|
||||
dns_cache = cls._parse_cache_args(dns_args)
|
||||
return cls(mdns_cache=mdns_cache, dns_cache=dns_cache)
|
||||
|
||||
@staticmethod
|
||||
def _parse_cache_args(cache_args: Iterable[str]) -> dict[str, list[str]]:
|
||||
"""Parse cache arguments into a dictionary.
|
||||
|
||||
Args:
|
||||
cache_args: List of cache mappings like ['host1=ip1,ip2', 'host2=ip3']
|
||||
|
||||
Returns:
|
||||
Dictionary mapping normalized hostnames to list of IP addresses
|
||||
"""
|
||||
cache: dict[str, list[str]] = {}
|
||||
for arg in cache_args:
|
||||
if "=" not in arg:
|
||||
_LOGGER.warning(
|
||||
"Invalid cache format: %s (expected 'hostname=ip1,ip2')", arg
|
||||
)
|
||||
continue
|
||||
hostname, ips = arg.split("=", 1)
|
||||
# Normalize hostname for consistent lookups
|
||||
normalized = normalize_hostname(hostname)
|
||||
cache[normalized] = [ip.strip() for ip in ips.split(",")]
|
||||
return cache
|
@@ -7,7 +7,7 @@ service APIConnection {
|
||||
option (needs_setup_connection) = false;
|
||||
option (needs_authentication) = false;
|
||||
}
|
||||
rpc connect (ConnectRequest) returns (ConnectResponse) {
|
||||
rpc authenticate (AuthenticationRequest) returns (AuthenticationResponse) {
|
||||
option (needs_setup_connection) = false;
|
||||
option (needs_authentication) = false;
|
||||
}
|
||||
@@ -129,7 +129,7 @@ message HelloResponse {
|
||||
|
||||
// Message sent at the beginning of each connection to authenticate the client
|
||||
// Can only be sent by the client and only at the beginning of the connection
|
||||
message ConnectRequest {
|
||||
message AuthenticationRequest {
|
||||
option (id) = 3;
|
||||
option (source) = SOURCE_CLIENT;
|
||||
option (no_delay) = true;
|
||||
@@ -141,7 +141,7 @@ message ConnectRequest {
|
||||
|
||||
// Confirmation of successful connection. After this the connection is available for all traffic.
|
||||
// Can only be sent by the server and only at the beginning of the connection
|
||||
message ConnectResponse {
|
||||
message AuthenticationResponse {
|
||||
option (id) = 4;
|
||||
option (source) = SOURCE_SERVER;
|
||||
option (no_delay) = true;
|
||||
|
@@ -1387,14 +1387,14 @@ bool APIConnection::send_hello_response(const HelloRequest &msg) {
|
||||
return this->send_message(resp, HelloResponse::MESSAGE_TYPE);
|
||||
}
|
||||
#ifdef USE_API_PASSWORD
|
||||
bool APIConnection::send_connect_response(const ConnectRequest &msg) {
|
||||
ConnectResponse resp;
|
||||
bool APIConnection::send_authenticate_response(const AuthenticationRequest &msg) {
|
||||
AuthenticationResponse resp;
|
||||
// bool invalid_password = 1;
|
||||
resp.invalid_password = !this->parent_->check_password(msg.password);
|
||||
if (!resp.invalid_password) {
|
||||
this->complete_authentication_();
|
||||
}
|
||||
return this->send_message(resp, ConnectResponse::MESSAGE_TYPE);
|
||||
return this->send_message(resp, AuthenticationResponse::MESSAGE_TYPE);
|
||||
}
|
||||
#endif // USE_API_PASSWORD
|
||||
|
||||
|
@@ -198,7 +198,7 @@ class APIConnection final : public APIServerConnection {
|
||||
#endif
|
||||
bool send_hello_response(const HelloRequest &msg) override;
|
||||
#ifdef USE_API_PASSWORD
|
||||
bool send_connect_response(const ConnectRequest &msg) override;
|
||||
bool send_authenticate_response(const AuthenticationRequest &msg) override;
|
||||
#endif
|
||||
bool send_disconnect_response(const DisconnectRequest &msg) override;
|
||||
bool send_ping_response(const PingRequest &msg) override;
|
||||
|
@@ -43,7 +43,7 @@ void HelloResponse::calculate_size(ProtoSize &size) const {
|
||||
size.add_length(1, this->name_ref_.size());
|
||||
}
|
||||
#ifdef USE_API_PASSWORD
|
||||
bool ConnectRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
|
||||
bool AuthenticationRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) {
|
||||
switch (field_id) {
|
||||
case 1:
|
||||
this->password = value.as_string();
|
||||
@@ -53,8 +53,8 @@ bool ConnectRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void ConnectResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); }
|
||||
void ConnectResponse::calculate_size(ProtoSize &size) const { size.add_bool(1, this->invalid_password); }
|
||||
void AuthenticationResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); }
|
||||
void AuthenticationResponse::calculate_size(ProtoSize &size) const { size.add_bool(1, this->invalid_password); }
|
||||
#endif
|
||||
#ifdef USE_AREAS
|
||||
void AreaInfo::encode(ProtoWriteBuffer buffer) const {
|
||||
|
@@ -361,12 +361,12 @@ class HelloResponse final : public ProtoMessage {
|
||||
protected:
|
||||
};
|
||||
#ifdef USE_API_PASSWORD
|
||||
class ConnectRequest final : public ProtoDecodableMessage {
|
||||
class AuthenticationRequest final : public ProtoDecodableMessage {
|
||||
public:
|
||||
static constexpr uint8_t MESSAGE_TYPE = 3;
|
||||
static constexpr uint8_t ESTIMATED_SIZE = 9;
|
||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||
const char *message_name() const override { return "connect_request"; }
|
||||
const char *message_name() const override { return "authentication_request"; }
|
||||
#endif
|
||||
std::string password{};
|
||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||
@@ -376,12 +376,12 @@ class ConnectRequest final : public ProtoDecodableMessage {
|
||||
protected:
|
||||
bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;
|
||||
};
|
||||
class ConnectResponse final : public ProtoMessage {
|
||||
class AuthenticationResponse final : public ProtoMessage {
|
||||
public:
|
||||
static constexpr uint8_t MESSAGE_TYPE = 4;
|
||||
static constexpr uint8_t ESTIMATED_SIZE = 2;
|
||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||
const char *message_name() const override { return "connect_response"; }
|
||||
const char *message_name() const override { return "authentication_response"; }
|
||||
#endif
|
||||
bool invalid_password{false};
|
||||
void encode(ProtoWriteBuffer buffer) const override;
|
||||
|
@@ -670,8 +670,11 @@ void HelloResponse::dump_to(std::string &out) const {
|
||||
dump_field(out, "name", this->name_ref_);
|
||||
}
|
||||
#ifdef USE_API_PASSWORD
|
||||
void ConnectRequest::dump_to(std::string &out) const { dump_field(out, "password", this->password); }
|
||||
void ConnectResponse::dump_to(std::string &out) const { dump_field(out, "invalid_password", this->invalid_password); }
|
||||
void AuthenticationRequest::dump_to(std::string &out) const { dump_field(out, "password", this->password); }
|
||||
void AuthenticationResponse::dump_to(std::string &out) const {
|
||||
MessageDumpHelper helper(out, "AuthenticationResponse");
|
||||
dump_field(out, "invalid_password", this->invalid_password);
|
||||
}
|
||||
#endif
|
||||
void DisconnectRequest::dump_to(std::string &out) const { out.append("DisconnectRequest {}"); }
|
||||
void DisconnectResponse::dump_to(std::string &out) const { out.append("DisconnectResponse {}"); }
|
||||
|
@@ -25,13 +25,13 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type,
|
||||
break;
|
||||
}
|
||||
#ifdef USE_API_PASSWORD
|
||||
case ConnectRequest::MESSAGE_TYPE: {
|
||||
ConnectRequest msg;
|
||||
case AuthenticationRequest::MESSAGE_TYPE: {
|
||||
AuthenticationRequest msg;
|
||||
msg.decode(msg_data, msg_size);
|
||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||
ESP_LOGVV(TAG, "on_connect_request: %s", msg.dump().c_str());
|
||||
ESP_LOGVV(TAG, "on_authentication_request: %s", msg.dump().c_str());
|
||||
#endif
|
||||
this->on_connect_request(msg);
|
||||
this->on_authentication_request(msg);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
@@ -600,8 +600,8 @@ void APIServerConnection::on_hello_request(const HelloRequest &msg) {
|
||||
}
|
||||
}
|
||||
#ifdef USE_API_PASSWORD
|
||||
void APIServerConnection::on_connect_request(const ConnectRequest &msg) {
|
||||
if (!this->send_connect_response(msg)) {
|
||||
void APIServerConnection::on_authentication_request(const AuthenticationRequest &msg) {
|
||||
if (!this->send_authenticate_response(msg)) {
|
||||
this->on_fatal_error();
|
||||
}
|
||||
}
|
||||
|
@@ -27,7 +27,7 @@ class APIServerConnectionBase : public ProtoService {
|
||||
virtual void on_hello_request(const HelloRequest &value){};
|
||||
|
||||
#ifdef USE_API_PASSWORD
|
||||
virtual void on_connect_request(const ConnectRequest &value){};
|
||||
virtual void on_authentication_request(const AuthenticationRequest &value){};
|
||||
#endif
|
||||
|
||||
virtual void on_disconnect_request(const DisconnectRequest &value){};
|
||||
@@ -216,7 +216,7 @@ class APIServerConnection : public APIServerConnectionBase {
|
||||
public:
|
||||
virtual bool send_hello_response(const HelloRequest &msg) = 0;
|
||||
#ifdef USE_API_PASSWORD
|
||||
virtual bool send_connect_response(const ConnectRequest &msg) = 0;
|
||||
virtual bool send_authenticate_response(const AuthenticationRequest &msg) = 0;
|
||||
#endif
|
||||
virtual bool send_disconnect_response(const DisconnectRequest &msg) = 0;
|
||||
virtual bool send_ping_response(const PingRequest &msg) = 0;
|
||||
@@ -339,7 +339,7 @@ class APIServerConnection : public APIServerConnectionBase {
|
||||
protected:
|
||||
void on_hello_request(const HelloRequest &msg) override;
|
||||
#ifdef USE_API_PASSWORD
|
||||
void on_connect_request(const ConnectRequest &msg) override;
|
||||
void on_authentication_request(const AuthenticationRequest &msg) override;
|
||||
#endif
|
||||
void on_disconnect_request(const DisconnectRequest &msg) override;
|
||||
void on_ping_request(const PingRequest &msg) override;
|
||||
|
@@ -18,6 +18,7 @@ from esphome.const import (
|
||||
DEVICE_CLASS_TEMPERATURE,
|
||||
DEVICE_CLASS_VOLTAGE,
|
||||
STATE_CLASS_MEASUREMENT,
|
||||
STATE_CLASS_TOTAL_INCREASING,
|
||||
UNIT_AMPERE,
|
||||
UNIT_CELSIUS,
|
||||
UNIT_VOLT,
|
||||
@@ -162,7 +163,7 @@ INA2XX_SCHEMA = cv.Schema(
|
||||
unit_of_measurement=UNIT_WATT_HOURS,
|
||||
accuracy_decimals=8,
|
||||
device_class=DEVICE_CLASS_ENERGY,
|
||||
state_class=STATE_CLASS_MEASUREMENT,
|
||||
state_class=STATE_CLASS_TOTAL_INCREASING,
|
||||
),
|
||||
key=CONF_NAME,
|
||||
),
|
||||
@@ -170,7 +171,8 @@ INA2XX_SCHEMA = cv.Schema(
|
||||
sensor.sensor_schema(
|
||||
unit_of_measurement=UNIT_JOULE,
|
||||
accuracy_decimals=8,
|
||||
state_class=STATE_CLASS_MEASUREMENT,
|
||||
device_class=DEVICE_CLASS_ENERGY,
|
||||
state_class=STATE_CLASS_TOTAL_INCREASING,
|
||||
),
|
||||
key=CONF_NAME,
|
||||
),
|
||||
|
@@ -39,6 +39,8 @@ from esphome.helpers import ensure_unique_string, get_str_env, is_ha_addon
|
||||
from esphome.util import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from esphome.address_cache import AddressCache
|
||||
|
||||
from ..cpp_generator import MockObj, MockObjClass, Statement
|
||||
from ..types import ConfigType, EntityMetadata
|
||||
|
||||
@@ -583,6 +585,8 @@ class EsphomeCore:
|
||||
self.id_classes = {}
|
||||
# The current component being processed during validation
|
||||
self.current_component: str | None = None
|
||||
# Address cache for DNS and mDNS lookups from command line arguments
|
||||
self.address_cache: AddressCache | None = None
|
||||
|
||||
def reset(self):
|
||||
from esphome.pins import PIN_SCHEMA_REGISTRY
|
||||
@@ -610,6 +614,7 @@ class EsphomeCore:
|
||||
self.platform_counts = defaultdict(int)
|
||||
self.unique_ids = {}
|
||||
self.current_component = None
|
||||
self.address_cache = None
|
||||
PIN_SCHEMA_REGISTRY.reset()
|
||||
|
||||
@contextmanager
|
||||
|
@@ -28,6 +28,21 @@ class DNSCache:
|
||||
self._cache: dict[str, tuple[float, list[str] | Exception]] = {}
|
||||
self._ttl = ttl
|
||||
|
||||
def get_cached_addresses(
|
||||
self, hostname: str, now_monotonic: float
|
||||
) -> list[str] | None:
|
||||
"""Get cached addresses without triggering resolution.
|
||||
|
||||
Returns None if not in cache, list of addresses if found.
|
||||
"""
|
||||
# Normalize hostname for consistent lookups
|
||||
normalized = hostname.rstrip(".").lower()
|
||||
if expire_time_addresses := self._cache.get(normalized):
|
||||
expire_time, addresses = expire_time_addresses
|
||||
if expire_time > now_monotonic and not isinstance(addresses, Exception):
|
||||
return addresses
|
||||
return None
|
||||
|
||||
async def async_resolve(
|
||||
self, hostname: str, now_monotonic: float
|
||||
) -> list[str] | Exception:
|
||||
|
@@ -4,6 +4,9 @@ import asyncio
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from zeroconf import AddressResolver, IPVersion
|
||||
|
||||
from esphome.address_cache import normalize_hostname
|
||||
from esphome.zeroconf import (
|
||||
ESPHOME_SERVICE_TYPE,
|
||||
AsyncEsphomeZeroconf,
|
||||
@@ -50,6 +53,30 @@ class MDNSStatus:
|
||||
return await aiozc.async_resolve_host(host_name)
|
||||
return None
|
||||
|
||||
def get_cached_addresses(self, host_name: str) -> list[str] | None:
|
||||
"""Get cached addresses for a host without triggering resolution.
|
||||
|
||||
Returns None if not in cache or no zeroconf available.
|
||||
"""
|
||||
if not self.aiozc:
|
||||
_LOGGER.debug("No zeroconf instance available for %s", host_name)
|
||||
return None
|
||||
|
||||
# Normalize hostname and get the base name
|
||||
normalized = normalize_hostname(host_name)
|
||||
base_name = normalized.partition(".")[0]
|
||||
|
||||
# Try to load from zeroconf cache without triggering resolution
|
||||
resolver_name = f"{base_name}.local."
|
||||
info = AddressResolver(resolver_name)
|
||||
# Let zeroconf use its own current time for cache checking
|
||||
if info.load_from_cache(self.aiozc.zeroconf):
|
||||
addresses = info.parsed_scoped_addresses(IPVersion.All)
|
||||
_LOGGER.debug("Found %s in zeroconf cache: %s", resolver_name, addresses)
|
||||
return addresses
|
||||
_LOGGER.debug("Not found in zeroconf cache: %s", resolver_name)
|
||||
return None
|
||||
|
||||
async def async_refresh_hosts(self) -> None:
|
||||
"""Refresh the hosts to track."""
|
||||
dashboard = self.dashboard
|
||||
|
@@ -50,8 +50,8 @@ from esphome.util import get_serial_ports, shlex_quote
|
||||
from esphome.yaml_util import FastestAvailableSafeLoader
|
||||
|
||||
from .const import DASHBOARD_COMMAND
|
||||
from .core import DASHBOARD
|
||||
from .entries import UNKNOWN_STATE, entry_state_to_bool
|
||||
from .core import DASHBOARD, ESPHomeDashboard
|
||||
from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool
|
||||
from .util.file import write_file
|
||||
from .util.subprocess import async_run_system_command
|
||||
from .util.text import friendly_name_slugify
|
||||
@@ -314,6 +314,73 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def build_cache_arguments(
|
||||
entry: DashboardEntry | None,
|
||||
dashboard: ESPHomeDashboard,
|
||||
now: float,
|
||||
) -> list[str]:
|
||||
"""Build cache arguments for passing to CLI.
|
||||
|
||||
Args:
|
||||
entry: Dashboard entry for the configuration
|
||||
dashboard: Dashboard instance with cache access
|
||||
now: Current monotonic time for DNS cache expiry checks
|
||||
|
||||
Returns:
|
||||
List of cache arguments to pass to CLI
|
||||
"""
|
||||
cache_args: list[str] = []
|
||||
|
||||
if not entry:
|
||||
return cache_args
|
||||
|
||||
_LOGGER.debug(
|
||||
"Building cache for entry (address=%s, name=%s)",
|
||||
entry.address,
|
||||
entry.name,
|
||||
)
|
||||
|
||||
def add_cache_entry(hostname: str, addresses: list[str], cache_type: str) -> None:
|
||||
"""Add a cache entry to the command arguments."""
|
||||
if not addresses:
|
||||
return
|
||||
normalized = hostname.rstrip(".").lower()
|
||||
cache_args.extend(
|
||||
[
|
||||
f"--{cache_type}-address-cache",
|
||||
f"{normalized}={','.join(sort_ip_addresses(addresses))}",
|
||||
]
|
||||
)
|
||||
|
||||
# Check entry.address for cached addresses
|
||||
if use_address := entry.address:
|
||||
if use_address.endswith(".local"):
|
||||
# mDNS cache for .local addresses
|
||||
if (mdns := dashboard.mdns_status) and (
|
||||
cached := mdns.get_cached_addresses(use_address)
|
||||
):
|
||||
_LOGGER.debug("mDNS cache hit for %s: %s", use_address, cached)
|
||||
add_cache_entry(use_address, cached, "mdns")
|
||||
# DNS cache for non-.local addresses
|
||||
elif cached := dashboard.dns_cache.get_cached_addresses(use_address, now):
|
||||
_LOGGER.debug("DNS cache hit for %s: %s", use_address, cached)
|
||||
add_cache_entry(use_address, cached, "dns")
|
||||
|
||||
# Check entry.name if we haven't already cached via address
|
||||
# For mDNS devices, entry.name typically doesn't have .local suffix
|
||||
if entry.name and not use_address:
|
||||
mdns_name = (
|
||||
f"{entry.name}.local" if not entry.name.endswith(".local") else entry.name
|
||||
)
|
||||
if (mdns := dashboard.mdns_status) and (
|
||||
cached := mdns.get_cached_addresses(mdns_name)
|
||||
):
|
||||
_LOGGER.debug("mDNS cache hit for %s: %s", mdns_name, cached)
|
||||
add_cache_entry(mdns_name, cached, "mdns")
|
||||
|
||||
return cache_args
|
||||
|
||||
|
||||
class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
|
||||
"""Base class for commands that require a port."""
|
||||
|
||||
@@ -326,52 +393,22 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket):
|
||||
configuration = json_message["configuration"]
|
||||
config_file = settings.rel_path(configuration)
|
||||
port = json_message["port"]
|
||||
addresses: list[str] = []
|
||||
|
||||
# Build cache arguments to pass to CLI
|
||||
cache_args: list[str] = []
|
||||
|
||||
if (
|
||||
port == "OTA" # pylint: disable=too-many-boolean-expressions
|
||||
and (entry := entries.get(config_file))
|
||||
and entry.loaded_integrations
|
||||
and "api" in entry.loaded_integrations
|
||||
):
|
||||
# First priority: entry.address AKA use_address
|
||||
if (
|
||||
(use_address := entry.address)
|
||||
and (
|
||||
address_list := await dashboard.dns_cache.async_resolve(
|
||||
use_address, time.monotonic()
|
||||
)
|
||||
)
|
||||
and not isinstance(address_list, Exception)
|
||||
):
|
||||
addresses.extend(sort_ip_addresses(address_list))
|
||||
cache_args = build_cache_arguments(entry, dashboard, time.monotonic())
|
||||
|
||||
# Second priority: mDNS
|
||||
if (
|
||||
(mdns := dashboard.mdns_status)
|
||||
and (address_list := await mdns.async_resolve_host(entry.name))
|
||||
and (
|
||||
new_addresses := [
|
||||
addr for addr in address_list if addr not in addresses
|
||||
]
|
||||
)
|
||||
):
|
||||
# Use the IP address if available but only
|
||||
# if the API is loaded and the device is online
|
||||
# since MQTT logging will not work otherwise
|
||||
addresses.extend(sort_ip_addresses(new_addresses))
|
||||
|
||||
if not addresses:
|
||||
# If no address was found, use the port directly
|
||||
# as otherwise they will get the chooser which
|
||||
# does not work with the dashboard as there is no
|
||||
# interactive way to get keyboard input
|
||||
addresses = [port]
|
||||
|
||||
device_args: list[str] = [
|
||||
arg for address in addresses for arg in ("--device", address)
|
||||
]
|
||||
|
||||
return [*DASHBOARD_COMMAND, *args, config_file, *device_args]
|
||||
# Cache arguments must come before the subcommand
|
||||
cmd = [*DASHBOARD_COMMAND, *cache_args, *args, config_file, "--device", port]
|
||||
_LOGGER.debug("Built command: %s", cmd)
|
||||
return cmd
|
||||
|
||||
|
||||
class EsphomeLogsHandler(EsphomePortCommandWebSocket):
|
||||
|
@@ -311,10 +311,14 @@ def perform_ota(
|
||||
def run_ota_impl_(
|
||||
remote_host: str | list[str], remote_port: int, password: str, filename: str
|
||||
) -> tuple[int, str | None]:
|
||||
from esphome.core import CORE
|
||||
|
||||
# Handle both single host and list of hosts
|
||||
try:
|
||||
# Resolve all hosts at once for parallel DNS resolution
|
||||
res = resolve_ip_address(remote_host, remote_port)
|
||||
res = resolve_ip_address(
|
||||
remote_host, remote_port, address_cache=CORE.address_cache
|
||||
)
|
||||
except EsphomeError as err:
|
||||
_LOGGER.error(
|
||||
"Error resolving IP address of %s. Is it connected to WiFi?",
|
||||
|
@@ -9,10 +9,14 @@ from pathlib import Path
|
||||
import platform
|
||||
import re
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from esphome.const import __version__ as ESPHOME_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from esphome.address_cache import AddressCache
|
||||
|
||||
# Type aliases for socket address information
|
||||
AddrInfo = tuple[
|
||||
int, # family (AF_INET, AF_INET6, etc.)
|
||||
@@ -173,7 +177,24 @@ def addr_preference_(res: AddrInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
|
||||
def _add_ip_addresses_to_addrinfo(
|
||||
addresses: list[str], port: int, res: list[AddrInfo]
|
||||
) -> None:
|
||||
"""Helper to add IP addresses to addrinfo results with error handling."""
|
||||
import socket
|
||||
|
||||
for addr in addresses:
|
||||
try:
|
||||
res += socket.getaddrinfo(
|
||||
addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
|
||||
)
|
||||
except OSError:
|
||||
_LOGGER.debug("Failed to parse IP address '%s'", addr)
|
||||
|
||||
|
||||
def resolve_ip_address(
|
||||
host: str | list[str], port: int, address_cache: AddressCache | None = None
|
||||
) -> list[AddrInfo]:
|
||||
import socket
|
||||
|
||||
# There are five cases here. The host argument could be one of:
|
||||
@@ -194,47 +215,69 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]:
|
||||
hosts = [host]
|
||||
|
||||
res: list[AddrInfo] = []
|
||||
|
||||
# Fast path: if all hosts are already IP addresses
|
||||
if all(is_ip_address(h) for h in hosts):
|
||||
# Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST
|
||||
for addr in hosts:
|
||||
try:
|
||||
res += socket.getaddrinfo(
|
||||
addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
|
||||
)
|
||||
except OSError:
|
||||
_LOGGER.debug("Failed to parse IP address '%s'", addr)
|
||||
_add_ip_addresses_to_addrinfo(hosts, port, res)
|
||||
# Sort by preference
|
||||
res.sort(key=addr_preference_)
|
||||
return res
|
||||
|
||||
from esphome.resolver import AsyncResolver
|
||||
# Process hosts
|
||||
cached_addresses: list[str] = []
|
||||
uncached_hosts: list[str] = []
|
||||
has_cache = address_cache is not None
|
||||
|
||||
resolver = AsyncResolver(hosts, port)
|
||||
addr_infos = resolver.resolve()
|
||||
# Convert aioesphomeapi AddrInfo to our format
|
||||
for addr_info in addr_infos:
|
||||
sockaddr = addr_info.sockaddr
|
||||
if addr_info.family == socket.AF_INET6:
|
||||
# IPv6
|
||||
sockaddr_tuple = (
|
||||
sockaddr.address,
|
||||
sockaddr.port,
|
||||
sockaddr.flowinfo,
|
||||
sockaddr.scope_id,
|
||||
)
|
||||
for h in hosts:
|
||||
if is_ip_address(h):
|
||||
if has_cache:
|
||||
# If we have a cache, treat IPs as cached
|
||||
cached_addresses.append(h)
|
||||
else:
|
||||
# If no cache, pass IPs through to resolver with hostnames
|
||||
uncached_hosts.append(h)
|
||||
elif address_cache and (cached := address_cache.get_addresses(h)):
|
||||
# Found in cache
|
||||
cached_addresses.extend(cached)
|
||||
else:
|
||||
# IPv4
|
||||
sockaddr_tuple = (sockaddr.address, sockaddr.port)
|
||||
# Not cached, need to resolve
|
||||
if address_cache and address_cache.has_cache():
|
||||
_LOGGER.info("Host %s not in cache, will need to resolve", h)
|
||||
uncached_hosts.append(h)
|
||||
|
||||
res.append(
|
||||
(
|
||||
addr_info.family,
|
||||
addr_info.type,
|
||||
addr_info.proto,
|
||||
"", # canonname
|
||||
sockaddr_tuple,
|
||||
# Process cached addresses (includes direct IPs and cached lookups)
|
||||
_add_ip_addresses_to_addrinfo(cached_addresses, port, res)
|
||||
|
||||
# If we have uncached hosts (only non-IP hostnames), resolve them
|
||||
if uncached_hosts:
|
||||
from esphome.resolver import AsyncResolver
|
||||
|
||||
resolver = AsyncResolver(uncached_hosts, port)
|
||||
addr_infos = resolver.resolve()
|
||||
# Convert aioesphomeapi AddrInfo to our format
|
||||
for addr_info in addr_infos:
|
||||
sockaddr = addr_info.sockaddr
|
||||
if addr_info.family == socket.AF_INET6:
|
||||
# IPv6
|
||||
sockaddr_tuple = (
|
||||
sockaddr.address,
|
||||
sockaddr.port,
|
||||
sockaddr.flowinfo,
|
||||
sockaddr.scope_id,
|
||||
)
|
||||
else:
|
||||
# IPv4
|
||||
sockaddr_tuple = (sockaddr.address, sockaddr.port)
|
||||
|
||||
res.append(
|
||||
(
|
||||
addr_info.family,
|
||||
addr_info.type,
|
||||
addr_info.proto,
|
||||
"", # canonname
|
||||
sockaddr_tuple,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by preference
|
||||
res.sort(key=addr_preference_)
|
||||
@@ -256,14 +299,7 @@ def sort_ip_addresses(address_list: list[str]) -> list[str]:
|
||||
# First "resolve" all the IP addresses to getaddrinfo() tuples of the form
|
||||
# (family, type, proto, canonname, sockaddr)
|
||||
res: list[AddrInfo] = []
|
||||
for addr in address_list:
|
||||
# This should always work as these are supposed to be IP addresses
|
||||
try:
|
||||
res += socket.getaddrinfo(
|
||||
addr, 0, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST
|
||||
)
|
||||
except OSError:
|
||||
_LOGGER.info("Failed to parse IP address '%s'", addr)
|
||||
_add_ip_addresses_to_addrinfo(address_list, 0, res)
|
||||
|
||||
# Now use that information to sort them.
|
||||
res.sort(key=addr_preference_)
|
||||
|
Reference in New Issue
Block a user