1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-18 17:53:47 +01:00
This commit is contained in:
J. Nick Koston
2025-09-21 09:52:39 -06:00
parent 139577f96a
commit 6c8b66df96

View File

@@ -14,9 +14,20 @@ from esphome import espota2
from esphome.core import EsphomeError
def test_recv_decode_with_decode() -> None:
@pytest.fixture
def mock_socket():
"""Create a mock socket for testing."""
socket = Mock()
socket.close = Mock()
socket.recv = Mock()
socket.sendall = Mock()
socket.settimeout = Mock()
socket.connect = Mock()
return socket
def test_recv_decode_with_decode(mock_socket) -> None:
"""Test recv_decode with decode=True returns list."""
mock_socket = Mock()
mock_socket.recv.return_value = b"\x01\x02\x03"
result = espota2.recv_decode(mock_socket, 3, decode=True)
@@ -25,9 +36,8 @@ def test_recv_decode_with_decode() -> None:
mock_socket.recv.assert_called_once_with(3)
def test_recv_decode_without_decode() -> None:
def test_recv_decode_without_decode(mock_socket) -> None:
"""Test recv_decode with decode=False returns bytes."""
mock_socket = Mock()
mock_socket.recv.return_value = b"\x01\x02\x03"
result = espota2.recv_decode(mock_socket, 3, decode=False)
@@ -36,9 +46,8 @@ def test_recv_decode_without_decode() -> None:
mock_socket.recv.assert_called_once_with(3)
def test_receive_exactly_success() -> None:
def test_receive_exactly_success(mock_socket) -> None:
"""Test receive_exactly successfully receives expected data."""
mock_socket = Mock()
mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"]
result = espota2.receive_exactly(mock_socket, 3, "test", espota2.RESPONSE_OK)
@@ -47,11 +56,9 @@ def test_receive_exactly_success() -> None:
assert mock_socket.recv.call_count == 2
def test_receive_exactly_with_error_response() -> None:
def test_receive_exactly_with_error_response(mock_socket) -> None:
"""Test receive_exactly raises OTAError on error response."""
mock_socket = Mock()
mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID])
mock_socket.close = Mock()
with pytest.raises(espota2.OTAError, match="Error auth:.*Authentication invalid"):
espota2.receive_exactly(mock_socket, 1, "auth", [espota2.RESPONSE_OK])
@@ -59,9 +66,8 @@ def test_receive_exactly_with_error_response() -> None:
mock_socket.close.assert_called_once()
def test_receive_exactly_socket_error() -> None:
def test_receive_exactly_socket_error(mock_socket) -> None:
"""Test receive_exactly handles socket errors."""
mock_socket = Mock()
mock_socket.recv.side_effect = OSError("Connection reset")
with pytest.raises(espota2.OTAError, match="Error receiving acknowledge test"):
@@ -119,9 +125,8 @@ def test_check_error_unexpected_response() -> None:
espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK])
def test_send_check_with_various_data_types() -> None:
def test_send_check_with_various_data_types(mock_socket) -> None:
"""Test send_check handles different data types."""
mock_socket = Mock()
# Test with list/tuple
espota2.send_check(mock_socket, [0x01, 0x02], "list")
@@ -140,18 +145,16 @@ def test_send_check_with_various_data_types() -> None:
mock_socket.sendall.assert_called_with(b"\xaa\xbb")
def test_send_check_socket_error() -> None:
def test_send_check_socket_error(mock_socket) -> None:
"""Test send_check handles socket errors."""
mock_socket = Mock()
mock_socket.sendall.side_effect = OSError("Broken pipe")
with pytest.raises(espota2.OTAError, match="Error sending test"):
espota2.send_check(mock_socket, b"data", "test")
def test_perform_ota_successful_md5_auth() -> None:
def test_perform_ota_successful_md5_auth(mock_socket) -> None:
"""Test successful OTA with MD5 authentication."""
mock_socket = Mock()
mock_file = io.BytesIO(b"firmware content here")
# Mock random for predictable cnonce
@@ -183,9 +186,14 @@ def test_perform_ota_successful_md5_auth() -> None:
# Verify magic bytes were sent
assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES))
# Verify features were sent
# Verify features were sent (compression + SHA256 support)
assert mock_socket.sendall.call_args_list[1] == call(
bytes([espota2.FEATURE_SUPPORTS_COMPRESSION])
bytes(
[
espota2.FEATURE_SUPPORTS_COMPRESSION
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
]
)
)
# Verify cnonce was sent (MD5 of random.random())
@@ -201,9 +209,8 @@ def test_perform_ota_successful_md5_auth() -> None:
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
def test_perform_ota_no_auth() -> None:
def test_perform_ota_no_auth(mock_socket) -> None:
"""Test OTA without authentication."""
mock_socket = Mock()
mock_file = io.BytesIO(b"firmware")
with patch("time.sleep"), patch("time.perf_counter", side_effect=[0, 1]):
@@ -231,9 +238,8 @@ def test_perform_ota_no_auth() -> None:
assert len(auth_calls) == 0
def test_perform_ota_with_compression() -> None:
def test_perform_ota_with_compression(mock_socket) -> None:
"""Test OTA with compression support."""
mock_socket = Mock()
original_content = b"firmware" * 100 # Repeating content for compression
mock_file = io.BytesIO(original_content)
@@ -274,9 +280,8 @@ def test_perform_ota_with_compression() -> None:
assert sent_size == len(compressed)
def test_perform_ota_auth_without_password() -> None:
def test_perform_ota_auth_without_password(mock_socket) -> None:
"""Test OTA fails when auth is required but no password provided."""
mock_socket = Mock()
mock_file = io.BytesIO(b"firmware")
responses = [
@@ -293,9 +298,8 @@ def test_perform_ota_auth_without_password() -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_unsupported_version() -> None:
def test_perform_ota_unsupported_version(mock_socket) -> None:
"""Test OTA fails with unsupported version."""
mock_socket = Mock()
mock_file = io.BytesIO(b"firmware")
responses = [
@@ -308,9 +312,8 @@ def test_perform_ota_unsupported_version() -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_upload_error() -> None:
def test_perform_ota_upload_error(mock_socket) -> None:
"""Test OTA handles upload errors."""
mock_socket = Mock()
mock_file = io.BytesIO(b"firmware")
with patch("time.perf_counter", side_effect=[0, 1]):
@@ -511,9 +514,8 @@ def test_perform_ota_sha256_fallback_to_md5() -> None:
pass # Implementation depends on final PR merge
def test_perform_ota_version_differences() -> None:
def test_perform_ota_version_differences(mock_socket) -> None:
"""Test OTA behavior differences between version 1.0 and 2.0."""
mock_socket = Mock()
mock_file = io.BytesIO(b"firmware")
with patch("time.sleep"), patch("time.perf_counter", side_effect=[0, 1]):