diff --git a/esphome/__main__.py b/esphome/__main__.py index 404fdded6e..f009f876d1 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -13,7 +13,6 @@ import time from typing import Protocol import argcomplete -from zeroconf import ServiceBrowser, ServiceStateChange, Zeroconf # Note: Do not import modules from esphome.components here, as this would # cause them to be loaded before external components are processed, resulting @@ -61,7 +60,7 @@ from esphome.util import ( run_external_process, safe_print, ) -from esphome.zeroconf import ESPHOME_SERVICE_TYPE +from esphome.zeroconf import discover_mdns_devices _LOGGER = logging.getLogger(__name__) @@ -235,32 +234,8 @@ def choose_upload_log_host( (f"{port.path} ({port.description})", port.path) for port in get_serial_ports() ] - if purpose == Purpose.LOGGING: - if has_mqtt_logging(): - mqtt_config = CORE.config[CONF_MQTT] - options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) - - if has_api(): - if has_name_add_mac_suffix() and has_mdns() and has_non_ip_address(): - # Discover devices via mDNS when name_add_mac_suffix is enabled - safe_print("Discovering devices...") - discovered = discover_mdns_devices(CORE.name) - for device_addr in discovered: - options.append((f"Over The Air ({device_addr})", device_addr)) - if not discovered and has_resolvable_address(): - # No devices found, show base address as fallback - options.append( - ( - f"Over The Air ({CORE.address}) (no devices found)", - CORE.address, - ) - ) - elif has_resolvable_address(): - options.append((f"Over The Air ({CORE.address})", CORE.address)) - if has_mqtt_ip_lookup(): - options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) - - elif purpose == Purpose.UPLOADING and has_ota(): + def add_ota_options() -> None: + """Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled.""" if has_name_add_mac_suffix() and has_mdns() and has_non_ip_address(): # Discover devices via mDNS when name_add_mac_suffix is enabled safe_print("Discovering devices...") @@ -277,6 +252,17 @@ def choose_upload_log_host( if has_mqtt_ip_lookup(): options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + if purpose == Purpose.LOGGING: + if has_mqtt_logging(): + mqtt_config = CORE.config[CONF_MQTT] + options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) + + if has_api(): + add_ota_options() + + elif purpose == Purpose.UPLOADING and has_ota(): + add_ota_options() + if check_default is not None and check_default in [opt[1] for opt in options]: return [check_default] return [choose_prompt(options, purpose=purpose)] @@ -370,46 +356,6 @@ def has_name_add_mac_suffix() -> bool: return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False) -def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> list[str]: - """Discover ESPHome devices via mDNS that match the base name pattern. - - When name_add_mac_suffix is enabled, devices advertise as -.local. - This function discovers all such devices on the network. - - Args: - base_name: The base device name (without MAC suffix) - timeout: How long to wait for mDNS responses (default 5 seconds) - - Returns: - List of discovered device addresses (e.g., ['device-abc123.local']) - """ - discovered: list[str] = [] - prefix = f"{base_name}-" - - def on_service_state_change( - zeroconf: Zeroconf, - service_type: str, - name: str, - state_change: ServiceStateChange, - ) -> None: - if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated): - # Extract device name from service name (removes service type suffix) - device_name = name.partition(".")[0] - # Check if this device matches our base name pattern - if device_name.startswith(prefix) and device_name not in discovered: - discovered.append(device_name) - - zc = Zeroconf() - try: - ServiceBrowser(zc, ESPHOME_SERVICE_TYPE, handlers=[on_service_state_change]) - # Wait for discovery - time.sleep(timeout) - finally: - zc.close() - - return [f"{name}.local" for name in sorted(discovered)] - - def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str): from esphome import mqtt diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index dc4ca77eb4..ae8ec541ef 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -4,10 +4,12 @@ import asyncio from collections.abc import Callable from dataclasses import dataclass import logging +import time from zeroconf import ( AddressResolver, IPVersion, + ServiceBrowser, ServiceInfo, ServiceStateChange, Zeroconf, @@ -200,3 +202,53 @@ class AsyncEsphomeZeroconf(AsyncZeroconf): ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)): return addresses return None + + +def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> list[str]: + """Discover ESPHome devices via mDNS that match the base name pattern. + + When name_add_mac_suffix is enabled, devices advertise as -.local. + This function discovers all such devices on the network. + + Args: + base_name: The base device name (without MAC suffix) + timeout: How long to wait for mDNS responses (default 5 seconds) + + Returns: + List of discovered device addresses (e.g., ['device-abc123.local']) + """ + discovered: list[str] = [] + prefix = f"{base_name}-" + + def on_service_state_change( + zeroconf: Zeroconf, + service_type: str, + name: str, + state_change: ServiceStateChange, + ) -> None: + if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated): + # Extract device name from service name (removes service type suffix) + device_name = name.partition(".")[0] + # Check if this device matches our base name pattern + if device_name.startswith(prefix) and device_name not in discovered: + discovered.append(device_name) + + try: + zc = Zeroconf() + except Exception as err: + _LOGGER.warning("mDNS discovery failed to initialize: %s", err) + return [] + + browser: ServiceBrowser | None = None + try: + browser = ServiceBrowser( + zc, ESPHOME_SERVICE_TYPE, handlers=[on_service_state_change] + ) + # Wait for discovery + time.sleep(timeout) + finally: + if browser is not None: + browser.cancel() + zc.close() + + return [f"{name}.local" for name in sorted(discovered)] diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index 240e6638a8..bab2ea8672 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -27,7 +27,6 @@ from esphome.__main__ import ( command_wizard, compile_program, detect_external_components, - discover_mdns_devices, get_port_type, has_ip_address, has_mqtt, @@ -72,6 +71,7 @@ from esphome.const import ( PLATFORM_RP2040, ) from esphome.core import CORE, EsphomeError +from esphome.zeroconf import discover_mdns_devices def strip_ansi_codes(text: str) -> str: @@ -1686,14 +1686,16 @@ def test_has_name_add_mac_suffix() -> None: def mock_mdns_discovery() -> Generator[MagicMock]: """Fixture to mock mDNS discovery infrastructure.""" with ( - patch("esphome.__main__.Zeroconf") as mock_zeroconf_class, - patch("esphome.__main__.ServiceBrowser") as mock_browser_class, - patch("esphome.__main__.time.sleep"), + patch("esphome.zeroconf.Zeroconf") as mock_zeroconf_class, + patch("esphome.zeroconf.ServiceBrowser") as mock_browser_class, + patch("esphome.zeroconf.time.sleep"), ): mock_zc = MagicMock() mock_zeroconf_class.return_value = mock_zc + mock_browser = MagicMock() # Store references for test access mock_zc._mock_browser_class = mock_browser_class + mock_zc._mock_browser = mock_browser yield mock_zc @@ -1743,18 +1745,20 @@ def test_discover_mdns_devices( expected: list[str], ) -> None: """Test discover_mdns_devices function with various scenarios.""" + mock_browser = mock_mdns_discovery._mock_browser def capture_callback(zc, service_type, handlers): callback = handlers[0] for service_name, state_change in discovered_services: callback(mock_mdns_discovery, service_type, service_name, state_change) - return MagicMock() + return mock_browser mock_mdns_discovery._mock_browser_class.side_effect = capture_callback result = discover_mdns_devices(base_name, timeout=0.1) assert result == expected + mock_browser.cancel.assert_called_once() mock_mdns_discovery.close.assert_called_once()