1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-22 13:12:22 +01:00
This commit is contained in:
J. Nick Koston
2025-09-21 10:25:48 -06:00
parent 0e71662158
commit 97bc627d41
2 changed files with 41 additions and 8 deletions

View File

@@ -91,18 +91,26 @@ class OTAError(EsphomeError):
pass pass
def recv_decode(sock, amount, decode=True): def recv_decode(
sock: socket.socket, amount: int, decode: bool = True
) -> bytes | list[int]:
data = sock.recv(amount) data = sock.recv(amount)
if not decode: if not decode:
return data return data
return list(data) return list(data)
def receive_exactly(sock, amount, msg, expect, decode=True): def receive_exactly(
data = [] if decode else b"" sock: socket.socket,
amount: int,
msg: str,
expect: int | list[int] | None,
decode: bool = True,
) -> list[int] | bytes:
data: list[int] | bytes = [] if decode else b""
try: try:
data += recv_decode(sock, 1, decode=decode) data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
except OSError as err: except OSError as err:
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
@@ -114,13 +122,13 @@ def receive_exactly(sock, amount, msg, expect, decode=True):
while len(data) < amount: while len(data) < amount:
try: try:
data += recv_decode(sock, amount - len(data), decode=decode) data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
except OSError as err: except OSError as err:
raise OTAError(f"Error receiving {msg}: {err}") from err raise OTAError(f"Error receiving {msg}: {err}") from err
return data return data
def check_error(data, expect): def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
if not expect: if not expect:
return return
dat = data[0] dat = data[0]
@@ -187,7 +195,9 @@ def check_error(data, expect):
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}") raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
def send_check(sock, data, msg): def send_check(
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
) -> None:
try: try:
if isinstance(data, (list, tuple)): if isinstance(data, (list, tuple)):
data = bytes(data) data = bytes(data)
@@ -239,7 +249,7 @@ def perform_ota(
def perform_auth( def perform_auth(
sock: socket.socket, sock: socket.socket,
password: str, password: str,
hash_func: Any, hash_func: Callable[[], Any],
nonce_size: int, nonce_size: int,
hash_name: str, hash_name: str,
) -> None: ) -> None:

View File

@@ -416,6 +416,29 @@ def test_perform_ota_sha256_auth_without_password(mock_socket: Mock) -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin") espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_unexpected_auth_response(mock_socket: Mock) -> None:
"""Test OTA fails when device sends an unexpected auth response."""
mock_file = io.BytesIO(b"firmware")
# Use 0x03 which is not in the expected auth responses
# This will be caught by check_error and raise "Unexpected response from ESP"
UNKNOWN_AUTH_METHOD = 0x03
responses = [
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
bytes([espota2.RESPONSE_HEADER_OK]),
bytes([UNKNOWN_AUTH_METHOD]), # Unknown auth method
]
mock_socket.recv.side_effect = responses
# This will actually raise "Unexpected response from ESP" from check_error
with pytest.raises(
espota2.OTAError, match=r"Error auth: Unexpected response from ESP: 0x03"
):
espota2.perform_ota(mock_socket, "password", mock_file, "test.bin")
def test_perform_ota_unsupported_version(mock_socket: Mock) -> None: def test_perform_ota_unsupported_version(mock_socket: Mock) -> None:
"""Test OTA fails with unsupported version.""" """Test OTA fails with unsupported version."""
mock_file = io.BytesIO(b"firmware") mock_file = io.BytesIO(b"firmware")