mirror of
https://github.com/esphome/esphome.git
synced 2025-09-23 13:42:27 +01:00
preen
This commit is contained in:
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
import gzip
|
||||
import hashlib
|
||||
import io
|
||||
from pathlib import Path
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
import struct
|
||||
from unittest.mock import MagicMock, Mock, call, patch
|
||||
|
||||
import pytest
|
||||
@@ -16,20 +17,18 @@ from pytest import CaptureFixture
|
||||
from esphome import espota2
|
||||
from esphome.core import EsphomeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unittest.mock import Mock as MockType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_socket() -> MockType:
|
||||
def mock_socket() -> Mock:
|
||||
"""Create a mock socket for testing."""
|
||||
socket = Mock()
|
||||
socket.close = Mock()
|
||||
socket.recv = Mock()
|
||||
socket.sendall = Mock()
|
||||
socket.settimeout = Mock()
|
||||
socket.connect = Mock()
|
||||
return socket
|
||||
socket_mock = Mock()
|
||||
socket_mock.close = Mock()
|
||||
socket_mock.recv = Mock()
|
||||
socket_mock.sendall = Mock()
|
||||
socket_mock.settimeout = Mock()
|
||||
socket_mock.connect = Mock()
|
||||
socket_mock.setsockopt = Mock()
|
||||
return socket_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -39,21 +38,21 @@ def mock_file() -> io.BytesIO:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
def mock_time() -> Generator[None]:
|
||||
"""Mock time-related functions for consistent testing."""
|
||||
with patch("time.sleep"), patch("time.perf_counter", side_effect=[0, 1]):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_random():
|
||||
def mock_random() -> Generator[Mock]:
|
||||
"""Mock random for predictable test values."""
|
||||
with patch("random.random", return_value=0.123456) as mock_rand:
|
||||
yield mock_rand
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_resolve_ip():
|
||||
def mock_resolve_ip() -> Generator[Mock]:
|
||||
"""Mock resolve_ip_address for testing."""
|
||||
with patch("esphome.espota2.resolve_ip_address") as mock:
|
||||
mock.return_value = [
|
||||
@@ -63,14 +62,14 @@ def mock_resolve_ip():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_perform_ota():
|
||||
def mock_perform_ota() -> Generator[Mock]:
|
||||
"""Mock perform_ota function for testing."""
|
||||
with patch("esphome.espota2.perform_ota") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_run_ota_impl():
|
||||
def mock_run_ota_impl() -> Generator[Mock]:
|
||||
"""Mock run_ota_impl_ function for testing."""
|
||||
with patch("esphome.espota2.run_ota_impl_") as mock:
|
||||
mock.return_value = (0, "192.168.1.100")
|
||||
@@ -78,7 +77,7 @@ def mock_run_ota_impl():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_open_file():
|
||||
def mock_open_file() -> Generator[tuple[Mock, MagicMock]]:
|
||||
"""Mock file opening for testing."""
|
||||
with patch("builtins.open", create=True) as mock_open:
|
||||
mock_file = MagicMock()
|
||||
@@ -86,7 +85,7 @@ def mock_open_file():
|
||||
yield mock_open, mock_file
|
||||
|
||||
|
||||
def test_recv_decode_with_decode(mock_socket) -> None:
|
||||
def test_recv_decode_with_decode(mock_socket: Mock) -> None:
|
||||
"""Test recv_decode with decode=True returns list."""
|
||||
mock_socket.recv.return_value = b"\x01\x02\x03"
|
||||
|
||||
@@ -96,7 +95,7 @@ def test_recv_decode_with_decode(mock_socket) -> None:
|
||||
mock_socket.recv.assert_called_once_with(3)
|
||||
|
||||
|
||||
def test_recv_decode_without_decode(mock_socket) -> None:
|
||||
def test_recv_decode_without_decode(mock_socket: Mock) -> None:
|
||||
"""Test recv_decode with decode=False returns bytes."""
|
||||
mock_socket.recv.return_value = b"\x01\x02\x03"
|
||||
|
||||
@@ -106,7 +105,7 @@ def test_recv_decode_without_decode(mock_socket) -> None:
|
||||
mock_socket.recv.assert_called_once_with(3)
|
||||
|
||||
|
||||
def test_receive_exactly_success(mock_socket) -> None:
|
||||
def test_receive_exactly_success(mock_socket: Mock) -> None:
|
||||
"""Test receive_exactly successfully receives expected data."""
|
||||
mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"]
|
||||
|
||||
@@ -116,7 +115,7 @@ def test_receive_exactly_success(mock_socket) -> None:
|
||||
assert mock_socket.recv.call_count == 2
|
||||
|
||||
|
||||
def test_receive_exactly_with_error_response(mock_socket) -> None:
|
||||
def test_receive_exactly_with_error_response(mock_socket: Mock) -> None:
|
||||
"""Test receive_exactly raises OTAError on error response."""
|
||||
mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID])
|
||||
|
||||
@@ -126,7 +125,7 @@ def test_receive_exactly_with_error_response(mock_socket) -> None:
|
||||
mock_socket.close.assert_called_once()
|
||||
|
||||
|
||||
def test_receive_exactly_socket_error(mock_socket) -> None:
|
||||
def test_receive_exactly_socket_error(mock_socket: Mock) -> None:
|
||||
"""Test receive_exactly handles socket errors."""
|
||||
mock_socket.recv.side_effect = OSError("Connection reset")
|
||||
|
||||
@@ -185,7 +184,7 @@ def test_check_error_unexpected_response() -> None:
|
||||
espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK])
|
||||
|
||||
|
||||
def test_send_check_with_various_data_types(mock_socket) -> None:
|
||||
def test_send_check_with_various_data_types(mock_socket: Mock) -> None:
|
||||
"""Test send_check handles different data types."""
|
||||
|
||||
# Test with list/tuple
|
||||
@@ -205,7 +204,7 @@ def test_send_check_with_various_data_types(mock_socket) -> None:
|
||||
mock_socket.sendall.assert_called_with(b"\xaa\xbb")
|
||||
|
||||
|
||||
def test_send_check_socket_error(mock_socket) -> None:
|
||||
def test_send_check_socket_error(mock_socket: Mock) -> None:
|
||||
"""Test send_check handles socket errors."""
|
||||
mock_socket.sendall.side_effect = OSError("Broken pipe")
|
||||
|
||||
@@ -213,8 +212,9 @@ def test_send_check_socket_error(mock_socket) -> None:
|
||||
espota2.send_check(mock_socket, b"data", "test")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_successful_md5_auth(
|
||||
mock_socket, mock_file, mock_time, mock_random
|
||||
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||
) -> None:
|
||||
"""Test successful OTA with MD5 authentication."""
|
||||
# Setup socket responses for recv calls
|
||||
@@ -263,7 +263,8 @@ def test_perform_ota_successful_md5_auth(
|
||||
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
|
||||
|
||||
|
||||
def test_perform_ota_no_auth(mock_socket, mock_file, mock_time) -> None:
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_no_auth(mock_socket: Mock, mock_file: io.BytesIO) -> None:
|
||||
"""Test OTA without authentication."""
|
||||
recv_responses = [
|
||||
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||
@@ -289,7 +290,8 @@ def test_perform_ota_no_auth(mock_socket, mock_file, mock_time) -> None:
|
||||
assert len(auth_calls) == 0
|
||||
|
||||
|
||||
def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_with_compression(mock_socket: Mock) -> None:
|
||||
"""Test OTA with compression support."""
|
||||
original_content = b"firmware" * 100 # Repeating content for compression
|
||||
mock_file = io.BytesIO(original_content)
|
||||
@@ -312,12 +314,7 @@ def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
|
||||
# Verify compressed content was sent
|
||||
# Get the binary size that was sent (4 bytes after features)
|
||||
size_bytes = mock_socket.sendall.call_args_list[2][0][0]
|
||||
sent_size = (
|
||||
(size_bytes[0] << 24)
|
||||
| (size_bytes[1] << 16)
|
||||
| (size_bytes[2] << 8)
|
||||
| size_bytes[3]
|
||||
)
|
||||
sent_size = struct.unpack(">I", size_bytes)[0]
|
||||
|
||||
# Size should be less than original due to compression
|
||||
assert sent_size < len(original_content)
|
||||
@@ -327,7 +324,7 @@ def test_perform_ota_with_compression(mock_socket, mock_time) -> None:
|
||||
assert sent_size == len(compressed)
|
||||
|
||||
|
||||
def test_perform_ota_auth_without_password(mock_socket) -> None:
|
||||
def test_perform_ota_auth_without_password(mock_socket: Mock) -> None:
|
||||
"""Test OTA fails when auth is required but no password provided."""
|
||||
mock_file = io.BytesIO(b"firmware")
|
||||
|
||||
@@ -345,7 +342,7 @@ def test_perform_ota_auth_without_password(mock_socket) -> None:
|
||||
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||
|
||||
|
||||
def test_perform_ota_unsupported_version(mock_socket) -> None:
|
||||
def test_perform_ota_unsupported_version(mock_socket: Mock) -> None:
|
||||
"""Test OTA fails with unsupported version."""
|
||||
mock_file = io.BytesIO(b"firmware")
|
||||
|
||||
@@ -359,7 +356,8 @@ def test_perform_ota_unsupported_version(mock_socket) -> None:
|
||||
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||
|
||||
|
||||
def test_perform_ota_upload_error(mock_socket, mock_file, mock_time) -> None:
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> None:
|
||||
"""Test OTA handles upload errors."""
|
||||
# Setup responses - provide enough for the recv calls
|
||||
recv_responses = [
|
||||
@@ -380,7 +378,7 @@ def test_perform_ota_upload_error(mock_socket, mock_file, mock_time) -> None:
|
||||
|
||||
|
||||
def test_run_ota_impl_successful(
|
||||
mock_socket, tmp_path: Path, mock_resolve_ip, mock_perform_ota
|
||||
mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock, mock_perform_ota: Mock
|
||||
) -> None:
|
||||
"""Test run_ota_impl_ with successful upload."""
|
||||
# Create a real firmware file
|
||||
@@ -413,7 +411,7 @@ def test_run_ota_impl_successful(
|
||||
|
||||
|
||||
def test_run_ota_impl_connection_failed(
|
||||
mock_socket, tmp_path: Path, mock_resolve_ip
|
||||
mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock
|
||||
) -> None:
|
||||
"""Test run_ota_impl_ when connection fails."""
|
||||
mock_socket.connect.side_effect = OSError("Connection refused")
|
||||
@@ -432,7 +430,7 @@ def test_run_ota_impl_connection_failed(
|
||||
mock_socket.close.assert_called_once()
|
||||
|
||||
|
||||
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip) -> None:
|
||||
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None:
|
||||
"""Test run_ota_impl_ when DNS resolution fails."""
|
||||
# Create a real firmware file
|
||||
firmware_file = tmp_path / "firmware.bin"
|
||||
@@ -446,7 +444,7 @@ def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip) -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_run_ota_wrapper(mock_run_ota_impl) -> None:
|
||||
def test_run_ota_wrapper(mock_run_ota_impl: Mock) -> None:
|
||||
"""Test run_ota wrapper function."""
|
||||
# Test successful case
|
||||
mock_run_ota_impl.return_value = (0, "192.168.1.100")
|
||||
@@ -494,8 +492,9 @@ def test_progress_bar(capsys: CaptureFixture[str]) -> None:
|
||||
|
||||
|
||||
# Tests for SHA256 authentication
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_successful_sha256_auth(
|
||||
mock_socket, mock_file, mock_time, mock_random
|
||||
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||
) -> None:
|
||||
"""Test successful OTA with SHA256 authentication."""
|
||||
# Setup socket responses for recv calls
|
||||
@@ -546,8 +545,9 @@ def test_perform_ota_successful_sha256_auth(
|
||||
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_sha256_fallback_to_md5(
|
||||
mock_socket, mock_file, mock_time, mock_random
|
||||
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||
) -> None:
|
||||
"""Test SHA256-capable client falls back to MD5 for compatibility."""
|
||||
# This test verifies the temporary backward compatibility
|
||||
@@ -594,7 +594,10 @@ def test_perform_ota_sha256_fallback_to_md5(
|
||||
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
|
||||
|
||||
|
||||
def test_perform_ota_version_differences(mock_socket, mock_file, mock_time) -> None:
|
||||
@pytest.mark.usefixtures("mock_time")
|
||||
def test_perform_ota_version_differences(
|
||||
mock_socket: Mock, mock_file: io.BytesIO
|
||||
) -> None:
|
||||
"""Test OTA behavior differences between version 1.0 and 2.0."""
|
||||
# Test version 1.0 - no chunk acknowledgments
|
||||
recv_responses = [
|
||||
|
Reference in New Issue
Block a user