1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-23 13:42:27 +01:00
This commit is contained in:
J. Nick Koston
2025-09-21 10:09:35 -06:00
parent 17704f712e
commit eee8b11119

View File

@@ -2,12 +2,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator
import gzip import gzip
import hashlib import hashlib
import io import io
from pathlib import Path from pathlib import Path
import socket import socket
from typing import TYPE_CHECKING import struct
from unittest.mock import MagicMock, Mock, call, patch from unittest.mock import MagicMock, Mock, call, patch
import pytest import pytest
@@ -16,20 +17,18 @@ from pytest import CaptureFixture
from esphome import espota2 from esphome import espota2
from esphome.core import EsphomeError from esphome.core import EsphomeError
if TYPE_CHECKING:
from unittest.mock import Mock as MockType
@pytest.fixture @pytest.fixture
def mock_socket() -> MockType: def mock_socket() -> Mock:
"""Create a mock socket for testing.""" """Create a mock socket for testing."""
socket = Mock() socket_mock = Mock()
socket.close = Mock() socket_mock.close = Mock()
socket.recv = Mock() socket_mock.recv = Mock()
socket.sendall = Mock() socket_mock.sendall = Mock()
socket.settimeout = Mock() socket_mock.settimeout = Mock()
socket.connect = Mock() socket_mock.connect = Mock()
return socket socket_mock.setsockopt = Mock()
return socket_mock
@pytest.fixture @pytest.fixture
@@ -39,21 +38,21 @@ def mock_file() -> io.BytesIO:
@pytest.fixture @pytest.fixture
def mock_time(): def mock_time() -> Generator[None]:
"""Mock time-related functions for consistent testing.""" """Mock time-related functions for consistent testing."""
with patch("time.sleep"), patch("time.perf_counter", side_effect=[0, 1]): with patch("time.sleep"), patch("time.perf_counter", side_effect=[0, 1]):
yield yield
@pytest.fixture @pytest.fixture
def mock_random(): def mock_random() -> Generator[Mock]:
"""Mock random for predictable test values.""" """Mock random for predictable test values."""
with patch("random.random", return_value=0.123456) as mock_rand: with patch("random.random", return_value=0.123456) as mock_rand:
yield mock_rand yield mock_rand
@pytest.fixture @pytest.fixture
def mock_resolve_ip(): def mock_resolve_ip() -> Generator[Mock]:
"""Mock resolve_ip_address for testing.""" """Mock resolve_ip_address for testing."""
with patch("esphome.espota2.resolve_ip_address") as mock: with patch("esphome.espota2.resolve_ip_address") as mock:
mock.return_value = [ mock.return_value = [
@@ -63,14 +62,14 @@ def mock_resolve_ip():
@pytest.fixture @pytest.fixture
def mock_perform_ota(): def mock_perform_ota() -> Generator[Mock]:
"""Mock perform_ota function for testing.""" """Mock perform_ota function for testing."""
with patch("esphome.espota2.perform_ota") as mock: with patch("esphome.espota2.perform_ota") as mock:
yield mock yield mock
@pytest.fixture @pytest.fixture
def mock_run_ota_impl(): def mock_run_ota_impl() -> Generator[Mock]:
"""Mock run_ota_impl_ function for testing.""" """Mock run_ota_impl_ function for testing."""
with patch("esphome.espota2.run_ota_impl_") as mock: with patch("esphome.espota2.run_ota_impl_") as mock:
mock.return_value = (0, "192.168.1.100") mock.return_value = (0, "192.168.1.100")
@@ -78,7 +77,7 @@ def mock_run_ota_impl():
@pytest.fixture @pytest.fixture
def mock_open_file(): def mock_open_file() -> Generator[tuple[Mock, MagicMock]]:
"""Mock file opening for testing.""" """Mock file opening for testing."""
with patch("builtins.open", create=True) as mock_open: with patch("builtins.open", create=True) as mock_open:
mock_file = MagicMock() mock_file = MagicMock()
@@ -86,7 +85,7 @@ def mock_open_file():
yield mock_open, mock_file yield mock_open, mock_file
def test_recv_decode_with_decode(mock_socket) -> None: def test_recv_decode_with_decode(mock_socket: Mock) -> None:
"""Test recv_decode with decode=True returns list.""" """Test recv_decode with decode=True returns list."""
mock_socket.recv.return_value = b"\x01\x02\x03" mock_socket.recv.return_value = b"\x01\x02\x03"
@@ -96,7 +95,7 @@ def test_recv_decode_with_decode(mock_socket) -> None:
mock_socket.recv.assert_called_once_with(3) mock_socket.recv.assert_called_once_with(3)
def test_recv_decode_without_decode(mock_socket) -> None: def test_recv_decode_without_decode(mock_socket: Mock) -> None:
"""Test recv_decode with decode=False returns bytes.""" """Test recv_decode with decode=False returns bytes."""
mock_socket.recv.return_value = b"\x01\x02\x03" mock_socket.recv.return_value = b"\x01\x02\x03"
@@ -106,7 +105,7 @@ def test_recv_decode_without_decode(mock_socket) -> None:
mock_socket.recv.assert_called_once_with(3) mock_socket.recv.assert_called_once_with(3)
def test_receive_exactly_success(mock_socket) -> None: def test_receive_exactly_success(mock_socket: Mock) -> None:
"""Test receive_exactly successfully receives expected data.""" """Test receive_exactly successfully receives expected data."""
mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"] mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"]
@@ -116,7 +115,7 @@ def test_receive_exactly_success(mock_socket) -> None:
assert mock_socket.recv.call_count == 2 assert mock_socket.recv.call_count == 2
def test_receive_exactly_with_error_response(mock_socket) -> None: def test_receive_exactly_with_error_response(mock_socket: Mock) -> None:
"""Test receive_exactly raises OTAError on error response.""" """Test receive_exactly raises OTAError on error response."""
mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]) mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID])
@@ -126,7 +125,7 @@ def test_receive_exactly_with_error_response(mock_socket) -> None:
mock_socket.close.assert_called_once() mock_socket.close.assert_called_once()
def test_receive_exactly_socket_error(mock_socket) -> None: def test_receive_exactly_socket_error(mock_socket: Mock) -> None:
"""Test receive_exactly handles socket errors.""" """Test receive_exactly handles socket errors."""
mock_socket.recv.side_effect = OSError("Connection reset") mock_socket.recv.side_effect = OSError("Connection reset")
@@ -185,7 +184,7 @@ def test_check_error_unexpected_response() -> None:
espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK]) espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK])
def test_send_check_with_various_data_types(mock_socket) -> None: def test_send_check_with_various_data_types(mock_socket: Mock) -> None:
"""Test send_check handles different data types.""" """Test send_check handles different data types."""
# Test with list/tuple # Test with list/tuple
@@ -205,7 +204,7 @@ def test_send_check_with_various_data_types(mock_socket) -> None:
mock_socket.sendall.assert_called_with(b"\xaa\xbb") mock_socket.sendall.assert_called_with(b"\xaa\xbb")
def test_send_check_socket_error(mock_socket) -> None: def test_send_check_socket_error(mock_socket: Mock) -> None:
"""Test send_check handles socket errors.""" """Test send_check handles socket errors."""
mock_socket.sendall.side_effect = OSError("Broken pipe") mock_socket.sendall.side_effect = OSError("Broken pipe")
@@ -213,8 +212,9 @@ def test_send_check_socket_error(mock_socket) -> None:
espota2.send_check(mock_socket, b"data", "test") espota2.send_check(mock_socket, b"data", "test")
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_successful_md5_auth( def test_perform_ota_successful_md5_auth(
mock_socket, mock_file, mock_time, mock_random mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None: ) -> None:
"""Test successful OTA with MD5 authentication.""" """Test successful OTA with MD5 authentication."""
# Setup socket responses for recv calls # Setup socket responses for recv calls
@@ -263,7 +263,8 @@ def test_perform_ota_successful_md5_auth(
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode()) assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
def test_perform_ota_no_auth(mock_socket, mock_file, mock_time) -> None: @pytest.mark.usefixtures("mock_time")
def test_perform_ota_no_auth(mock_socket: Mock, mock_file: io.BytesIO) -> None:
"""Test OTA without authentication.""" """Test OTA without authentication."""
recv_responses = [ recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response bytes([espota2.RESPONSE_OK]), # First byte of version response
@@ -289,7 +290,8 @@ def test_perform_ota_no_auth(mock_socket, mock_file, mock_time) -> None:
assert len(auth_calls) == 0 assert len(auth_calls) == 0
def test_perform_ota_with_compression(mock_socket, mock_time) -> None: @pytest.mark.usefixtures("mock_time")
def test_perform_ota_with_compression(mock_socket: Mock) -> None:
"""Test OTA with compression support.""" """Test OTA with compression support."""
original_content = b"firmware" * 100 # Repeating content for compression original_content = b"firmware" * 100 # Repeating content for compression
mock_file = io.BytesIO(original_content) mock_file = io.BytesIO(original_content)
@@ -312,12 +314,7 @@ def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
# Verify compressed content was sent # Verify compressed content was sent
# Get the binary size that was sent (4 bytes after features) # Get the binary size that was sent (4 bytes after features)
size_bytes = mock_socket.sendall.call_args_list[2][0][0] size_bytes = mock_socket.sendall.call_args_list[2][0][0]
sent_size = ( sent_size = struct.unpack(">I", size_bytes)[0]
(size_bytes[0] << 24)
| (size_bytes[1] << 16)
| (size_bytes[2] << 8)
| size_bytes[3]
)
# Size should be less than original due to compression # Size should be less than original due to compression
assert sent_size < len(original_content) assert sent_size < len(original_content)
@@ -327,7 +324,7 @@ def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
assert sent_size == len(compressed) assert sent_size == len(compressed)
def test_perform_ota_auth_without_password(mock_socket) -> None: def test_perform_ota_auth_without_password(mock_socket: Mock) -> None:
"""Test OTA fails when auth is required but no password provided.""" """Test OTA fails when auth is required but no password provided."""
mock_file = io.BytesIO(b"firmware") mock_file = io.BytesIO(b"firmware")
@@ -345,7 +342,7 @@ def test_perform_ota_auth_without_password(mock_socket) -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin") espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_unsupported_version(mock_socket) -> None: def test_perform_ota_unsupported_version(mock_socket: Mock) -> None:
"""Test OTA fails with unsupported version.""" """Test OTA fails with unsupported version."""
mock_file = io.BytesIO(b"firmware") mock_file = io.BytesIO(b"firmware")
@@ -359,7 +356,8 @@ def test_perform_ota_unsupported_version(mock_socket) -> None:
espota2.perform_ota(mock_socket, "", mock_file, "test.bin") espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_upload_error(mock_socket, mock_file, mock_time) -> None: @pytest.mark.usefixtures("mock_time")
def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> None:
"""Test OTA handles upload errors.""" """Test OTA handles upload errors."""
# Setup responses - provide enough for the recv calls # Setup responses - provide enough for the recv calls
recv_responses = [ recv_responses = [
@@ -380,7 +378,7 @@ def test_perform_ota_upload_error(mock_socket, mock_file, mock_time) -> None:
def test_run_ota_impl_successful( def test_run_ota_impl_successful(
mock_socket, tmp_path: Path, mock_resolve_ip, mock_perform_ota mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock, mock_perform_ota: Mock
) -> None: ) -> None:
"""Test run_ota_impl_ with successful upload.""" """Test run_ota_impl_ with successful upload."""
# Create a real firmware file # Create a real firmware file
@@ -413,7 +411,7 @@ def test_run_ota_impl_successful(
def test_run_ota_impl_connection_failed( def test_run_ota_impl_connection_failed(
mock_socket, tmp_path: Path, mock_resolve_ip mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock
) -> None: ) -> None:
"""Test run_ota_impl_ when connection fails.""" """Test run_ota_impl_ when connection fails."""
mock_socket.connect.side_effect = OSError("Connection refused") mock_socket.connect.side_effect = OSError("Connection refused")
@@ -432,7 +430,7 @@ def test_run_ota_impl_connection_failed(
mock_socket.close.assert_called_once() mock_socket.close.assert_called_once()
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip) -> None: def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None:
"""Test run_ota_impl_ when DNS resolution fails.""" """Test run_ota_impl_ when DNS resolution fails."""
# Create a real firmware file # Create a real firmware file
firmware_file = tmp_path / "firmware.bin" firmware_file = tmp_path / "firmware.bin"
@@ -446,7 +444,7 @@ def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip) -> None:
) )
def test_run_ota_wrapper(mock_run_ota_impl) -> None: def test_run_ota_wrapper(mock_run_ota_impl: Mock) -> None:
"""Test run_ota wrapper function.""" """Test run_ota wrapper function."""
# Test successful case # Test successful case
mock_run_ota_impl.return_value = (0, "192.168.1.100") mock_run_ota_impl.return_value = (0, "192.168.1.100")
@@ -494,8 +492,9 @@ def test_progress_bar(capsys: CaptureFixture[str]) -> None:
# Tests for SHA256 authentication # Tests for SHA256 authentication
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_successful_sha256_auth( def test_perform_ota_successful_sha256_auth(
mock_socket, mock_file, mock_time, mock_random mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None: ) -> None:
"""Test successful OTA with SHA256 authentication.""" """Test successful OTA with SHA256 authentication."""
# Setup socket responses for recv calls # Setup socket responses for recv calls
@@ -546,8 +545,9 @@ def test_perform_ota_successful_sha256_auth(
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode()) assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_sha256_fallback_to_md5( def test_perform_ota_sha256_fallback_to_md5(
mock_socket, mock_file, mock_time, mock_random mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None: ) -> None:
"""Test SHA256-capable client falls back to MD5 for compatibility.""" """Test SHA256-capable client falls back to MD5 for compatibility."""
# This test verifies the temporary backward compatibility # This test verifies the temporary backward compatibility
@@ -594,7 +594,10 @@ def test_perform_ota_sha256_fallback_to_md5(
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode()) assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
def test_perform_ota_version_differences(mock_socket, mock_file, mock_time) -> None: @pytest.mark.usefixtures("mock_time")
def test_perform_ota_version_differences(
mock_socket: Mock, mock_file: io.BytesIO
) -> None:
"""Test OTA behavior differences between version 1.0 and 2.0.""" """Test OTA behavior differences between version 1.0 and 2.0."""
# Test version 1.0 - no chunk acknowledgments # Test version 1.0 - no chunk acknowledgments
recv_responses = [ recv_responses = [