diff --git a/esphome/espota2.py b/esphome/espota2.py index 5f906e4d08..36eb4d68ea 100644 --- a/esphome/espota2.py +++ b/esphome/espota2.py @@ -91,18 +91,26 @@ class OTAError(EsphomeError): pass -def recv_decode(sock, amount, decode=True): +def recv_decode( + sock: socket.socket, amount: int, decode: bool = True +) -> bytes | list[int]: data = sock.recv(amount) if not decode: return data return list(data) -def receive_exactly(sock, amount, msg, expect, decode=True): - data = [] if decode else b"" +def receive_exactly( + sock: socket.socket, + amount: int, + msg: str, + expect: int | list[int] | None, + decode: bool = True, +) -> list[int] | bytes: + data: list[int] | bytes = [] if decode else b"" try: - data += recv_decode(sock, 1, decode=decode) + data += recv_decode(sock, 1, decode=decode) # type: ignore[operator] except OSError as err: raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err @@ -114,13 +122,13 @@ def receive_exactly(sock, amount, msg, expect, decode=True): while len(data) < amount: try: - data += recv_decode(sock, amount - len(data), decode=decode) + data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator] except OSError as err: raise OTAError(f"Error receiving {msg}: {err}") from err return data -def check_error(data, expect): +def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None: if not expect: return dat = data[0] @@ -187,7 +195,9 @@ def check_error(data, expect): raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}") -def send_check(sock, data, msg): +def send_check( + sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str +) -> None: try: if isinstance(data, (list, tuple)): data = bytes(data) @@ -239,7 +249,7 @@ def perform_ota( def perform_auth( sock: socket.socket, password: str, - hash_func: Any, + hash_func: Callable[[], Any], nonce_size: int, hash_name: str, ) -> None: diff --git a/tests/unit_tests/test_espota2.py b/tests/unit_tests/test_espota2.py index f74ca1e4e8..c036a5de8e 100644 --- a/tests/unit_tests/test_espota2.py +++ b/tests/unit_tests/test_espota2.py @@ -416,6 +416,29 @@ def test_perform_ota_sha256_auth_without_password(mock_socket: Mock) -> None: espota2.perform_ota(mock_socket, "", mock_file, "test.bin") +def test_perform_ota_unexpected_auth_response(mock_socket: Mock) -> None: + """Test OTA fails when device sends an unexpected auth response.""" + mock_file = io.BytesIO(b"firmware") + + # Use 0x03 which is not in the expected auth responses + # This will be caught by check_error and raise "Unexpected response from ESP" + UNKNOWN_AUTH_METHOD = 0x03 + + responses = [ + bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]), + bytes([espota2.RESPONSE_HEADER_OK]), + bytes([UNKNOWN_AUTH_METHOD]), # Unknown auth method + ] + + mock_socket.recv.side_effect = responses + + # This will actually raise "Unexpected response from ESP" from check_error + with pytest.raises( + espota2.OTAError, match=r"Error auth: Unexpected response from ESP: 0x03" + ): + espota2.perform_ota(mock_socket, "password", mock_file, "test.bin") + + def test_perform_ota_unsupported_version(mock_socket: Mock) -> None: """Test OTA fails with unsupported version.""" mock_file = io.BytesIO(b"firmware")