diff --git a/tests/unit_tests/test_espota2.py b/tests/unit_tests/test_espota2.py index 539f4ecc42..24b6a57b63 100644 --- a/tests/unit_tests/test_espota2.py +++ b/tests/unit_tests/test_espota2.py @@ -95,6 +95,13 @@ def mock_open_file() -> Generator[tuple[Mock, MagicMock]]: yield mock_open, mock_file +@pytest.fixture +def mock_socket_constructor(mock_socket: Mock) -> Generator[Mock]: + """Mock socket.socket constructor to return our mock socket.""" + with patch("socket.socket", return_value=mock_socket) as mock_constructor: + yield mock_constructor + + def test_recv_decode_with_decode(mock_socket: Mock) -> None: """Test recv_decode with decode=True returns list.""" mock_socket.recv.return_value = b"\x01\x02\x03" @@ -387,42 +394,41 @@ def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> N espota2.perform_ota(mock_socket, "", mock_file, "test.bin") +@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip") def test_run_ota_impl_successful( - mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock, mock_perform_ota: Mock + mock_socket: Mock, tmp_path: Path, mock_perform_ota: Mock ) -> None: """Test run_ota_impl_ with successful upload.""" # Create a real firmware file firmware_file = tmp_path / "firmware.bin" firmware_file.write_bytes(b"firmware content") - with patch("socket.socket", return_value=mock_socket): - # Run OTA with real file path - result_code, result_host = espota2.run_ota_impl_( - "test.local", 3232, "password", str(firmware_file) - ) + # Run OTA with real file path + result_code, result_host = espota2.run_ota_impl_( + "test.local", 3232, "password", str(firmware_file) + ) - # Verify success - assert result_code == 0 - assert result_host == "192.168.1.100" + # Verify success + assert result_code == 0 + assert result_host == "192.168.1.100" - # Verify socket was configured correctly - mock_socket.settimeout.assert_called_with(10.0) - mock_socket.connect.assert_called_once_with(("192.168.1.100", 3232)) - mock_socket.close.assert_called_once() + # Verify socket was configured correctly + mock_socket.settimeout.assert_called_with(10.0) + mock_socket.connect.assert_called_once_with(("192.168.1.100", 3232)) + mock_socket.close.assert_called_once() - # Verify perform_ota was called with real file - mock_perform_ota.assert_called_once() - call_args = mock_perform_ota.call_args[0] - assert call_args[0] == mock_socket - assert call_args[1] == "password" - # The file object should be opened - assert hasattr(call_args[2], "read") - assert call_args[3] == str(firmware_file) + # Verify perform_ota was called with real file + mock_perform_ota.assert_called_once() + call_args = mock_perform_ota.call_args[0] + assert call_args[0] == mock_socket + assert call_args[1] == "password" + # The file object should be opened + assert hasattr(call_args[2], "read") + assert call_args[3] == str(firmware_file) -def test_run_ota_impl_connection_failed( - mock_socket: Mock, tmp_path: Path, mock_resolve_ip: Mock -) -> None: +@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip") +def test_run_ota_impl_connection_failed(mock_socket: Mock, tmp_path: Path) -> None: """Test run_ota_impl_ when connection fails.""" mock_socket.connect.side_effect = OSError("Connection refused") @@ -430,14 +436,13 @@ def test_run_ota_impl_connection_failed( firmware_file = tmp_path / "firmware.bin" firmware_file.write_bytes(b"firmware content") - with patch("socket.socket", return_value=mock_socket): - result_code, result_host = espota2.run_ota_impl_( - "test.local", 3232, "password", str(firmware_file) - ) + result_code, result_host = espota2.run_ota_impl_( + "test.local", 3232, "password", str(firmware_file) + ) - assert result_code == 1 - assert result_host is None - mock_socket.close.assert_called_once() + assert result_code == 1 + assert result_host is None + mock_socket.close.assert_called_once() def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None: