1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-23 13:42:27 +01:00
This commit is contained in:
J. Nick Koston
2025-09-21 10:09:35 -06:00
parent 17704f712e
commit eee8b11119

View File

@@ -2,12 +2,13 @@
from __future__ import annotations
from collections.abc import Generator
import gzip
import hashlib
import io
from pathlib import Path
import socket
from typing import TYPE_CHECKING
import struct
from unittest.mock import MagicMock, Mock, call, patch
import pytest
@@ -16,20 +17,18 @@ from pytest import CaptureFixture
from esphome import espota2
from esphome.core import EsphomeError
if TYPE_CHECKING:
from unittest.mock import Mock as MockType
@pytest.fixture
def mock_socket() -> MockType:
def mock_socket() -> Mock:
"""Create a mock socket for testing."""
socket = Mock()
socket.close = Mock()
socket.recv = Mock()
socket.sendall = Mock()
socket.settimeout = Mock()
socket.connect = Mock()
return socket
socket_mock = Mock()
socket_mock.close = Mock()
socket_mock.recv = Mock()
socket_mock.sendall = Mock()
socket_mock.settimeout = Mock()
socket_mock.connect = Mock()
socket_mock.setsockopt = Mock()
return socket_mock
@pytest.fixture
@@ -39,21 +38,21 @@ def mock_file() -> io.BytesIO:
@pytest.fixture
def mock_time():
def mock_time() -> Generator[None]:
"""Mock time-related functions for consistent testing."""
with patch("time.sleep"), patch("time.perf_counter", side_effect=[0, 1]):
yield
@pytest.fixture
def mock_random():
def mock_random() -> Generator[Mock]:
"""Mock random for predictable test values."""
with patch("random.random", return_value=0.123456) as mock_rand:
yield mock_rand
@pytest.fixture
def mock_resolve_ip():
def mock_resolve_ip() -> Generator[Mock]:
"""Mock resolve_ip_address for testing."""
with patch("esphome.espota2.resolve_ip_address") as mock:
mock.return_value = [
@@ -63,14 +62,14 @@ def mock_resolve_ip():
@pytest.fixture
def mock_perform_ota():
def mock_perform_ota() -> Generator[Mock]:
"""Mock perform_ota function for testing."""
with patch("esphome.espota2.perform_ota") as mock:
yield mock
@pytest.fixture
def mock_run_ota_impl():
def mock_run_ota_impl() -> Generator[Mock]:
"""Mock run_ota_impl_ function for testing."""
with patch("esphome.espota2.run_ota_impl_") as mock:
mock.return_value = (0, "192.168.1.100")
@@ -78,7 +77,7 @@ def mock_run_ota_impl():
@pytest.fixture
def mock_open_file():
def mock_open_file() -> Generator[tuple[Mock, MagicMock]]:
"""Mock file opening for testing."""
with patch("builtins.open", create=True) as mock_open:
mock_file = MagicMock()
@@ -86,7 +85,7 @@ def mock_open_file():
yield mock_open, mock_file
def test_recv_decode_with_decode(mock_socket) -> None:
def test_recv_decode_with_decode(mock_socket: Mock) -> None:
"""Test recv_decode with decode=True returns list."""
mock_socket.recv.return_value = b"\x01\x02\x03"
@@ -96,7 +95,7 @@ def test_recv_decode_with_decode(mock_socket) -> None:
mock_socket.recv.assert_called_once_with(3)
def test_recv_decode_without_decode(mock_socket) -> None:
def test_recv_decode_without_decode(mock_socket: Mock) -> None:
"""Test recv_decode with decode=False returns bytes."""
mock_socket.recv.return_value = b"\x01\x02\x03"
@@ -106,7 +105,7 @@ def test_recv_decode_without_decode(mock_socket) -> None:
mock_socket.recv.assert_called_once_with(3)
def test_receive_exactly_success(mock_socket) -> None:
def test_receive_exactly_success(mock_socket: Mock) -> None:
"""Test receive_exactly successfully receives expected data."""
mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"]
@@ -116,7 +115,7 @@ def test_receive_exactly_success(mock_socket) -> None:
assert mock_socket.recv.call_count == 2
def test_receive_exactly_with_error_response(mock_socket) -> None:
def test_receive_exactly_with_error_response(mock_socket: Mock) -> None:
"""Test receive_exactly raises OTAError on error response."""
mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID])
@@ -126,7 +125,7 @@ def test_receive_exactly_with_error_response(mock_socket) -> None:
mock_socket.close.assert_called_once()
def test_receive_exactly_socket_error(mock_socket) -> None:
def test_receive_exactly_socket_error(mock_socket: Mock) -> None:
"""Test receive_exactly handles socket errors."""
mock_socket.recv.side_effect = OSError("Connection reset")
@@ -185,7 +184,7 @@ def test_check_error_unexpected_response() -> None:
espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK])
def test_send_check_with_various_data_types(mock_socket) -> None:
def test_send_check_with_various_data_types(mock_socket: Mock) -> None:
"""Test send_check handles different data types."""
# Test with list/tuple
@@ -205,7 +204,7 @@ def test_send_check_with_various_data_types(mock_socket) -> None:
mock_socket.sendall.assert_called_with(b"\xaa\xbb")
def test_send_check_socket_error(mock_socket) -> None:
def test_send_check_socket_error(mock_socket: Mock) -> None:
"""Test send_check handles socket errors."""
mock_socket.sendall.side_effect = OSError("Broken pipe")
@@ -213,8 +212,9 @@ def test_send_check_socket_error(mock_socket) -> None:
espota2.send_check(mock_socket, b"data", "test")
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_successful_md5_auth(
mock_socket, mock_file, mock_time, mock_random
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test successful OTA with MD5 authentication."""
# Setup socket responses for recv calls
@@ -263,7 +263,8 @@ def test_perform_ota_successful_md5_auth(
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
def test_perform_ota_no_auth(mock_socket, mock_file, mock_time) -> None:
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_no_auth(mock_socket: Mock, mock_file: io.BytesIO) -> None:
"""Test OTA without authentication."""
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
@@ -289,7 +290,8 @@ def test_perform_ota_no_auth(mock_socket, mock_file, mock_time) -> None:
assert len(auth_calls) == 0
def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_with_compression(mock_socket: Mock) -> None:
"""Test OTA with compression support."""
original_content = b"firmware" * 100 # Repeating content for compression
mock_file = io.BytesIO(original_content)
@@ -312,12 +314,7 @@ def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
# Verify compressed content was sent
# Get the binary size that was sent (4 bytes after features)
size_bytes = mock_socket.sendall.call_args_list[2][0][0]
sent_size = (
(size_bytes[0] << 24)
| (size_bytes[1] << 16)
| (size_bytes[2] << 8)
| size_bytes[3]
)
sent_size = struct.unpack(">I", size_bytes)[0]
# Size should be less than original due to compression
assert sent_size < len(original_content)
@@ -327,7 +324,7 @@ def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
assert sent_size == len(compressed)
def test_perform_ota_auth_without_password(mock_socket) -> None:
def test_perform_ota_auth_without_password(mock_socket: Mock) -> None:
"""Test OTA fails when auth is required but no password provided."""
mock_file = io.BytesIO(b"firmware")
@@ -345,7 +342,7 @@ def test_perform_ota_auth_without_password(mock_socket) -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_unsupported_version(mock_socket) -> None:
def test_perform_ota_unsupported_version(mock_socket: Mock) -> None:
"""Test OTA fails with unsupported version."""
mock_file = io.BytesIO(b"firmware")
@@ -359,7 +356,8 @@ def test_perform_ota_unsupported_version(mock_socket) -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_upload_error(mock_socket, mock_file, mock_time) -> None:
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> None:
"""Test OTA handles upload errors."""
# Setup responses - provide enough for the recv calls
recv_responses = [
@@ -380,7 +378,7 @@ def test_perform_ota_upload_error(mock_socket, mock_file, mock_time) -> None:
def test_run_ota_impl_successful(
mock_socket, tmp_path: Path, mock_resolve_ip, mock_perform_ota
mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock, mock_perform_ota: Mock
) -> None:
"""Test run_ota_impl_ with successful upload."""
# Create a real firmware file
@@ -413,7 +411,7 @@ def test_run_ota_impl_successful(
def test_run_ota_impl_connection_failed(
mock_socket, tmp_path: Path, mock_resolve_ip
mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock
) -> None:
"""Test run_ota_impl_ when connection fails."""
mock_socket.connect.side_effect = OSError("Connection refused")
@@ -432,7 +430,7 @@ def test_run_ota_impl_connection_failed(
mock_socket.close.assert_called_once()
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip) -> None:
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None:
"""Test run_ota_impl_ when DNS resolution fails."""
# Create a real firmware file
firmware_file = tmp_path / "firmware.bin"
@@ -446,7 +444,7 @@ def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip) -> None:
)
def test_run_ota_wrapper(mock_run_ota_impl) -> None:
def test_run_ota_wrapper(mock_run_ota_impl: Mock) -> None:
"""Test run_ota wrapper function."""
# Test successful case
mock_run_ota_impl.return_value = (0, "192.168.1.100")
@@ -494,8 +492,9 @@ def test_progress_bar(capsys: CaptureFixture[str]) -> None:
# Tests for SHA256 authentication
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_successful_sha256_auth(
mock_socket, mock_file, mock_time, mock_random
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test successful OTA with SHA256 authentication."""
# Setup socket responses for recv calls
@@ -546,8 +545,9 @@ def test_perform_ota_successful_sha256_auth(
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_sha256_fallback_to_md5(
mock_socket, mock_file, mock_time, mock_random
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test SHA256-capable client falls back to MD5 for compatibility."""
# This test verifies the temporary backward compatibility
@@ -594,7 +594,10 @@ def test_perform_ota_sha256_fallback_to_md5(
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
def test_perform_ota_version_differences(mock_socket, mock_file, mock_time) -> None:
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_version_differences(
mock_socket: Mock, mock_file: io.BytesIO
) -> None:
"""Test OTA behavior differences between version 1.0 and 2.0."""
# Test version 1.0 - no chunk acknowledgments
recv_responses = [