mirror of
https://github.com/esphome/esphome.git
synced 2025-09-23 21:52:23 +01:00
Merge branch 'sha256_ota' into integration
This commit is contained in:
@@ -140,13 +140,14 @@ async def to_code(config):
|
|||||||
var = cg.new_Pvariable(config[CONF_ID])
|
var = cg.new_Pvariable(config[CONF_ID])
|
||||||
cg.add(var.set_port(config[CONF_PORT]))
|
cg.add(var.set_port(config[CONF_PORT]))
|
||||||
|
|
||||||
# Only include SHA256 support on platforms that have it
|
|
||||||
if supports_sha256():
|
|
||||||
cg.add_define("USE_OTA_SHA256")
|
|
||||||
|
|
||||||
if CONF_PASSWORD in config:
|
if CONF_PASSWORD in config:
|
||||||
cg.add(var.set_auth_password(config[CONF_PASSWORD]))
|
cg.add(var.set_auth_password(config[CONF_PASSWORD]))
|
||||||
cg.add_define("USE_OTA_PASSWORD")
|
cg.add_define("USE_OTA_PASSWORD")
|
||||||
|
# Only include hash algorithms when password is configured
|
||||||
|
cg.add_define("USE_OTA_MD5")
|
||||||
|
# Only include SHA256 support on platforms that have it
|
||||||
|
if supports_sha256():
|
||||||
|
cg.add_define("USE_OTA_SHA256")
|
||||||
cg.add_define("USE_OTA_VERSION", config[CONF_VERSION])
|
cg.add_define("USE_OTA_VERSION", config[CONF_VERSION])
|
||||||
|
|
||||||
await cg.register_component(var, config)
|
await cg.register_component(var, config)
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
#include "ota_esphome.h"
|
#include "ota_esphome.h"
|
||||||
#ifdef USE_OTA
|
#ifdef USE_OTA
|
||||||
|
#ifdef USE_OTA_MD5
|
||||||
#include "esphome/components/md5/md5.h"
|
#include "esphome/components/md5/md5.h"
|
||||||
|
#endif
|
||||||
#ifdef USE_OTA_SHA256
|
#ifdef USE_OTA_SHA256
|
||||||
#include "esphome/components/sha256/sha256.h"
|
#include "esphome/components/sha256/sha256.h"
|
||||||
#endif
|
#endif
|
||||||
@@ -267,12 +269,14 @@ void ESPHomeOTAComponent::handle_data_() {
|
|||||||
if (client_supports_sha256) {
|
if (client_supports_sha256) {
|
||||||
sha256::SHA256 sha_hasher;
|
sha256::SHA256 sha_hasher;
|
||||||
auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, 16, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH,
|
auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, 16, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH,
|
||||||
LOG_STR("SHA256"));
|
LOG_STR("SHA256"), sbuf);
|
||||||
} else {
|
} else {
|
||||||
|
#ifdef USE_OTA_MD5
|
||||||
ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)");
|
ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)");
|
||||||
md5::MD5Digest md5_hasher;
|
md5::MD5Digest md5_hasher;
|
||||||
auth_success =
|
auth_success = this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH,
|
||||||
this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH, LOG_STR("MD5"));
|
LOG_STR("MD5"), sbuf);
|
||||||
|
#endif // USE_OTA_MD5
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// Strict mode: SHA256 required on capable platforms (future default)
|
// Strict mode: SHA256 required on capable platforms (future default)
|
||||||
@@ -281,13 +285,18 @@ void ESPHomeOTAComponent::handle_data_() {
|
|||||||
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
|
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
|
||||||
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
||||||
}
|
}
|
||||||
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
|
sha256::SHA256 sha_hasher;
|
||||||
|
auth_success = this->perform_hash_auth_(&sha_hasher, this->password_, 16, ota::OTA_RESPONSE_REQUEST_SHA256_AUTH,
|
||||||
|
LOG_STR("SHA256"), sbuf);
|
||||||
#endif // ALLOW_OTA_DOWNGRADE_MD5
|
#endif // ALLOW_OTA_DOWNGRADE_MD5
|
||||||
#else
|
#else
|
||||||
// Platform only supports MD5 - use it as the only available option
|
// Platform only supports MD5 - use it as the only available option
|
||||||
// This is not a security downgrade as the platform cannot support SHA256
|
// This is not a security downgrade as the platform cannot support SHA256
|
||||||
|
#ifdef USE_OTA_MD5
|
||||||
md5::MD5Digest md5_hasher;
|
md5::MD5Digest md5_hasher;
|
||||||
auth_success = this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH);
|
auth_success =
|
||||||
|
this->perform_hash_auth_(&md5_hasher, this->password_, 8, ota::OTA_RESPONSE_REQUEST_AUTH, LOG_STR("MD5"), sbuf);
|
||||||
|
#endif // USE_OTA_MD5
|
||||||
#endif // USE_OTA_SHA256
|
#endif // USE_OTA_SHA256
|
||||||
|
|
||||||
if (!auth_success) {
|
if (!auth_success) {
|
||||||
@@ -514,29 +523,24 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
|
|||||||
delay(1);
|
delay(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_OTA_PASSWORD
|
||||||
void ESPHomeOTAComponent::log_auth_warning_(const LogString *action, const LogString *hash_name) {
|
void ESPHomeOTAComponent::log_auth_warning_(const LogString *action, const LogString *hash_name) {
|
||||||
ESP_LOGW(TAG, "Auth: %s %s failed", LOG_STR_ARG(action), LOG_STR_ARG(hash_name));
|
ESP_LOGW(TAG, "Auth: %s %s failed", LOG_STR_ARG(action), LOG_STR_ARG(hash_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-template function definition to reduce binary size
|
// Non-template function definition to reduce binary size
|
||||||
bool ESPHomeOTAComponent::perform_hash_auth_(HashBase *hasher, const std::string &password, size_t nonce_size,
|
bool ESPHomeOTAComponent::perform_hash_auth_(HashBase *hasher, const std::string &password, size_t nonce_size,
|
||||||
uint8_t auth_request, const LogString *name) {
|
uint8_t auth_request, const LogString *name, char *buf) {
|
||||||
// Get sizes from the hasher
|
// Get sizes from the hasher
|
||||||
const size_t hex_size = hasher->get_hex_size();
|
const size_t hex_size = hasher->get_hex_size();
|
||||||
|
|
||||||
// Use fixed-size buffers for the maximum possible hash size (SHA256 = 64 chars)
|
// Use the provided buffer for all hex operations
|
||||||
// This avoids dynamic allocation overhead
|
|
||||||
static constexpr size_t MAX_HEX_SIZE = 65; // SHA256 hex + null terminator
|
|
||||||
char hex_buffer1[MAX_HEX_SIZE]; // Used for: nonce -> expected result
|
|
||||||
char hex_buffer2[MAX_HEX_SIZE]; // Used for: cnonce -> response
|
|
||||||
|
|
||||||
// Small stack buffer for auth request and nonce seed bytes
|
// Small stack buffer for nonce seed bytes
|
||||||
uint8_t buf[1];
|
|
||||||
uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256)
|
uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256)
|
||||||
|
|
||||||
// Send auth request type
|
// Send auth request type
|
||||||
buf[0] = auth_request;
|
this->writeall_(&auth_request, 1);
|
||||||
this->writeall_(buf, 1);
|
|
||||||
|
|
||||||
hasher->init();
|
hasher->init();
|
||||||
|
|
||||||
@@ -562,56 +566,55 @@ bool ESPHomeOTAComponent::perform_hash_auth_(HashBase *hasher, const std::string
|
|||||||
}
|
}
|
||||||
hasher->calculate();
|
hasher->calculate();
|
||||||
|
|
||||||
// Use hex_buffer1 for nonce
|
// Generate and send nonce
|
||||||
hasher->get_hex(hex_buffer1);
|
hasher->get_hex(buf);
|
||||||
hex_buffer1[hex_size] = '\0';
|
buf[hex_size] = '\0';
|
||||||
ESP_LOGV(TAG, "Auth: %s Nonce is %s", LOG_STR_ARG(name), hex_buffer1);
|
ESP_LOGV(TAG, "Auth: %s Nonce is %s", LOG_STR_ARG(name), buf);
|
||||||
|
|
||||||
// Send nonce
|
if (!this->writeall_(reinterpret_cast<uint8_t *>(buf), hex_size)) {
|
||||||
if (!this->writeall_(reinterpret_cast<uint8_t *>(hex_buffer1), hex_size)) {
|
|
||||||
this->log_auth_warning_(LOG_STR("Writing nonce"), name);
|
this->log_auth_warning_(LOG_STR("Writing nonce"), name);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare challenge
|
// Start challenge: password + nonce
|
||||||
hasher->init();
|
hasher->init();
|
||||||
hasher->add(password.c_str(), password.length());
|
hasher->add(password.c_str(), password.length());
|
||||||
hasher->add(hex_buffer1, hex_size); // Add nonce
|
hasher->add(buf, hex_size);
|
||||||
|
|
||||||
// Receive cnonce into hex_buffer2
|
// Read cnonce and add to hash
|
||||||
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), hex_size)) {
|
if (!this->readall_(reinterpret_cast<uint8_t *>(buf), hex_size)) {
|
||||||
this->log_auth_warning_(LOG_STR("Reading cnonce"), name);
|
this->log_auth_warning_(LOG_STR("Reading cnonce"), name);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
hex_buffer2[hex_size] = '\0';
|
buf[hex_size] = '\0';
|
||||||
ESP_LOGV(TAG, "Auth: %s CNonce is %s", LOG_STR_ARG(name), hex_buffer2);
|
ESP_LOGV(TAG, "Auth: %s CNonce is %s", LOG_STR_ARG(name), buf);
|
||||||
|
|
||||||
// Add cnonce to hash
|
hasher->add(buf, hex_size);
|
||||||
hasher->add(hex_buffer2, hex_size);
|
|
||||||
|
|
||||||
// Calculate result - reuse hex_buffer1 for expected
|
|
||||||
hasher->calculate();
|
hasher->calculate();
|
||||||
hasher->get_hex(hex_buffer1);
|
|
||||||
hex_buffer1[hex_size] = '\0';
|
|
||||||
ESP_LOGV(TAG, "Auth: %s Result is %s", LOG_STR_ARG(name), hex_buffer1);
|
|
||||||
|
|
||||||
// Receive response - reuse hex_buffer2
|
// Log expected result (digest is already in hasher)
|
||||||
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), hex_size)) {
|
hasher->get_hex(buf);
|
||||||
|
buf[hex_size] = '\0';
|
||||||
|
ESP_LOGV(TAG, "Auth: %s Result is %s", LOG_STR_ARG(name), buf);
|
||||||
|
|
||||||
|
// Read response into the buffer
|
||||||
|
if (!this->readall_(reinterpret_cast<uint8_t *>(buf), hex_size)) {
|
||||||
this->log_auth_warning_(LOG_STR("Reading response"), name);
|
this->log_auth_warning_(LOG_STR("Reading response"), name);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
hex_buffer2[hex_size] = '\0';
|
buf[hex_size] = '\0';
|
||||||
ESP_LOGV(TAG, "Auth: %s Response is %s", LOG_STR_ARG(name), hex_buffer2);
|
ESP_LOGV(TAG, "Auth: %s Response is %s", LOG_STR_ARG(name), buf);
|
||||||
|
|
||||||
// Compare
|
// Compare response directly with digest in hasher
|
||||||
bool matches = memcmp(hex_buffer1, hex_buffer2, hex_size) == 0;
|
bool matches = hasher->equals_hex(buf);
|
||||||
|
|
||||||
if (!matches) {
|
if (!matches) {
|
||||||
ESP_LOGW(TAG, "Auth failed! %s passwords do not match", LOG_STR_ARG(name));
|
this->log_auth_warning_(LOG_STR("Password mismatch"), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
return matches;
|
return matches;
|
||||||
}
|
}
|
||||||
|
#endif // USE_OTA_PASSWORD
|
||||||
|
|
||||||
} // namespace esphome
|
} // namespace esphome
|
||||||
#endif
|
#endif
|
||||||
|
@@ -31,14 +31,16 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
|
|||||||
protected:
|
protected:
|
||||||
void handle_handshake_();
|
void handle_handshake_();
|
||||||
void handle_data_();
|
void handle_data_();
|
||||||
|
#ifdef USE_OTA_PASSWORD
|
||||||
bool perform_hash_auth_(HashBase *hasher, const std::string &password, size_t nonce_size, uint8_t auth_request,
|
bool perform_hash_auth_(HashBase *hasher, const std::string &password, size_t nonce_size, uint8_t auth_request,
|
||||||
const LogString *name);
|
const LogString *name, char *buf);
|
||||||
|
void log_auth_warning_(const LogString *action, const LogString *hash_name);
|
||||||
|
#endif // USE_OTA_PASSWORD
|
||||||
bool readall_(uint8_t *buf, size_t len);
|
bool readall_(uint8_t *buf, size_t len);
|
||||||
bool writeall_(const uint8_t *buf, size_t len);
|
bool writeall_(const uint8_t *buf, size_t len);
|
||||||
void log_socket_error_(const LogString *msg);
|
void log_socket_error_(const LogString *msg);
|
||||||
void log_read_error_(const LogString *what);
|
void log_read_error_(const LogString *what);
|
||||||
void log_start_(const LogString *phase);
|
void log_start_(const LogString *phase);
|
||||||
void log_auth_warning_(const LogString *action, const LogString *hash_name);
|
|
||||||
void cleanup_connection_();
|
void cleanup_connection_();
|
||||||
void yield_and_feed_watchdog_();
|
void yield_and_feed_watchdog_();
|
||||||
|
|
||||||
|
@@ -13,12 +13,10 @@ CONFIG_SCHEMA = cv.Schema({})
|
|||||||
|
|
||||||
async def to_code(config: ConfigType) -> None:
|
async def to_code(config: ConfigType) -> None:
|
||||||
# Add OpenSSL library for host platform
|
# Add OpenSSL library for host platform
|
||||||
if CORE.is_host:
|
if not CORE.is_host:
|
||||||
|
return
|
||||||
if IS_MACOS:
|
if IS_MACOS:
|
||||||
# macOS needs special handling for Homebrew OpenSSL
|
# macOS needs special handling for Homebrew OpenSSL
|
||||||
cg.add_build_flag("-I/opt/homebrew/opt/openssl/include")
|
cg.add_build_flag("-I/opt/homebrew/opt/openssl/include")
|
||||||
cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib")
|
cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib")
|
||||||
cg.add_build_flag("-lcrypto")
|
cg.add_build_flag("-lcrypto")
|
||||||
else:
|
|
||||||
# Linux and other Unix systems usually have OpenSSL in standard paths
|
|
||||||
cg.add_build_flag("-lcrypto")
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -301,6 +302,11 @@ def clean_cmake_cache():
|
|||||||
def clean_build():
|
def clean_build():
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
# Allow skipping cache cleaning for integration tests
|
||||||
|
if os.environ.get("ESPHOME_SKIP_CLEAN_BUILD"):
|
||||||
|
_LOGGER.warning("Skipping build cleaning (ESPHOME_SKIP_CLEAN_BUILD set)")
|
||||||
|
return
|
||||||
|
|
||||||
pioenvs = CORE.relative_pioenvs_path()
|
pioenvs = CORE.relative_pioenvs_path()
|
||||||
if pioenvs.is_dir():
|
if pioenvs.is_dir():
|
||||||
_LOGGER.info("Deleting %s", pioenvs)
|
_LOGGER.info("Deleting %s", pioenvs)
|
||||||
|
@@ -58,6 +58,8 @@ def _get_platformio_env(cache_dir: Path) -> dict[str, str]:
|
|||||||
env["PLATFORMIO_CORE_DIR"] = str(cache_dir)
|
env["PLATFORMIO_CORE_DIR"] = str(cache_dir)
|
||||||
env["PLATFORMIO_CACHE_DIR"] = str(cache_dir / ".cache")
|
env["PLATFORMIO_CACHE_DIR"] = str(cache_dir / ".cache")
|
||||||
env["PLATFORMIO_LIBDEPS_DIR"] = str(cache_dir / "libdeps")
|
env["PLATFORMIO_LIBDEPS_DIR"] = str(cache_dir / "libdeps")
|
||||||
|
# Prevent cache cleaning during integration tests
|
||||||
|
env["ESPHOME_SKIP_CLEAN_BUILD"] = "1"
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@@ -68,6 +70,11 @@ def shared_platformio_cache() -> Generator[Path]:
|
|||||||
test_cache_dir = Path.home() / ".esphome-integration-tests"
|
test_cache_dir = Path.home() / ".esphome-integration-tests"
|
||||||
cache_dir = test_cache_dir / "platformio"
|
cache_dir = test_cache_dir / "platformio"
|
||||||
|
|
||||||
|
# Create the temp directory that PlatformIO uses to avoid race conditions
|
||||||
|
# This ensures it exists and won't be deleted by parallel processes
|
||||||
|
platformio_tmp_dir = cache_dir / ".cache" / "tmp"
|
||||||
|
platformio_tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Use a lock file in the home directory to ensure only one process initializes the cache
|
# Use a lock file in the home directory to ensure only one process initializes the cache
|
||||||
# This is needed when running with pytest-xdist
|
# This is needed when running with pytest-xdist
|
||||||
# The lock file must be in a directory that already exists to avoid race conditions
|
# The lock file must be in a directory that already exists to avoid race conditions
|
||||||
@@ -83,17 +90,11 @@ def shared_platformio_cache() -> Generator[Path]:
|
|||||||
test_cache_dir.mkdir(exist_ok=True)
|
test_cache_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
# Create a basic host config
|
# Use the cache_init fixture for initialization
|
||||||
init_dir = Path(tmpdir)
|
init_dir = Path(tmpdir)
|
||||||
|
fixture_path = Path(__file__).parent / "fixtures" / "cache_init.yaml"
|
||||||
config_path = init_dir / "cache_init.yaml"
|
config_path = init_dir / "cache_init.yaml"
|
||||||
config_path.write_text("""esphome:
|
config_path.write_text(fixture_path.read_text())
|
||||||
name: cache-init
|
|
||||||
host:
|
|
||||||
api:
|
|
||||||
encryption:
|
|
||||||
key: "IIevImVI42I0FGos5nLqFK91jrJehrgidI0ArwMLr8w="
|
|
||||||
logger:
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Run compilation to populate the cache
|
# Run compilation to populate the cache
|
||||||
# We must succeed here to avoid race conditions where multiple
|
# We must succeed here to avoid race conditions where multiple
|
||||||
@@ -346,7 +347,8 @@ async def wait_and_connect_api_client(
|
|||||||
noise_psk: str | None = None,
|
noise_psk: str | None = None,
|
||||||
client_info: str = "integration-test",
|
client_info: str = "integration-test",
|
||||||
timeout: float = API_CONNECTION_TIMEOUT,
|
timeout: float = API_CONNECTION_TIMEOUT,
|
||||||
) -> AsyncGenerator[APIClient]:
|
return_disconnect_event: bool = False,
|
||||||
|
) -> AsyncGenerator[APIClient | tuple[APIClient, asyncio.Event]]:
|
||||||
"""Wait for API to be available and connect."""
|
"""Wait for API to be available and connect."""
|
||||||
client = APIClient(
|
client = APIClient(
|
||||||
address=address,
|
address=address,
|
||||||
@@ -359,14 +361,17 @@ async def wait_and_connect_api_client(
|
|||||||
# Create a future to signal when connected
|
# Create a future to signal when connected
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
connected_future: asyncio.Future[None] = loop.create_future()
|
connected_future: asyncio.Future[None] = loop.create_future()
|
||||||
|
disconnect_event = asyncio.Event()
|
||||||
|
|
||||||
async def on_connect() -> None:
|
async def on_connect() -> None:
|
||||||
"""Called when successfully connected."""
|
"""Called when successfully connected."""
|
||||||
|
disconnect_event.clear() # Clear the disconnect event on new connection
|
||||||
if not connected_future.done():
|
if not connected_future.done():
|
||||||
connected_future.set_result(None)
|
connected_future.set_result(None)
|
||||||
|
|
||||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||||
"""Called when disconnected."""
|
"""Called when disconnected."""
|
||||||
|
disconnect_event.set()
|
||||||
if not connected_future.done() and not expected_disconnect:
|
if not connected_future.done() and not expected_disconnect:
|
||||||
connected_future.set_exception(
|
connected_future.set_exception(
|
||||||
APIConnectionError("Disconnected before fully connected")
|
APIConnectionError("Disconnected before fully connected")
|
||||||
@@ -397,6 +402,9 @@ async def wait_and_connect_api_client(
|
|||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
raise TimeoutError(f"Failed to connect to API after {timeout} seconds")
|
raise TimeoutError(f"Failed to connect to API after {timeout} seconds")
|
||||||
|
|
||||||
|
if return_disconnect_event:
|
||||||
|
yield client, disconnect_event
|
||||||
|
else:
|
||||||
yield client
|
yield client
|
||||||
finally:
|
finally:
|
||||||
# Stop reconnect logic and disconnect
|
# Stop reconnect logic and disconnect
|
||||||
@@ -430,6 +438,33 @@ async def api_client_connected(
|
|||||||
yield _connect_client
|
yield _connect_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def api_client_connected_with_disconnect(
|
||||||
|
unused_tcp_port: int,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
"""Factory for creating connected API client context managers with disconnect event."""
|
||||||
|
|
||||||
|
def _connect_client_with_disconnect(
|
||||||
|
address: str = LOCALHOST,
|
||||||
|
port: int | None = None,
|
||||||
|
password: str = "",
|
||||||
|
noise_psk: str | None = None,
|
||||||
|
client_info: str = "integration-test",
|
||||||
|
timeout: float = API_CONNECTION_TIMEOUT,
|
||||||
|
):
|
||||||
|
return wait_and_connect_api_client(
|
||||||
|
address=address,
|
||||||
|
port=port if port is not None else unused_tcp_port,
|
||||||
|
password=password,
|
||||||
|
noise_psk=noise_psk,
|
||||||
|
client_info=client_info,
|
||||||
|
timeout=timeout,
|
||||||
|
return_disconnect_event=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield _connect_client_with_disconnect
|
||||||
|
|
||||||
|
|
||||||
async def _read_stream_lines(
|
async def _read_stream_lines(
|
||||||
stream: asyncio.StreamReader,
|
stream: asyncio.StreamReader,
|
||||||
lines: list[str],
|
lines: list[str],
|
||||||
|
10
tests/integration/fixtures/cache_init.yaml
Normal file
10
tests/integration/fixtures/cache_init.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
esphome:
|
||||||
|
name: cache-init
|
||||||
|
|
||||||
|
host:
|
||||||
|
|
||||||
|
api:
|
||||||
|
encryption:
|
||||||
|
key: "IIevImVI42I0FGos5nLqFK91jrJehrgidI0ArwMLr8w="
|
||||||
|
|
||||||
|
logger:
|
@@ -0,0 +1,11 @@
|
|||||||
|
esphome:
|
||||||
|
name: oversized-noise
|
||||||
|
|
||||||
|
host:
|
||||||
|
|
||||||
|
api:
|
||||||
|
encryption:
|
||||||
|
key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=
|
||||||
|
|
||||||
|
logger:
|
||||||
|
level: VERY_VERBOSE
|
11
tests/integration/fixtures/oversized_payload_noise.yaml
Normal file
11
tests/integration/fixtures/oversized_payload_noise.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
esphome:
|
||||||
|
name: oversized-noise
|
||||||
|
|
||||||
|
host:
|
||||||
|
|
||||||
|
api:
|
||||||
|
encryption:
|
||||||
|
key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=
|
||||||
|
|
||||||
|
logger:
|
||||||
|
level: VERY_VERBOSE
|
@@ -0,0 +1,9 @@
|
|||||||
|
esphome:
|
||||||
|
name: oversized-plaintext
|
||||||
|
|
||||||
|
host:
|
||||||
|
|
||||||
|
api:
|
||||||
|
|
||||||
|
logger:
|
||||||
|
level: VERY_VERBOSE
|
@@ -0,0 +1,11 @@
|
|||||||
|
esphome:
|
||||||
|
name: oversized-noise
|
||||||
|
|
||||||
|
host:
|
||||||
|
|
||||||
|
api:
|
||||||
|
encryption:
|
||||||
|
key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=
|
||||||
|
|
||||||
|
logger:
|
||||||
|
level: VERY_VERBOSE
|
@@ -0,0 +1,9 @@
|
|||||||
|
esphome:
|
||||||
|
name: oversized-protobuf-plaintext
|
||||||
|
|
||||||
|
host:
|
||||||
|
|
||||||
|
api:
|
||||||
|
|
||||||
|
logger:
|
||||||
|
level: VERY_VERBOSE
|
335
tests/integration/test_oversized_payloads.py
Normal file
335
tests/integration/test_oversized_payloads.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""Integration tests for oversized payloads and headers that should cause disconnection."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .types import APIClientConnectedWithDisconnectFactory, RunCompiledFunction
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oversized_payload_plaintext(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that oversized payloads (>100KiB) from client cause disconnection without crashing."""
|
||||||
|
process_exited = False
|
||||||
|
helper_log_found = False
|
||||||
|
|
||||||
|
def check_logs(line: str) -> None:
|
||||||
|
nonlocal process_exited, helper_log_found
|
||||||
|
# Check for signs that the process exited/crashed
|
||||||
|
if "Segmentation fault" in line or "core dumped" in line:
|
||||||
|
process_exited = True
|
||||||
|
# Check for HELPER_LOG message about message size exceeding maximum
|
||||||
|
if (
|
||||||
|
"[VV]" in line
|
||||||
|
and "Bad packet: message size" in line
|
||||||
|
and "exceeds maximum" in line
|
||||||
|
):
|
||||||
|
helper_log_found = True
|
||||||
|
|
||||||
|
async with run_compiled(yaml_config, line_callback=check_logs):
|
||||||
|
async with api_client_connected_with_disconnect() as (client, disconnect_event):
|
||||||
|
# Verify basic connection works first
|
||||||
|
device_info = await client.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-plaintext"
|
||||||
|
|
||||||
|
# Create an oversized payload (>100KiB)
|
||||||
|
oversized_data = b"X" * (100 * 1024 + 1) # 100KiB + 1 byte
|
||||||
|
|
||||||
|
# Access the internal connection to send raw data
|
||||||
|
frame_helper = client._connection._frame_helper
|
||||||
|
# Create a message with oversized payload
|
||||||
|
# Using message type 1 (DeviceInfoRequest) as an example
|
||||||
|
message_type = 1
|
||||||
|
frame_helper.write_packets([(message_type, oversized_data)], True)
|
||||||
|
|
||||||
|
# Wait for the connection to be closed by ESPHome
|
||||||
|
await asyncio.wait_for(disconnect_event.wait(), timeout=5.0)
|
||||||
|
|
||||||
|
# After disconnection, verify process didn't crash
|
||||||
|
assert not process_exited, "ESPHome process should not crash"
|
||||||
|
# Verify we saw the expected HELPER_LOG message
|
||||||
|
assert helper_log_found, (
|
||||||
|
"Expected to see HELPER_LOG about message size exceeding maximum"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to reconnect to verify the process is still running
|
||||||
|
async with api_client_connected_with_disconnect() as (client2, _):
|
||||||
|
device_info = await client2.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-plaintext"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oversized_protobuf_message_id_plaintext(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that protobuf messages with ID > UINT16_MAX cause disconnection without crashing.
|
||||||
|
|
||||||
|
This tests the message type limit - message IDs must fit in a uint16_t (0-65535).
|
||||||
|
"""
|
||||||
|
process_exited = False
|
||||||
|
helper_log_found = False
|
||||||
|
|
||||||
|
def check_logs(line: str) -> None:
|
||||||
|
nonlocal process_exited, helper_log_found
|
||||||
|
# Check for signs that the process exited/crashed
|
||||||
|
if "Segmentation fault" in line or "core dumped" in line:
|
||||||
|
process_exited = True
|
||||||
|
# Check for HELPER_LOG message about message type exceeding maximum
|
||||||
|
if (
|
||||||
|
"[VV]" in line
|
||||||
|
and "Bad packet: message type" in line
|
||||||
|
and "exceeds maximum" in line
|
||||||
|
):
|
||||||
|
helper_log_found = True
|
||||||
|
|
||||||
|
async with run_compiled(yaml_config, line_callback=check_logs):
|
||||||
|
async with api_client_connected_with_disconnect() as (client, disconnect_event):
|
||||||
|
# Verify basic connection works first
|
||||||
|
device_info = await client.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-protobuf-plaintext"
|
||||||
|
|
||||||
|
# Access the internal connection to send raw message with large ID
|
||||||
|
frame_helper = client._connection._frame_helper
|
||||||
|
# Message ID that exceeds uint16_t limit (> 65535)
|
||||||
|
large_message_id = 65536 # 2^16, exceeds UINT16_MAX
|
||||||
|
# Small payload for the test
|
||||||
|
payload = b"test"
|
||||||
|
|
||||||
|
# This should cause disconnection due to oversized varint
|
||||||
|
frame_helper.write_packets([(large_message_id, payload)], True)
|
||||||
|
|
||||||
|
# Wait for the connection to be closed by ESPHome
|
||||||
|
await asyncio.wait_for(disconnect_event.wait(), timeout=5.0)
|
||||||
|
|
||||||
|
# After disconnection, verify process didn't crash
|
||||||
|
assert not process_exited, "ESPHome process should not crash"
|
||||||
|
# Verify we saw the expected HELPER_LOG message
|
||||||
|
assert helper_log_found, (
|
||||||
|
"Expected to see HELPER_LOG about message type exceeding maximum"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to reconnect to verify the process is still running
|
||||||
|
async with api_client_connected_with_disconnect() as (client2, _):
|
||||||
|
device_info = await client2.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-protobuf-plaintext"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oversized_payload_noise(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that oversized payloads (>100KiB) from client cause disconnection without crashing with noise encryption."""
|
||||||
|
noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU="
|
||||||
|
process_exited = False
|
||||||
|
cipherstate_failed = False
|
||||||
|
|
||||||
|
def check_logs(line: str) -> None:
|
||||||
|
nonlocal process_exited, cipherstate_failed
|
||||||
|
# Check for signs that the process exited/crashed
|
||||||
|
if "Segmentation fault" in line or "core dumped" in line:
|
||||||
|
process_exited = True
|
||||||
|
# Check for the expected warning about decryption failure
|
||||||
|
if (
|
||||||
|
"[W][api.connection" in line
|
||||||
|
and "Reading failed CIPHERSTATE_DECRYPT_FAILED" in line
|
||||||
|
):
|
||||||
|
cipherstate_failed = True
|
||||||
|
|
||||||
|
async with run_compiled(yaml_config, line_callback=check_logs):
|
||||||
|
async with api_client_connected_with_disconnect(noise_psk=noise_key) as (
|
||||||
|
client,
|
||||||
|
disconnect_event,
|
||||||
|
):
|
||||||
|
# Verify basic connection works first
|
||||||
|
device_info = await client.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-noise"
|
||||||
|
|
||||||
|
# Create an oversized payload (>100KiB)
|
||||||
|
oversized_data = b"Y" * (100 * 1024 + 1) # 100KiB + 1 byte
|
||||||
|
|
||||||
|
# Access the internal connection to send raw data
|
||||||
|
frame_helper = client._connection._frame_helper
|
||||||
|
# For noise connections, we still send through write_packets
|
||||||
|
# but the frame helper will handle encryption
|
||||||
|
# Using message type 1 (DeviceInfoRequest) as an example
|
||||||
|
message_type = 1
|
||||||
|
frame_helper.write_packets([(message_type, oversized_data)], True)
|
||||||
|
|
||||||
|
# Wait for the connection to be closed by ESPHome
|
||||||
|
await asyncio.wait_for(disconnect_event.wait(), timeout=5.0)
|
||||||
|
|
||||||
|
# After disconnection, verify process didn't crash
|
||||||
|
assert not process_exited, "ESPHome process should not crash"
|
||||||
|
# Verify we saw the expected warning message
|
||||||
|
assert cipherstate_failed, (
|
||||||
|
"Expected to see warning about CIPHERSTATE_DECRYPT_FAILED"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to reconnect to verify the process is still running
|
||||||
|
async with api_client_connected_with_disconnect(noise_psk=noise_key) as (
|
||||||
|
client2,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
device_info = await client2.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-noise"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oversized_protobuf_message_id_noise(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the noise protocol handles unknown message types correctly.
|
||||||
|
|
||||||
|
With noise encryption, message types are stored as uint16_t (2 bytes) after decryption.
|
||||||
|
Unknown message types should be ignored without disconnecting, as ESPHome needs to
|
||||||
|
read the full message to maintain encryption stream continuity.
|
||||||
|
"""
|
||||||
|
noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU="
|
||||||
|
process_exited = False
|
||||||
|
|
||||||
|
def check_logs(line: str) -> None:
|
||||||
|
nonlocal process_exited
|
||||||
|
# Check for signs that the process exited/crashed
|
||||||
|
if "Segmentation fault" in line or "core dumped" in line:
|
||||||
|
process_exited = True
|
||||||
|
|
||||||
|
async with run_compiled(yaml_config, line_callback=check_logs):
|
||||||
|
async with api_client_connected_with_disconnect(noise_psk=noise_key) as (
|
||||||
|
client,
|
||||||
|
disconnect_event,
|
||||||
|
):
|
||||||
|
# Verify basic connection works first
|
||||||
|
device_info = await client.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-noise"
|
||||||
|
|
||||||
|
# With noise, message types are uint16_t, so we test with an unknown but valid value
|
||||||
|
frame_helper = client._connection._frame_helper
|
||||||
|
|
||||||
|
# Test with an unknown message type (65535 is not used by ESPHome)
|
||||||
|
unknown_message_id = 65535 # Valid uint16_t but unknown to ESPHome
|
||||||
|
payload = b"test"
|
||||||
|
|
||||||
|
# Send the unknown message type - ESPHome should read and ignore it
|
||||||
|
frame_helper.write_packets([(unknown_message_id, payload)], True)
|
||||||
|
|
||||||
|
# Give ESPHome a moment to process (but expect no disconnection)
|
||||||
|
# The connection should stay alive as ESPHome ignores unknown message types
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(disconnect_event.wait(), timeout=0.5)
|
||||||
|
|
||||||
|
# Connection should still be alive - unknown types are ignored, not fatal
|
||||||
|
assert client._connection.is_connected, (
|
||||||
|
"Connection should remain open for unknown message types"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we can still communicate by sending a valid request
|
||||||
|
device_info2 = await client.device_info()
|
||||||
|
assert device_info2 is not None
|
||||||
|
assert device_info2.name == "oversized-noise"
|
||||||
|
|
||||||
|
# After test, verify process didn't crash
|
||||||
|
assert not process_exited, "ESPHome process should not crash"
|
||||||
|
|
||||||
|
# Verify we can still reconnect
|
||||||
|
async with api_client_connected_with_disconnect(noise_psk=noise_key) as (
|
||||||
|
client2,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
device_info = await client2.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-noise"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_noise_corrupt_encrypted_frame(
|
||||||
|
yaml_config: str,
|
||||||
|
run_compiled: RunCompiledFunction,
|
||||||
|
api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that noise protocol properly handles corrupt encrypted frames.
|
||||||
|
|
||||||
|
Send a frame with valid size but corrupt encrypted content (garbage bytes).
|
||||||
|
This should fail decryption and cause disconnection.
|
||||||
|
"""
|
||||||
|
noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU="
|
||||||
|
process_exited = False
|
||||||
|
cipherstate_failed = False
|
||||||
|
|
||||||
|
def check_logs(line: str) -> None:
|
||||||
|
nonlocal process_exited, cipherstate_failed
|
||||||
|
# Check for signs that the process exited/crashed
|
||||||
|
if "Segmentation fault" in line or "core dumped" in line:
|
||||||
|
process_exited = True
|
||||||
|
# Check for the expected warning about decryption failure
|
||||||
|
if (
|
||||||
|
"[W][api.connection" in line
|
||||||
|
and "Reading failed CIPHERSTATE_DECRYPT_FAILED" in line
|
||||||
|
):
|
||||||
|
cipherstate_failed = True
|
||||||
|
|
||||||
|
async with run_compiled(yaml_config, line_callback=check_logs):
|
||||||
|
async with api_client_connected_with_disconnect(noise_psk=noise_key) as (
|
||||||
|
client,
|
||||||
|
disconnect_event,
|
||||||
|
):
|
||||||
|
# Verify basic connection works first
|
||||||
|
device_info = await client.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-noise"
|
||||||
|
|
||||||
|
# Get the socket to send raw corrupt data
|
||||||
|
socket = client._connection._socket
|
||||||
|
|
||||||
|
# Send a corrupt noise frame directly to the socket
|
||||||
|
# Format: [indicator=0x01][size_high][size_low][garbage_encrypted_data]
|
||||||
|
# Size of 32 bytes (reasonable size for a noise frame with MAC)
|
||||||
|
corrupt_frame = bytes(
|
||||||
|
[
|
||||||
|
0x01, # Noise indicator
|
||||||
|
0x00, # Size high byte
|
||||||
|
0x20, # Size low byte (32 bytes)
|
||||||
|
]
|
||||||
|
) + bytes(32) # 32 bytes of zeros (invalid encrypted data)
|
||||||
|
|
||||||
|
# Send the corrupt frame
|
||||||
|
socket.sendall(corrupt_frame)
|
||||||
|
|
||||||
|
# Wait for ESPHome to disconnect due to decryption failure
|
||||||
|
await asyncio.wait_for(disconnect_event.wait(), timeout=5.0)
|
||||||
|
|
||||||
|
# After disconnection, verify process didn't crash
|
||||||
|
assert not process_exited, (
|
||||||
|
"ESPHome process should not crash on corrupt encrypted frames"
|
||||||
|
)
|
||||||
|
# Verify we saw the expected warning message
|
||||||
|
assert cipherstate_failed, (
|
||||||
|
"Expected to see warning about CIPHERSTATE_DECRYPT_FAILED"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we can still reconnect after handling the corrupt frame
|
||||||
|
async with api_client_connected_with_disconnect(noise_psk=noise_key) as (
|
||||||
|
client2,
|
||||||
|
_,
|
||||||
|
):
|
||||||
|
device_info = await client2.device_info()
|
||||||
|
assert device_info is not None
|
||||||
|
assert device_info.name == "oversized-noise"
|
@@ -54,3 +54,17 @@ class APIClientConnectedFactory(Protocol):
|
|||||||
client_info: str = "integration-test",
|
client_info: str = "integration-test",
|
||||||
timeout: float = 30,
|
timeout: float = 30,
|
||||||
) -> AbstractAsyncContextManager[APIClient]: ...
|
) -> AbstractAsyncContextManager[APIClient]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class APIClientConnectedWithDisconnectFactory(Protocol):
|
||||||
|
"""Protocol for connected API client factory that returns disconnect event."""
|
||||||
|
|
||||||
|
def __call__( # noqa: E704
|
||||||
|
self,
|
||||||
|
address: str = "localhost",
|
||||||
|
port: int | None = None,
|
||||||
|
password: str = "",
|
||||||
|
noise_psk: str | None = None,
|
||||||
|
client_info: str = "integration-test",
|
||||||
|
timeout: float = 30,
|
||||||
|
) -> AbstractAsyncContextManager[tuple[APIClient, asyncio.Event]]: ...
|
||||||
|
@@ -94,3 +94,10 @@ def mock_run_git_command() -> Generator[Mock, None, None]:
|
|||||||
"""Mock run_git_command for git module."""
|
"""Mock run_git_command for git module."""
|
||||||
with patch("esphome.git.run_git_command") as mock:
|
with patch("esphome.git.run_git_command") as mock:
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_idedata() -> Generator[Mock, None, None]:
|
||||||
|
"""Mock get_idedata for platformio_api."""
|
||||||
|
with patch("esphome.platformio_api.get_idedata") as mock:
|
||||||
|
yield mock
|
||||||
|
@@ -15,7 +15,7 @@ def test_clone_or_update_with_never_refresh(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that NEVER_REFRESH skips updates for existing repos."""
|
"""Test that NEVER_REFRESH skips updates for existing repos."""
|
||||||
# Set up CORE.config_path so data_dir uses tmp_path
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
CORE.config_path = str(tmp_path / "test.yaml")
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
# Compute the expected repo directory path
|
# Compute the expected repo directory path
|
||||||
url = "https://github.com/test/repo"
|
url = "https://github.com/test/repo"
|
||||||
@@ -56,7 +56,7 @@ def test_clone_or_update_with_refresh_updates_old_repo(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that refresh triggers update for old repos."""
|
"""Test that refresh triggers update for old repos."""
|
||||||
# Set up CORE.config_path so data_dir uses tmp_path
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
CORE.config_path = str(tmp_path / "test.yaml")
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
# Compute the expected repo directory path
|
# Compute the expected repo directory path
|
||||||
url = "https://github.com/test/repo"
|
url = "https://github.com/test/repo"
|
||||||
@@ -110,7 +110,7 @@ def test_clone_or_update_with_refresh_skips_fresh_repo(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that refresh doesn't update fresh repos."""
|
"""Test that refresh doesn't update fresh repos."""
|
||||||
# Set up CORE.config_path so data_dir uses tmp_path
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
CORE.config_path = str(tmp_path / "test.yaml")
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
# Compute the expected repo directory path
|
# Compute the expected repo directory path
|
||||||
url = "https://github.com/test/repo"
|
url = "https://github.com/test/repo"
|
||||||
@@ -156,7 +156,7 @@ def test_clone_or_update_clones_missing_repo(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that missing repos are cloned regardless of refresh setting."""
|
"""Test that missing repos are cloned regardless of refresh setting."""
|
||||||
# Set up CORE.config_path so data_dir uses tmp_path
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
CORE.config_path = str(tmp_path / "test.yaml")
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
# Compute the expected repo directory path
|
# Compute the expected repo directory path
|
||||||
url = "https://github.com/test/repo"
|
url = "https://github.com/test/repo"
|
||||||
@@ -198,7 +198,7 @@ def test_clone_or_update_with_none_refresh_always_updates(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that refresh=None always updates existing repos."""
|
"""Test that refresh=None always updates existing repos."""
|
||||||
# Set up CORE.config_path so data_dir uses tmp_path
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
CORE.config_path = str(tmp_path / "test.yaml")
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
# Compute the expected repo directory path
|
# Compute the expected repo directory path
|
||||||
url = "https://github.com/test/repo"
|
url = "https://github.com/test/repo"
|
||||||
|
@@ -12,6 +12,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from pytest import CaptureFixture
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
|
from esphome import platformio_api
|
||||||
from esphome.__main__ import (
|
from esphome.__main__ import (
|
||||||
Purpose,
|
Purpose,
|
||||||
choose_upload_log_host,
|
choose_upload_log_host,
|
||||||
@@ -28,7 +29,9 @@ from esphome.__main__ import (
|
|||||||
mqtt_get_ip,
|
mqtt_get_ip,
|
||||||
show_logs,
|
show_logs,
|
||||||
upload_program,
|
upload_program,
|
||||||
|
upload_using_esptool,
|
||||||
)
|
)
|
||||||
|
from esphome.components.esp32.const import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32
|
||||||
from esphome.const import (
|
from esphome.const import (
|
||||||
CONF_API,
|
CONF_API,
|
||||||
CONF_BROKER,
|
CONF_BROKER,
|
||||||
@@ -220,6 +223,14 @@ def mock_run_external_process() -> Generator[Mock]:
|
|||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_run_external_command() -> Generator[Mock]:
|
||||||
|
"""Mock run_external_command for testing."""
|
||||||
|
with patch("esphome.__main__.run_external_command") as mock:
|
||||||
|
mock.return_value = 0 # Default to success
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def test_choose_upload_log_host_with_string_default() -> None:
|
def test_choose_upload_log_host_with_string_default() -> None:
|
||||||
"""Test with a single string default device."""
|
"""Test with a single string default device."""
|
||||||
setup_core()
|
setup_core()
|
||||||
@@ -818,6 +829,122 @@ def test_upload_program_serial_esp8266_with_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_using_esptool_path_conversion(
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_run_external_command: Mock,
|
||||||
|
mock_get_idedata: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test upload_using_esptool properly converts Path objects to strings for esptool.
|
||||||
|
|
||||||
|
This test ensures that img.path (Path object) is converted to string before
|
||||||
|
passing to esptool, preventing AttributeError.
|
||||||
|
"""
|
||||||
|
setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path, name="test")
|
||||||
|
|
||||||
|
# Set up ESP32-specific data required by get_esp32_variant()
|
||||||
|
CORE.data[KEY_ESP32] = {KEY_VARIANT: VARIANT_ESP32}
|
||||||
|
|
||||||
|
# Create mock IDEData with Path objects
|
||||||
|
mock_idedata = MagicMock(spec=platformio_api.IDEData)
|
||||||
|
mock_idedata.firmware_bin_path = tmp_path / "firmware.bin"
|
||||||
|
mock_idedata.extra_flash_images = [
|
||||||
|
platformio_api.FlashImage(path=tmp_path / "bootloader.bin", offset="0x1000"),
|
||||||
|
platformio_api.FlashImage(path=tmp_path / "partitions.bin", offset="0x8000"),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_get_idedata.return_value = mock_idedata
|
||||||
|
|
||||||
|
# Create the actual firmware files so they exist
|
||||||
|
(tmp_path / "firmware.bin").touch()
|
||||||
|
(tmp_path / "bootloader.bin").touch()
|
||||||
|
(tmp_path / "partitions.bin").touch()
|
||||||
|
|
||||||
|
config = {CONF_ESPHOME: {"platformio_options": {}}}
|
||||||
|
|
||||||
|
# Call upload_using_esptool without custom file argument
|
||||||
|
result = upload_using_esptool(config, "/dev/ttyUSB0", None, None)
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
# Verify that run_external_command was called
|
||||||
|
assert mock_run_external_command.call_count == 1
|
||||||
|
|
||||||
|
# Get the actual call arguments
|
||||||
|
call_args = mock_run_external_command.call_args[0]
|
||||||
|
|
||||||
|
# The first argument should be esptool.main function,
|
||||||
|
# followed by the command arguments
|
||||||
|
assert len(call_args) > 1
|
||||||
|
|
||||||
|
# Find the indices of the flash image arguments
|
||||||
|
# They should come after "write-flash" and "-z"
|
||||||
|
cmd_list = list(call_args[1:]) # Skip the esptool.main function
|
||||||
|
|
||||||
|
# Verify all paths are strings, not Path objects
|
||||||
|
# The firmware and flash images should be at specific positions
|
||||||
|
write_flash_idx = cmd_list.index("write-flash")
|
||||||
|
|
||||||
|
# After write-flash we have: -z, --flash-size, detect, then offset/path pairs
|
||||||
|
# Check firmware at offset 0x10000 (ESP32)
|
||||||
|
firmware_offset_idx = write_flash_idx + 4
|
||||||
|
assert cmd_list[firmware_offset_idx] == "0x10000"
|
||||||
|
firmware_path = cmd_list[firmware_offset_idx + 1]
|
||||||
|
assert isinstance(firmware_path, str)
|
||||||
|
assert firmware_path.endswith("firmware.bin")
|
||||||
|
|
||||||
|
# Check bootloader
|
||||||
|
bootloader_offset_idx = firmware_offset_idx + 2
|
||||||
|
assert cmd_list[bootloader_offset_idx] == "0x1000"
|
||||||
|
bootloader_path = cmd_list[bootloader_offset_idx + 1]
|
||||||
|
assert isinstance(bootloader_path, str)
|
||||||
|
assert bootloader_path.endswith("bootloader.bin")
|
||||||
|
|
||||||
|
# Check partitions
|
||||||
|
partitions_offset_idx = bootloader_offset_idx + 2
|
||||||
|
assert cmd_list[partitions_offset_idx] == "0x8000"
|
||||||
|
partitions_path = cmd_list[partitions_offset_idx + 1]
|
||||||
|
assert isinstance(partitions_path, str)
|
||||||
|
assert partitions_path.endswith("partitions.bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_using_esptool_with_file_path(
|
||||||
|
tmp_path: Path,
|
||||||
|
mock_run_external_command: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test upload_using_esptool with a custom file that's a Path object."""
|
||||||
|
setup_core(platform=PLATFORM_ESP8266, tmp_path=tmp_path, name="test")
|
||||||
|
|
||||||
|
# Create a test firmware file
|
||||||
|
firmware_file = tmp_path / "custom_firmware.bin"
|
||||||
|
firmware_file.touch()
|
||||||
|
|
||||||
|
config = {CONF_ESPHOME: {"platformio_options": {}}}
|
||||||
|
|
||||||
|
# Call with a Path object as the file argument (though usually it's a string)
|
||||||
|
result = upload_using_esptool(config, "/dev/ttyUSB0", str(firmware_file), None)
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
# Verify that run_external_command was called
|
||||||
|
mock_run_external_command.assert_called_once()
|
||||||
|
|
||||||
|
# Get the actual call arguments
|
||||||
|
call_args = mock_run_external_command.call_args[0]
|
||||||
|
cmd_list = list(call_args[1:]) # Skip the esptool.main function
|
||||||
|
|
||||||
|
# Find the firmware path in the command
|
||||||
|
write_flash_idx = cmd_list.index("write-flash")
|
||||||
|
|
||||||
|
# For custom file, it should be at offset 0x0
|
||||||
|
firmware_offset_idx = write_flash_idx + 4
|
||||||
|
assert cmd_list[firmware_offset_idx] == "0x0"
|
||||||
|
firmware_path = cmd_list[firmware_offset_idx + 1]
|
||||||
|
|
||||||
|
# Verify it's a string, not a Path object
|
||||||
|
assert isinstance(firmware_path, str)
|
||||||
|
assert firmware_path.endswith("custom_firmware.bin")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"platform,device",
|
"platform,device",
|
||||||
[
|
[
|
||||||
|
Reference in New Issue
Block a user