diff --git a/esphome/components/esphome/ota/__init__.py b/esphome/components/esphome/ota/__init__.py index 72a690b926..e6f249e021 100644 --- a/esphome/components/esphome/ota/__init__.py +++ b/esphome/components/esphome/ota/__init__.py @@ -140,13 +140,14 @@ async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) cg.add(var.set_port(config[CONF_PORT])) - # Only include SHA256 support on platforms that have it - if supports_sha256(): - cg.add_define("USE_OTA_SHA256") - if CONF_PASSWORD in config: cg.add(var.set_auth_password(config[CONF_PASSWORD])) cg.add_define("USE_OTA_PASSWORD") + # Only include hash algorithms when password is configured + cg.add_define("USE_OTA_MD5") + # Only include SHA256 support on platforms that have it + if supports_sha256(): + cg.add_define("USE_OTA_SHA256") cg.add_define("USE_OTA_VERSION", config[CONF_VERSION]) await cg.register_component(var, config) diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index 206905d0d8..f503ff795e 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -1,6 +1,8 @@ #include "ota_esphome.h" #ifdef USE_OTA +#ifdef USE_OTA_MD5 #include "esphome/components/md5/md5.h" +#endif #ifdef USE_OTA_SHA256 #include "esphome/components/sha256/sha256.h" #endif @@ -267,12 +269,14 @@ void ESPHomeOTAComponent::handle_data_() { if (client_supports_sha256) { sha256::SHA256 sha_hasher; auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, 16, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH, - LOG_STR("SHA256")); + LOG_STR("SHA256"), sbuf); } else { +#ifdef USE_OTA_MD5 ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)"); md5::MD5Digest md5_hasher; - auth_success = - this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH, LOG_STR("MD5")); + auth_success = this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH, + LOG_STR("MD5"), sbuf); +#endif // USE_OTA_MD5 } #else // Strict mode: SHA256 required on capable platforms (future default) @@ -281,13 +285,18 @@ void ESPHomeOTAComponent::handle_data_() { error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID; goto error; // NOLINT(cppcoreguidelines-avoid-goto) } - auth_success = this->perform_hash_auth_(this->password_); + sha256::SHA256 sha_hasher; + auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, 16, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH, + LOG_STR("SHA256"), sbuf); #endif // ALLOW_OTA_DOWNGRADE_MD5 #else // Platform only supports MD5 - use it as the only available option // This is not a security downgrade as the platform cannot support SHA256 +#ifdef USE_OTA_MD5 md5::MD5Digest md5_hasher; - auth_success = this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH); + auth_success = + this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH, LOG_STR("MD5"), sbuf); +#endif // USE_OTA_MD5 #endif // USE_OTA_SHA256 if (!auth_success) { @@ -514,29 +523,24 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() { delay(1); } +#ifdef USE_OTA_PASSWORD void ESPHomeOTAComponent::log_auth_warning_(const LogString *action, const LogString *hash_name) { ESP_LOGW(TAG, "Auth: %s %s failed", LOG_STR_ARG(action), LOG_STR_ARG(hash_name)); } // Non-template function definition to reduce binary size bool ESPHomeOTAComponent::perform_hash_auth_(HashBase *hasher, const std::string &password, size_t nonce_size, - uint8_t auth_request, const LogString *name) { + uint8_t auth_request, const LogString *name, char *buf) { // Get sizes from the hasher const size_t hex_size = hasher->get_hex_size(); - // Use fixed-size buffers for the maximum possible hash size (SHA256 = 64 chars) - // This avoids dynamic allocation overhead - static constexpr size_t MAX_HEX_SIZE = 65; // SHA256 hex + null terminator - char hex_buffer1[MAX_HEX_SIZE]; // Used for: nonce -> expected result - char hex_buffer2[MAX_HEX_SIZE]; // Used for: cnonce -> response + // Use the provided buffer for all hex operations - // Small stack buffer for auth request and nonce seed bytes - uint8_t buf[1]; + // Small stack buffer for nonce seed bytes uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256) // Send auth request type - buf[0] = auth_request; - this->writeall_(buf, 1); + this->writeall_(&auth_request, 1); hasher->init(); @@ -562,56 +566,55 @@ bool ESPHomeOTAComponent::perform_hash_auth_(HashBase *hasher, const std::string } hasher->calculate(); - // Use hex_buffer1 for nonce - hasher->get_hex(hex_buffer1); - hex_buffer1[hex_size] = '\0'; - ESP_LOGV(TAG, "Auth: %s Nonce is %s", LOG_STR_ARG(name), hex_buffer1); + // Generate and send nonce + hasher->get_hex(buf); + buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: %s Nonce is %s", LOG_STR_ARG(name), buf); - // Send nonce - if (!this->writeall_(reinterpret_cast(hex_buffer1), hex_size)) { + if (!this->writeall_(reinterpret_cast(buf), hex_size)) { this->log_auth_warning_(LOG_STR("Writing nonce"), name); return false; } - // Prepare challenge + // Start challenge: password + nonce hasher->init(); hasher->add(password.c_str(), password.length()); - hasher->add(hex_buffer1, hex_size); // Add nonce + hasher->add(buf, hex_size); - // Receive cnonce into hex_buffer2 - if (!this->readall_(reinterpret_cast(hex_buffer2), hex_size)) { + // Read cnonce and add to hash + if (!this->readall_(reinterpret_cast(buf), hex_size)) { this->log_auth_warning_(LOG_STR("Reading cnonce"), name); return false; } - hex_buffer2[hex_size] = '\0'; - ESP_LOGV(TAG, "Auth: %s CNonce is %s", LOG_STR_ARG(name), hex_buffer2); + buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: %s CNonce is %s", LOG_STR_ARG(name), buf); - // Add cnonce to hash - hasher->add(hex_buffer2, hex_size); - - // Calculate result - reuse hex_buffer1 for expected + hasher->add(buf, hex_size); hasher->calculate(); - hasher->get_hex(hex_buffer1); - hex_buffer1[hex_size] = '\0'; - ESP_LOGV(TAG, "Auth: %s Result is %s", LOG_STR_ARG(name), hex_buffer1); - // Receive response - reuse hex_buffer2 - if (!this->readall_(reinterpret_cast(hex_buffer2), hex_size)) { + // Log expected result (digest is already in hasher) + hasher->get_hex(buf); + buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: %s Result is %s", LOG_STR_ARG(name), buf); + + // Read response into the buffer + if (!this->readall_(reinterpret_cast(buf), hex_size)) { this->log_auth_warning_(LOG_STR("Reading response"), name); return false; } - hex_buffer2[hex_size] = '\0'; - ESP_LOGV(TAG, "Auth: %s Response is %s", LOG_STR_ARG(name), hex_buffer2); + buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: %s Response is %s", LOG_STR_ARG(name), buf); - // Compare - bool matches = memcmp(hex_buffer1, hex_buffer2, hex_size) == 0; + // Compare response directly with digest in hasher + bool matches = hasher->equals_hex(buf); if (!matches) { - ESP_LOGW(TAG, "Auth failed! %s passwords do not match", LOG_STR_ARG(name)); + this->log_auth_warning_(LOG_STR("Password mismatch"), name); } return matches; } +#endif // USE_OTA_PASSWORD } // namespace esphome #endif diff --git a/esphome/components/esphome/ota/ota_esphome.h b/esphome/components/esphome/ota/ota_esphome.h index 5d806028ac..39f2f878de 100644 --- a/esphome/components/esphome/ota/ota_esphome.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -31,14 +31,16 @@ class ESPHomeOTAComponent : public ota::OTAComponent { protected: void handle_handshake_(); void handle_data_(); +#ifdef USE_OTA_PASSWORD bool perform_hash_auth_(HashBase *hasher, const std::string &password, size_t nonce_size, uint8_t auth_request, - const LogString *name); + const LogString *name, char *buf); + void log_auth_warning_(const LogString *action, const LogString *hash_name); +#endif // USE_OTA_PASSWORD bool readall_(uint8_t *buf, size_t len); bool writeall_(const uint8_t *buf, size_t len); void log_socket_error_(const LogString *msg); void log_read_error_(const LogString *what); void log_start_(const LogString *phase); - void log_auth_warning_(const LogString *action, const LogString *hash_name); void cleanup_connection_(); void yield_and_feed_watchdog_(); diff --git a/esphome/components/sha256/__init__.py b/esphome/components/sha256/__init__.py index 91d4929a4f..f07157416d 100644 --- a/esphome/components/sha256/__init__.py +++ b/esphome/components/sha256/__init__.py @@ -13,12 +13,10 @@ CONFIG_SCHEMA = cv.Schema({}) async def to_code(config: ConfigType) -> None: # Add OpenSSL library for host platform - if CORE.is_host: - if IS_MACOS: - # macOS needs special handling for Homebrew OpenSSL - cg.add_build_flag("-I/opt/homebrew/opt/openssl/include") - cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib") - cg.add_build_flag("-lcrypto") - else: - # Linux and other Unix systems usually have OpenSSL in standard paths - cg.add_build_flag("-lcrypto") + if not CORE.is_host: + return + if IS_MACOS: + # macOS needs special handling for Homebrew OpenSSL + cg.add_build_flag("-I/opt/homebrew/opt/openssl/include") + cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib") + cg.add_build_flag("-lcrypto") diff --git a/esphome/writer.py b/esphome/writer.py index c0d4379b3a..6d34d8f751 100644 --- a/esphome/writer.py +++ b/esphome/writer.py @@ -1,5 +1,6 @@ import importlib import logging +import os from pathlib import Path import re @@ -301,6 +302,11 @@ def clean_cmake_cache(): def clean_build(): import shutil + # Allow skipping cache cleaning for integration tests + if os.environ.get("ESPHOME_SKIP_CLEAN_BUILD"): + _LOGGER.warning("Skipping build cleaning (ESPHOME_SKIP_CLEAN_BUILD set)") + return + pioenvs = CORE.relative_pioenvs_path() if pioenvs.is_dir(): _LOGGER.info("Deleting %s", pioenvs) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 94632f8439..965363972f 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -58,6 +58,8 @@ def _get_platformio_env(cache_dir: Path) -> dict[str, str]: env["PLATFORMIO_CORE_DIR"] = str(cache_dir) env["PLATFORMIO_CACHE_DIR"] = str(cache_dir / ".cache") env["PLATFORMIO_LIBDEPS_DIR"] = str(cache_dir / "libdeps") + # Prevent cache cleaning during integration tests + env["ESPHOME_SKIP_CLEAN_BUILD"] = "1" return env @@ -68,6 +70,11 @@ def shared_platformio_cache() -> Generator[Path]: test_cache_dir = Path.home() / ".esphome-integration-tests" cache_dir = test_cache_dir / "platformio" + # Create the temp directory that PlatformIO uses to avoid race conditions + # This ensures it exists and won't be deleted by parallel processes + platformio_tmp_dir = cache_dir / ".cache" / "tmp" + platformio_tmp_dir.mkdir(parents=True, exist_ok=True) + # Use a lock file in the home directory to ensure only one process initializes the cache # This is needed when running with pytest-xdist # The lock file must be in a directory that already exists to avoid race conditions @@ -83,17 +90,11 @@ def shared_platformio_cache() -> Generator[Path]: test_cache_dir.mkdir(exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: - # Create a basic host config + # Use the cache_init fixture for initialization init_dir = Path(tmpdir) + fixture_path = Path(__file__).parent / "fixtures" / "cache_init.yaml" config_path = init_dir / "cache_init.yaml" - config_path.write_text("""esphome: - name: cache-init -host: -api: - encryption: - key: "IIevImVI42I0FGos5nLqFK91jrJehrgidI0ArwMLr8w=" -logger: -""") + config_path.write_text(fixture_path.read_text()) # Run compilation to populate the cache # We must succeed here to avoid race conditions where multiple @@ -346,7 +347,8 @@ async def wait_and_connect_api_client( noise_psk: str | None = None, client_info: str = "integration-test", timeout: float = API_CONNECTION_TIMEOUT, -) -> AsyncGenerator[APIClient]: + return_disconnect_event: bool = False, +) -> AsyncGenerator[APIClient | tuple[APIClient, asyncio.Event]]: """Wait for API to be available and connect.""" client = APIClient( address=address, @@ -359,14 +361,17 @@ async def wait_and_connect_api_client( # Create a future to signal when connected loop = asyncio.get_running_loop() connected_future: asyncio.Future[None] = loop.create_future() + disconnect_event = asyncio.Event() async def on_connect() -> None: """Called when successfully connected.""" + disconnect_event.clear() # Clear the disconnect event on new connection if not connected_future.done(): connected_future.set_result(None) async def on_disconnect(expected_disconnect: bool) -> None: """Called when disconnected.""" + disconnect_event.set() if not connected_future.done() and not expected_disconnect: connected_future.set_exception( APIConnectionError("Disconnected before fully connected") @@ -397,7 +402,10 @@ async def wait_and_connect_api_client( except TimeoutError: raise TimeoutError(f"Failed to connect to API after {timeout} seconds") - yield client + if return_disconnect_event: + yield client, disconnect_event + else: + yield client finally: # Stop reconnect logic and disconnect await reconnect_logic.stop() @@ -430,6 +438,33 @@ async def api_client_connected( yield _connect_client +@pytest_asyncio.fixture +async def api_client_connected_with_disconnect( + unused_tcp_port: int, +) -> AsyncGenerator: + """Factory for creating connected API client context managers with disconnect event.""" + + def _connect_client_with_disconnect( + address: str = LOCALHOST, + port: int | None = None, + password: str = "", + noise_psk: str | None = None, + client_info: str = "integration-test", + timeout: float = API_CONNECTION_TIMEOUT, + ): + return wait_and_connect_api_client( + address=address, + port=port if port is not None else unused_tcp_port, + password=password, + noise_psk=noise_psk, + client_info=client_info, + timeout=timeout, + return_disconnect_event=True, + ) + + yield _connect_client_with_disconnect + + async def _read_stream_lines( stream: asyncio.StreamReader, lines: list[str], diff --git a/tests/integration/fixtures/cache_init.yaml b/tests/integration/fixtures/cache_init.yaml new file mode 100644 index 0000000000..de208196cd --- /dev/null +++ b/tests/integration/fixtures/cache_init.yaml @@ -0,0 +1,10 @@ +esphome: + name: cache-init + +host: + +api: + encryption: + key: "IIevImVI42I0FGos5nLqFK91jrJehrgidI0ArwMLr8w=" + +logger: diff --git a/tests/integration/fixtures/noise_corrupt_encrypted_frame.yaml b/tests/integration/fixtures/noise_corrupt_encrypted_frame.yaml new file mode 100644 index 0000000000..6f0266c6fd --- /dev/null +++ b/tests/integration/fixtures/noise_corrupt_encrypted_frame.yaml @@ -0,0 +1,11 @@ +esphome: + name: oversized-noise + +host: + +api: + encryption: + key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU= + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_payload_noise.yaml b/tests/integration/fixtures/oversized_payload_noise.yaml new file mode 100644 index 0000000000..6f0266c6fd --- /dev/null +++ b/tests/integration/fixtures/oversized_payload_noise.yaml @@ -0,0 +1,11 @@ +esphome: + name: oversized-noise + +host: + +api: + encryption: + key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU= + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_payload_plaintext.yaml b/tests/integration/fixtures/oversized_payload_plaintext.yaml new file mode 100644 index 0000000000..44ece4f770 --- /dev/null +++ b/tests/integration/fixtures/oversized_payload_plaintext.yaml @@ -0,0 +1,9 @@ +esphome: + name: oversized-plaintext + +host: + +api: + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_protobuf_message_id_noise.yaml b/tests/integration/fixtures/oversized_protobuf_message_id_noise.yaml new file mode 100644 index 0000000000..6f0266c6fd --- /dev/null +++ b/tests/integration/fixtures/oversized_protobuf_message_id_noise.yaml @@ -0,0 +1,11 @@ +esphome: + name: oversized-noise + +host: + +api: + encryption: + key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU= + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_protobuf_message_id_plaintext.yaml b/tests/integration/fixtures/oversized_protobuf_message_id_plaintext.yaml new file mode 100644 index 0000000000..1e9eadfdc5 --- /dev/null +++ b/tests/integration/fixtures/oversized_protobuf_message_id_plaintext.yaml @@ -0,0 +1,9 @@ +esphome: + name: oversized-protobuf-plaintext + +host: + +api: + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/test_oversized_payloads.py b/tests/integration/test_oversized_payloads.py new file mode 100644 index 0000000000..f3e422620c --- /dev/null +++ b/tests/integration/test_oversized_payloads.py @@ -0,0 +1,335 @@ +"""Integration tests for oversized payloads and headers that should cause disconnection.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from .types import APIClientConnectedWithDisconnectFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_oversized_payload_plaintext( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that oversized payloads (>100KiB) from client cause disconnection without crashing.""" + process_exited = False + helper_log_found = False + + def check_logs(line: str) -> None: + nonlocal process_exited, helper_log_found + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for HELPER_LOG message about message size exceeding maximum + if ( + "[VV]" in line + and "Bad packet: message size" in line + and "exceeds maximum" in line + ): + helper_log_found = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect() as (client, disconnect_event): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-plaintext" + + # Create an oversized payload (>100KiB) + oversized_data = b"X" * (100 * 1024 + 1) # 100KiB + 1 byte + + # Access the internal connection to send raw data + frame_helper = client._connection._frame_helper + # Create a message with oversized payload + # Using message type 1 (DeviceInfoRequest) as an example + message_type = 1 + frame_helper.write_packets([(message_type, oversized_data)], True) + + # Wait for the connection to be closed by ESPHome + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + # Verify we saw the expected HELPER_LOG message + assert helper_log_found, ( + "Expected to see HELPER_LOG about message size exceeding maximum" + ) + + # Try to reconnect to verify the process is still running + async with api_client_connected_with_disconnect() as (client2, _): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-plaintext" + + +@pytest.mark.asyncio +async def test_oversized_protobuf_message_id_plaintext( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that protobuf messages with ID > UINT16_MAX cause disconnection without crashing. + + This tests the message type limit - message IDs must fit in a uint16_t (0-65535). + """ + process_exited = False + helper_log_found = False + + def check_logs(line: str) -> None: + nonlocal process_exited, helper_log_found + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for HELPER_LOG message about message type exceeding maximum + if ( + "[VV]" in line + and "Bad packet: message type" in line + and "exceeds maximum" in line + ): + helper_log_found = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect() as (client, disconnect_event): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-protobuf-plaintext" + + # Access the internal connection to send raw message with large ID + frame_helper = client._connection._frame_helper + # Message ID that exceeds uint16_t limit (> 65535) + large_message_id = 65536 # 2^16, exceeds UINT16_MAX + # Small payload for the test + payload = b"test" + + # This should cause disconnection due to oversized varint + frame_helper.write_packets([(large_message_id, payload)], True) + + # Wait for the connection to be closed by ESPHome + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + # Verify we saw the expected HELPER_LOG message + assert helper_log_found, ( + "Expected to see HELPER_LOG about message type exceeding maximum" + ) + + # Try to reconnect to verify the process is still running + async with api_client_connected_with_disconnect() as (client2, _): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-protobuf-plaintext" + + +@pytest.mark.asyncio +async def test_oversized_payload_noise( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that oversized payloads (>100KiB) from client cause disconnection without crashing with noise encryption.""" + noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=" + process_exited = False + cipherstate_failed = False + + def check_logs(line: str) -> None: + nonlocal process_exited, cipherstate_failed + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for the expected warning about decryption failure + if ( + "[W][api.connection" in line + and "Reading failed CIPHERSTATE_DECRYPT_FAILED" in line + ): + cipherstate_failed = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client, + disconnect_event, + ): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + # Create an oversized payload (>100KiB) + oversized_data = b"Y" * (100 * 1024 + 1) # 100KiB + 1 byte + + # Access the internal connection to send raw data + frame_helper = client._connection._frame_helper + # For noise connections, we still send through write_packets + # but the frame helper will handle encryption + # Using message type 1 (DeviceInfoRequest) as an example + message_type = 1 + frame_helper.write_packets([(message_type, oversized_data)], True) + + # Wait for the connection to be closed by ESPHome + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + # Verify we saw the expected warning message + assert cipherstate_failed, ( + "Expected to see warning about CIPHERSTATE_DECRYPT_FAILED" + ) + + # Try to reconnect to verify the process is still running + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client2, + _, + ): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + +@pytest.mark.asyncio +async def test_oversized_protobuf_message_id_noise( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that the noise protocol handles unknown message types correctly. + + With noise encryption, message types are stored as uint16_t (2 bytes) after decryption. + Unknown message types should be ignored without disconnecting, as ESPHome needs to + read the full message to maintain encryption stream continuity. + """ + noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=" + process_exited = False + + def check_logs(line: str) -> None: + nonlocal process_exited + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client, + disconnect_event, + ): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + # With noise, message types are uint16_t, so we test with an unknown but valid value + frame_helper = client._connection._frame_helper + + # Test with an unknown message type (65535 is not used by ESPHome) + unknown_message_id = 65535 # Valid uint16_t but unknown to ESPHome + payload = b"test" + + # Send the unknown message type - ESPHome should read and ignore it + frame_helper.write_packets([(unknown_message_id, payload)], True) + + # Give ESPHome a moment to process (but expect no disconnection) + # The connection should stay alive as ESPHome ignores unknown message types + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(disconnect_event.wait(), timeout=0.5) + + # Connection should still be alive - unknown types are ignored, not fatal + assert client._connection.is_connected, ( + "Connection should remain open for unknown message types" + ) + + # Verify we can still communicate by sending a valid request + device_info2 = await client.device_info() + assert device_info2 is not None + assert device_info2.name == "oversized-noise" + + # After test, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + + # Verify we can still reconnect + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client2, + _, + ): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + +@pytest.mark.asyncio +async def test_noise_corrupt_encrypted_frame( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that noise protocol properly handles corrupt encrypted frames. + + Send a frame with valid size but corrupt encrypted content (garbage bytes). + This should fail decryption and cause disconnection. + """ + noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=" + process_exited = False + cipherstate_failed = False + + def check_logs(line: str) -> None: + nonlocal process_exited, cipherstate_failed + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for the expected warning about decryption failure + if ( + "[W][api.connection" in line + and "Reading failed CIPHERSTATE_DECRYPT_FAILED" in line + ): + cipherstate_failed = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client, + disconnect_event, + ): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + # Get the socket to send raw corrupt data + socket = client._connection._socket + + # Send a corrupt noise frame directly to the socket + # Format: [indicator=0x01][size_high][size_low][garbage_encrypted_data] + # Size of 32 bytes (reasonable size for a noise frame with MAC) + corrupt_frame = bytes( + [ + 0x01, # Noise indicator + 0x00, # Size high byte + 0x20, # Size low byte (32 bytes) + ] + ) + bytes(32) # 32 bytes of zeros (invalid encrypted data) + + # Send the corrupt frame + socket.sendall(corrupt_frame) + + # Wait for ESPHome to disconnect due to decryption failure + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, ( + "ESPHome process should not crash on corrupt encrypted frames" + ) + # Verify we saw the expected warning message + assert cipherstate_failed, ( + "Expected to see warning about CIPHERSTATE_DECRYPT_FAILED" + ) + + # Verify we can still reconnect after handling the corrupt frame + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client2, + _, + ): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" diff --git a/tests/integration/types.py b/tests/integration/types.py index 5e4bfaa29d..b6728a2fcb 100644 --- a/tests/integration/types.py +++ b/tests/integration/types.py @@ -54,3 +54,17 @@ class APIClientConnectedFactory(Protocol): client_info: str = "integration-test", timeout: float = 30, ) -> AbstractAsyncContextManager[APIClient]: ... + + +class APIClientConnectedWithDisconnectFactory(Protocol): + """Protocol for connected API client factory that returns disconnect event.""" + + def __call__( # noqa: E704 + self, + address: str = "localhost", + port: int | None = None, + password: str = "", + noise_psk: str | None = None, + client_info: str = "integration-test", + timeout: float = 30, + ) -> AbstractAsyncContextManager[tuple[APIClient, asyncio.Event]]: ... diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index e84ba299a2..e8d9c02524 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -94,3 +94,10 @@ def mock_run_git_command() -> Generator[Mock, None, None]: """Mock run_git_command for git module.""" with patch("esphome.git.run_git_command") as mock: yield mock + + +@pytest.fixture +def mock_get_idedata() -> Generator[Mock, None, None]: + """Mock get_idedata for platformio_api.""" + with patch("esphome.platformio_api.get_idedata") as mock: + yield mock diff --git a/tests/unit_tests/test_git.py b/tests/unit_tests/test_git.py index ebe7177bd2..6a51206ec2 100644 --- a/tests/unit_tests/test_git.py +++ b/tests/unit_tests/test_git.py @@ -15,7 +15,7 @@ def test_clone_or_update_with_never_refresh( ) -> None: """Test that NEVER_REFRESH skips updates for existing repos.""" # Set up CORE.config_path so data_dir uses tmp_path - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" # Compute the expected repo directory path url = "https://github.com/test/repo" @@ -56,7 +56,7 @@ def test_clone_or_update_with_refresh_updates_old_repo( ) -> None: """Test that refresh triggers update for old repos.""" # Set up CORE.config_path so data_dir uses tmp_path - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" # Compute the expected repo directory path url = "https://github.com/test/repo" @@ -110,7 +110,7 @@ def test_clone_or_update_with_refresh_skips_fresh_repo( ) -> None: """Test that refresh doesn't update fresh repos.""" # Set up CORE.config_path so data_dir uses tmp_path - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" # Compute the expected repo directory path url = "https://github.com/test/repo" @@ -156,7 +156,7 @@ def test_clone_or_update_clones_missing_repo( ) -> None: """Test that missing repos are cloned regardless of refresh setting.""" # Set up CORE.config_path so data_dir uses tmp_path - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" # Compute the expected repo directory path url = "https://github.com/test/repo" @@ -198,7 +198,7 @@ def test_clone_or_update_with_none_refresh_always_updates( ) -> None: """Test that refresh=None always updates existing repos.""" # Set up CORE.config_path so data_dir uses tmp_path - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" # Compute the expected repo directory path url = "https://github.com/test/repo" diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index da280b1fd8..bb047d063c 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -12,6 +12,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from pytest import CaptureFixture +from esphome import platformio_api from esphome.__main__ import ( Purpose, choose_upload_log_host, @@ -28,7 +29,9 @@ from esphome.__main__ import ( mqtt_get_ip, show_logs, upload_program, + upload_using_esptool, ) +from esphome.components.esp32.const import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32 from esphome.const import ( CONF_API, CONF_BROKER, @@ -220,6 +223,14 @@ def mock_run_external_process() -> Generator[Mock]: yield mock +@pytest.fixture +def mock_run_external_command() -> Generator[Mock]: + """Mock run_external_command for testing.""" + with patch("esphome.__main__.run_external_command") as mock: + mock.return_value = 0 # Default to success + yield mock + + def test_choose_upload_log_host_with_string_default() -> None: """Test with a single string default device.""" setup_core() @@ -818,6 +829,122 @@ def test_upload_program_serial_esp8266_with_file( ) +def test_upload_using_esptool_path_conversion( + tmp_path: Path, + mock_run_external_command: Mock, + mock_get_idedata: Mock, +) -> None: + """Test upload_using_esptool properly converts Path objects to strings for esptool. + + This test ensures that img.path (Path object) is converted to string before + passing to esptool, preventing AttributeError. + """ + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path, name="test") + + # Set up ESP32-specific data required by get_esp32_variant() + CORE.data[KEY_ESP32] = {KEY_VARIANT: VARIANT_ESP32} + + # Create mock IDEData with Path objects + mock_idedata = MagicMock(spec=platformio_api.IDEData) + mock_idedata.firmware_bin_path = tmp_path / "firmware.bin" + mock_idedata.extra_flash_images = [ + platformio_api.FlashImage(path=tmp_path / "bootloader.bin", offset="0x1000"), + platformio_api.FlashImage(path=tmp_path / "partitions.bin", offset="0x8000"), + ] + + mock_get_idedata.return_value = mock_idedata + + # Create the actual firmware files so they exist + (tmp_path / "firmware.bin").touch() + (tmp_path / "bootloader.bin").touch() + (tmp_path / "partitions.bin").touch() + + config = {CONF_ESPHOME: {"platformio_options": {}}} + + # Call upload_using_esptool without custom file argument + result = upload_using_esptool(config, "/dev/ttyUSB0", None, None) + + assert result == 0 + + # Verify that run_external_command was called + assert mock_run_external_command.call_count == 1 + + # Get the actual call arguments + call_args = mock_run_external_command.call_args[0] + + # The first argument should be esptool.main function, + # followed by the command arguments + assert len(call_args) > 1 + + # Find the indices of the flash image arguments + # They should come after "write-flash" and "-z" + cmd_list = list(call_args[1:]) # Skip the esptool.main function + + # Verify all paths are strings, not Path objects + # The firmware and flash images should be at specific positions + write_flash_idx = cmd_list.index("write-flash") + + # After write-flash we have: -z, --flash-size, detect, then offset/path pairs + # Check firmware at offset 0x10000 (ESP32) + firmware_offset_idx = write_flash_idx + 4 + assert cmd_list[firmware_offset_idx] == "0x10000" + firmware_path = cmd_list[firmware_offset_idx + 1] + assert isinstance(firmware_path, str) + assert firmware_path.endswith("firmware.bin") + + # Check bootloader + bootloader_offset_idx = firmware_offset_idx + 2 + assert cmd_list[bootloader_offset_idx] == "0x1000" + bootloader_path = cmd_list[bootloader_offset_idx + 1] + assert isinstance(bootloader_path, str) + assert bootloader_path.endswith("bootloader.bin") + + # Check partitions + partitions_offset_idx = bootloader_offset_idx + 2 + assert cmd_list[partitions_offset_idx] == "0x8000" + partitions_path = cmd_list[partitions_offset_idx + 1] + assert isinstance(partitions_path, str) + assert partitions_path.endswith("partitions.bin") + + +def test_upload_using_esptool_with_file_path( + tmp_path: Path, + mock_run_external_command: Mock, +) -> None: + """Test upload_using_esptool with a custom file that's a Path object.""" + setup_core(platform=PLATFORM_ESP8266, tmp_path=tmp_path, name="test") + + # Create a test firmware file + firmware_file = tmp_path / "custom_firmware.bin" + firmware_file.touch() + + config = {CONF_ESPHOME: {"platformio_options": {}}} + + # Call with a Path object as the file argument (though usually it's a string) + result = upload_using_esptool(config, "/dev/ttyUSB0", str(firmware_file), None) + + assert result == 0 + + # Verify that run_external_command was called + mock_run_external_command.assert_called_once() + + # Get the actual call arguments + call_args = mock_run_external_command.call_args[0] + cmd_list = list(call_args[1:]) # Skip the esptool.main function + + # Find the firmware path in the command + write_flash_idx = cmd_list.index("write-flash") + + # For custom file, it should be at offset 0x0 + firmware_offset_idx = write_flash_idx + 4 + assert cmd_list[firmware_offset_idx] == "0x0" + firmware_path = cmd_list[firmware_offset_idx + 1] + + # Verify it's a string, not a Path object + assert isinstance(firmware_path, str) + assert firmware_path.endswith("custom_firmware.bin") + + @pytest.mark.parametrize( "platform,device", [