diff --git a/tests/unit_tests/test_espota2.py b/tests/unit_tests/test_espota2.py index e1b3de6f97..6bfd20a0b2 100644 --- a/tests/unit_tests/test_espota2.py +++ b/tests/unit_tests/test_espota2.py @@ -54,6 +54,40 @@ def mock_random(): yield mock_rand +@pytest.fixture +def mock_resolve_ip(): + """Mock resolve_ip_address for testing.""" + with patch("esphome.espota2.resolve_ip_address") as mock: + mock.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.100", 3232)) + ] + yield mock + + +@pytest.fixture +def mock_perform_ota(): + """Mock perform_ota function for testing.""" + with patch("esphome.espota2.perform_ota") as mock: + yield mock + + +@pytest.fixture +def mock_run_ota_impl(): + """Mock run_ota_impl_ function for testing.""" + with patch("esphome.espota2.run_ota_impl_") as mock: + mock.return_value = (0, "192.168.1.100") + yield mock + + +@pytest.fixture +def mock_open_file(): + """Mock file opening for testing.""" + with patch("builtins.open", create=True) as mock_open: + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file + yield mock_open, mock_file + + def test_recv_decode_with_decode(mock_socket) -> None: """Test recv_decode with decode=True returns list.""" mock_socket.recv.return_value = b"\x01\x02\x03" @@ -365,26 +399,25 @@ def test_perform_ota_upload_error(mock_socket) -> None: espota2.perform_ota(mock_socket, "", mock_file, "test.bin") -def test_run_ota_impl_successful() -> None: +def test_run_ota_impl_successful(mock_socket, tmp_path) -> None: """Test run_ota_impl_ with successful upload.""" - mock_socket = Mock() + # Create a real firmware file + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(b"firmware content") with ( patch("socket.socket", return_value=mock_socket), patch("esphome.espota2.resolve_ip_address") as mock_resolve, - patch("builtins.open", create=True) as mock_open, patch("esphome.espota2.perform_ota") as mock_perform, ): # Setup mocks mock_resolve.return_value = [ (socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.100", 3232)) ] - mock_file = MagicMock() - mock_open.return_value.__enter__.return_value = mock_file - # Run OTA + # Run OTA with real file path result_code, result_host = espota2.run_ota_impl_( - "test.local", 3232, "password", "firmware.bin" + "test.local", 3232, "password", str(firmware_file) ) # Verify success @@ -396,17 +429,24 @@ def test_run_ota_impl_successful() -> None: mock_socket.connect.assert_called_once_with(("192.168.1.100", 3232)) mock_socket.close.assert_called_once() - # Verify perform_ota was called - mock_perform.assert_called_once_with( - mock_socket, "password", mock_file, "firmware.bin" - ) + # Verify perform_ota was called with real file + mock_perform.assert_called_once() + call_args = mock_perform.call_args[0] + assert call_args[0] == mock_socket + assert call_args[1] == "password" + # The file object should be opened + assert hasattr(call_args[2], "read") + assert call_args[3] == str(firmware_file) -def test_run_ota_impl_connection_failed() -> None: +def test_run_ota_impl_connection_failed(mock_socket, tmp_path) -> None: """Test run_ota_impl_ when connection fails.""" - mock_socket = Mock() mock_socket.connect.side_effect = OSError("Connection refused") + # Create a real firmware file + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(b"firmware content") + with ( patch("socket.socket", return_value=mock_socket), patch("esphome.espota2.resolve_ip_address") as mock_resolve, @@ -416,7 +456,7 @@ def test_run_ota_impl_connection_failed() -> None: ] result_code, result_host = espota2.run_ota_impl_( - "test.local", 3232, "password", "firmware.bin" + "test.local", 3232, "password", str(firmware_file) ) assert result_code == 1 @@ -424,14 +464,18 @@ def test_run_ota_impl_connection_failed() -> None: mock_socket.close.assert_called_once() -def test_run_ota_impl_resolve_failed() -> None: +def test_run_ota_impl_resolve_failed(tmp_path) -> None: """Test run_ota_impl_ when DNS resolution fails.""" + # Create a real firmware file + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(b"firmware content") + with patch("esphome.espota2.resolve_ip_address") as mock_resolve: mock_resolve.side_effect = EsphomeError("DNS resolution failed") with pytest.raises(espota2.OTAError, match="DNS resolution failed"): result_code, result_host = espota2.run_ota_impl_( - "unknown.host", 3232, "password", "firmware.bin" + "unknown.host", 3232, "password", str(firmware_file) ) @@ -484,9 +528,8 @@ def test_progress_bar(capsys: CaptureFixture[str]) -> None: # Tests for SHA256 authentication (for when PR is merged) -def test_perform_ota_successful_sha256_auth() -> None: +def test_perform_ota_successful_sha256_auth(mock_socket) -> None: """Test successful OTA with SHA256 authentication (future support).""" - mock_socket = Mock() # Mock random for predictable cnonce with patch("random.random", return_value=0.123456):