mirror of
https://github.com/esphome/esphome.git
synced 2025-10-07 20:33:47 +01:00
Merge branch 'fix_double_move' into integration
This commit is contained in:
@@ -132,26 +132,16 @@ APIError APINoiseFrameHelper::loop() {
|
||||
return APIFrameHelper::loop();
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
/** Read a packet into the rx_buf_.
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg_start: points to the start of the payload - this pointer is only valid until the next
|
||||
* try_receive_raw_ call
|
||||
*
|
||||
* @return 0 if a full packet is in rx_buf_
|
||||
* @return -1 if error, check errno.
|
||||
* @return APIError::OK if a full packet is in rx_buf_
|
||||
*
|
||||
* errno EWOULDBLOCK: Packet could not be read without blocking. Try again later.
|
||||
* errno ENOMEM: Not enough memory for reading packet.
|
||||
* errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
* errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::try_read_frame_() {
|
||||
// read header
|
||||
if (rx_header_buf_len_ < 3) {
|
||||
// no header information yet
|
||||
@@ -212,12 +202,12 @@ APIError APINoiseFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
}
|
||||
}
|
||||
|
||||
LOG_PACKET_RECEIVED(rx_buf_);
|
||||
*frame = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_len_ = 0;
|
||||
LOG_PACKET_RECEIVED(this->rx_buf_);
|
||||
|
||||
// Clear state for next frame (rx_buf_ still contains data for caller)
|
||||
this->rx_buf_len_ = 0;
|
||||
this->rx_header_buf_len_ = 0;
|
||||
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
@@ -239,18 +229,18 @@ APIError APINoiseFrameHelper::state_action_() {
|
||||
}
|
||||
if (state_ == State::CLIENT_HELLO) {
|
||||
// waiting for client hello
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
aerr = this->try_read_frame_();
|
||||
if (aerr != APIError::OK) {
|
||||
return handle_handshake_frame_error_(aerr);
|
||||
}
|
||||
// ignore contents, may be used in future for flags
|
||||
// Resize for: existing prologue + 2 size bytes + frame data
|
||||
size_t old_size = prologue_.size();
|
||||
prologue_.resize(old_size + 2 + frame.size());
|
||||
prologue_[old_size] = (uint8_t) (frame.size() >> 8);
|
||||
prologue_[old_size + 1] = (uint8_t) frame.size();
|
||||
std::memcpy(prologue_.data() + old_size + 2, frame.data(), frame.size());
|
||||
size_t old_size = this->prologue_.size();
|
||||
this->prologue_.resize(old_size + 2 + this->rx_buf_.size());
|
||||
this->prologue_[old_size] = (uint8_t) (this->rx_buf_.size() >> 8);
|
||||
this->prologue_[old_size + 1] = (uint8_t) this->rx_buf_.size();
|
||||
std::memcpy(this->prologue_.data() + old_size + 2, this->rx_buf_.data(), this->rx_buf_.size());
|
||||
this->rx_buf_.clear();
|
||||
|
||||
state_ = State::SERVER_HELLO;
|
||||
}
|
||||
@@ -292,24 +282,23 @@ APIError APINoiseFrameHelper::state_action_() {
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE) {
|
||||
// waiting for handshake msg
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
aerr = this->try_read_frame_();
|
||||
if (aerr != APIError::OK) {
|
||||
return handle_handshake_frame_error_(aerr);
|
||||
}
|
||||
|
||||
if (frame.empty()) {
|
||||
if (this->rx_buf_.empty()) {
|
||||
send_explicit_handshake_reject_(LOG_STR("Empty handshake message"));
|
||||
return APIError::BAD_HANDSHAKE_ERROR_BYTE;
|
||||
} else if (frame[0] != 0x00) {
|
||||
HELPER_LOG("Bad handshake error byte: %u", frame[0]);
|
||||
} else if (this->rx_buf_[0] != 0x00) {
|
||||
HELPER_LOG("Bad handshake error byte: %u", this->rx_buf_[0]);
|
||||
send_explicit_handshake_reject_(LOG_STR("Bad handshake error byte"));
|
||||
return APIError::BAD_HANDSHAKE_ERROR_BYTE;
|
||||
}
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_input(mbuf, frame.data() + 1, frame.size() - 1);
|
||||
noise_buffer_set_input(mbuf, this->rx_buf_.data() + 1, this->rx_buf_.size() - 1);
|
||||
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
// Special handling for MAC failure
|
||||
@@ -318,6 +307,7 @@ APIError APINoiseFrameHelper::state_action_() {
|
||||
return handle_noise_error_(err, LOG_STR("noise_handshakestate_read_message"),
|
||||
APIError::HANDSHAKESTATE_READ_FAILED);
|
||||
}
|
||||
this->rx_buf_.clear();
|
||||
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
@@ -386,35 +376,33 @@ void APINoiseFrameHelper::send_explicit_handshake_reject_(const LogString *reaso
|
||||
state_ = orig_state;
|
||||
}
|
||||
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
aerr = state_action_();
|
||||
APIError aerr = this->state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
if (this->state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
aerr = this->try_read_frame_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size());
|
||||
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
|
||||
noise_buffer_set_inout(mbuf, this->rx_buf_.data(), this->rx_buf_.size(), this->rx_buf_.size());
|
||||
int err = noise_cipherstate_decrypt(this->recv_cipher_, &mbuf);
|
||||
APIError decrypt_err =
|
||||
handle_noise_error_(err, LOG_STR("noise_cipherstate_decrypt"), APIError::CIPHERSTATE_DECRYPT_FAILED);
|
||||
if (decrypt_err != APIError::OK)
|
||||
if (decrypt_err != APIError::OK) {
|
||||
return decrypt_err;
|
||||
}
|
||||
|
||||
uint16_t msg_size = mbuf.size;
|
||||
uint8_t *msg_data = frame.data();
|
||||
uint8_t *msg_data = this->rx_buf_.data();
|
||||
if (msg_size < 4) {
|
||||
state_ = State::FAILED;
|
||||
this->state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: size %d too short", msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
@@ -422,15 +410,16 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1];
|
||||
uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3];
|
||||
if (data_len > msg_size - 4) {
|
||||
state_ = State::FAILED;
|
||||
this->state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame);
|
||||
buffer->container = std::move(this->rx_buf_);
|
||||
buffer->data_offset = 4;
|
||||
buffer->data_len = data_len;
|
||||
buffer->type = type;
|
||||
this->rx_buf_.clear();
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
|
||||
|
@@ -28,7 +28,7 @@ class APINoiseFrameHelper final : public APIFrameHelper {
|
||||
|
||||
protected:
|
||||
APIError state_action_();
|
||||
APIError try_read_frame_(std::vector<uint8_t> *frame);
|
||||
APIError try_read_frame_();
|
||||
APIError write_frame_(const uint8_t *data, uint16_t len);
|
||||
APIError init_handshake_();
|
||||
APIError check_handshake_finished_();
|
||||
|
@@ -47,21 +47,13 @@ APIError APIPlaintextFrameHelper::loop() {
|
||||
return APIFrameHelper::loop();
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg: store the parsed frame in that struct
|
||||
/** Read a packet into the rx_buf_.
|
||||
*
|
||||
* @return See APIError
|
||||
*
|
||||
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::try_read_frame_() {
|
||||
// read header
|
||||
while (!rx_header_parsed_) {
|
||||
// Now that we know when the socket is ready, we can read up to 3 bytes
|
||||
@@ -170,24 +162,22 @@ APIError APIPlaintextFrameHelper::try_read_frame_(std::vector<uint8_t> *frame) {
|
||||
}
|
||||
}
|
||||
|
||||
LOG_PACKET_RECEIVED(rx_buf_);
|
||||
*frame = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_pos_ = 0;
|
||||
rx_header_parsed_ = false;
|
||||
LOG_PACKET_RECEIVED(this->rx_buf_);
|
||||
|
||||
// Clear state for next frame (rx_buf_ still contains data for caller)
|
||||
this->rx_buf_len_ = 0;
|
||||
this->rx_header_buf_pos_ = 0;
|
||||
this->rx_header_parsed_ = false;
|
||||
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
if (this->state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
APIError aerr = this->try_read_frame_();
|
||||
if (aerr != APIError::OK) {
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
// Make sure to tell the remote that we don't
|
||||
@@ -220,10 +210,11 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame);
|
||||
buffer->container = std::move(this->rx_buf_);
|
||||
buffer->data_offset = 0;
|
||||
buffer->data_len = rx_header_parsed_len_;
|
||||
buffer->type = rx_header_parsed_type_;
|
||||
buffer->data_len = this->rx_header_parsed_len_;
|
||||
buffer->type = this->rx_header_parsed_type_;
|
||||
this->rx_buf_.clear();
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) {
|
||||
|
@@ -24,7 +24,7 @@ class APIPlaintextFrameHelper final : public APIFrameHelper {
|
||||
APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span<const PacketInfo> packets) override;
|
||||
|
||||
protected:
|
||||
APIError try_read_frame_(std::vector<uint8_t> *frame);
|
||||
APIError try_read_frame_();
|
||||
|
||||
// Group 2-byte aligned types
|
||||
uint16_t rx_header_parsed_type_ = 0;
|
||||
|
@@ -35,7 +35,7 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor {
|
||||
msg.set_name(StringRef(this->name_));
|
||||
msg.key = this->key_;
|
||||
std::array<enums::ServiceArgType, sizeof...(Ts)> arg_types = {to_service_arg_type<Ts>()...};
|
||||
for (int i = 0; i < sizeof...(Ts); i++) {
|
||||
for (size_t i = 0; i < sizeof...(Ts); i++) {
|
||||
msg.args.emplace_back();
|
||||
auto &arg = msg.args.back();
|
||||
arg.type = arg_types[i];
|
||||
|
@@ -20,6 +20,23 @@ bool MCP2515::setup_internal() {
|
||||
return false;
|
||||
if (this->set_bitrate_(this->bit_rate_, this->mcp_clock_) != canbus::ERROR_OK)
|
||||
return false;
|
||||
|
||||
// setup hardware filter RXF0 accepting all standard CAN IDs
|
||||
if (this->set_filter_(RXF::RXF0, false, 0) != canbus::ERROR_OK) {
|
||||
return false;
|
||||
}
|
||||
if (this->set_filter_mask_(MASK::MASK0, false, 0) != canbus::ERROR_OK) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// setup hardware filter RXF1 accepting all extended CAN IDs
|
||||
if (this->set_filter_(RXF::RXF1, true, 0) != canbus::ERROR_OK) {
|
||||
return false;
|
||||
}
|
||||
if (this->set_filter_mask_(MASK::MASK1, true, 0) != canbus::ERROR_OK) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (this->set_mode_(this->mcp_mode_) != canbus::ERROR_OK)
|
||||
return false;
|
||||
uint8_t err_flags = this->get_error_flags_();
|
||||
|
@@ -67,6 +67,31 @@ ConfigPath = list[str | int]
|
||||
path_context = contextvars.ContextVar("Config path")
|
||||
|
||||
|
||||
def _add_auto_load_steps(result: Config, loads: list[str]) -> None:
|
||||
"""Add AutoLoadValidationStep for each component in loads that isn't already loaded."""
|
||||
for load in loads:
|
||||
if load not in result:
|
||||
result.add_validation_step(AutoLoadValidationStep(load))
|
||||
|
||||
|
||||
def _process_auto_load(
|
||||
result: Config, platform: ComponentManifest, path: ConfigPath
|
||||
) -> None:
|
||||
# Process platform's AUTO_LOAD
|
||||
auto_load = platform.auto_load
|
||||
if isinstance(auto_load, list):
|
||||
_add_auto_load_steps(result, auto_load)
|
||||
elif callable(auto_load):
|
||||
import inspect
|
||||
|
||||
if inspect.signature(auto_load).parameters:
|
||||
result.add_validation_step(
|
||||
AddDynamicAutoLoadsValidationStep(path, platform)
|
||||
)
|
||||
else:
|
||||
_add_auto_load_steps(result, auto_load())
|
||||
|
||||
|
||||
def _process_platform_config(
|
||||
result: Config,
|
||||
component_name: str,
|
||||
@@ -91,9 +116,7 @@ def _process_platform_config(
|
||||
CORE.loaded_platforms.add(f"{component_name}/{platform_name}")
|
||||
|
||||
# Process platform's AUTO_LOAD
|
||||
for load in platform.auto_load:
|
||||
if load not in result:
|
||||
result.add_validation_step(AutoLoadValidationStep(load))
|
||||
_process_auto_load(result, platform, path)
|
||||
|
||||
# Add validation steps for the platform
|
||||
p_domain = f"{component_name}.{platform_name}"
|
||||
@@ -390,9 +413,7 @@ class LoadValidationStep(ConfigValidationStep):
|
||||
result[self.domain] = self.conf = [self.conf]
|
||||
|
||||
# Process AUTO_LOAD
|
||||
for load in component.auto_load:
|
||||
if load not in result:
|
||||
result.add_validation_step(AutoLoadValidationStep(load))
|
||||
_process_auto_load(result, component, path)
|
||||
|
||||
result.add_validation_step(
|
||||
MetadataValidationStep([self.domain], self.domain, self.conf, component)
|
||||
@@ -618,6 +639,34 @@ class MetadataValidationStep(ConfigValidationStep):
|
||||
result.add_validation_step(FinalValidateValidationStep(self.path, self.comp))
|
||||
|
||||
|
||||
class AddDynamicAutoLoadsValidationStep(ConfigValidationStep):
|
||||
"""Add dynamic auto loads step.
|
||||
|
||||
This step is used to auto-load components where one component can alter its
|
||||
AUTO_LOAD based on its configuration.
|
||||
"""
|
||||
|
||||
# Has to happen after normal schema is validated and before final schema validation
|
||||
priority = -10.0
|
||||
|
||||
def __init__(self, path: ConfigPath, comp: ComponentManifest) -> None:
|
||||
self.path = path
|
||||
self.comp = comp
|
||||
|
||||
def run(self, result: Config) -> None:
|
||||
if result.errors:
|
||||
# If result already has errors, skip this step
|
||||
return
|
||||
|
||||
conf = result.get_nested_item(self.path)
|
||||
with result.catch_error(self.path):
|
||||
auto_load = self.comp.auto_load
|
||||
if not callable(auto_load):
|
||||
return
|
||||
loads = auto_load(conf)
|
||||
_add_auto_load_steps(result, loads)
|
||||
|
||||
|
||||
class SchemaValidationStep(ConfigValidationStep):
|
||||
"""Schema validation step.
|
||||
|
||||
|
@@ -77,7 +77,7 @@ bool ESPTime::strptime(const std::string &time_to_parse, ESPTime &esp_time) {
|
||||
&hour, // NOLINT
|
||||
&minute, // NOLINT
|
||||
&second, &num) == 6 && // NOLINT
|
||||
num == time_to_parse.size()) {
|
||||
num == static_cast<int>(time_to_parse.size())) {
|
||||
esp_time.year = year;
|
||||
esp_time.month = month;
|
||||
esp_time.day_of_month = day;
|
||||
@@ -87,7 +87,7 @@ bool ESPTime::strptime(const std::string &time_to_parse, ESPTime &esp_time) {
|
||||
} else if (sscanf(time_to_parse.c_str(), "%04hu-%02hhu-%02hhu %02hhu:%02hhu %n", &year, &month, &day, // NOLINT
|
||||
&hour, // NOLINT
|
||||
&minute, &num) == 5 && // NOLINT
|
||||
num == time_to_parse.size()) {
|
||||
num == static_cast<int>(time_to_parse.size())) {
|
||||
esp_time.year = year;
|
||||
esp_time.month = month;
|
||||
esp_time.day_of_month = day;
|
||||
@@ -95,17 +95,17 @@ bool ESPTime::strptime(const std::string &time_to_parse, ESPTime &esp_time) {
|
||||
esp_time.minute = minute;
|
||||
esp_time.second = 0;
|
||||
} else if (sscanf(time_to_parse.c_str(), "%02hhu:%02hhu:%02hhu %n", &hour, &minute, &second, &num) == 3 && // NOLINT
|
||||
num == time_to_parse.size()) {
|
||||
num == static_cast<int>(time_to_parse.size())) {
|
||||
esp_time.hour = hour;
|
||||
esp_time.minute = minute;
|
||||
esp_time.second = second;
|
||||
} else if (sscanf(time_to_parse.c_str(), "%02hhu:%02hhu %n", &hour, &minute, &num) == 2 && // NOLINT
|
||||
num == time_to_parse.size()) {
|
||||
num == static_cast<int>(time_to_parse.size())) {
|
||||
esp_time.hour = hour;
|
||||
esp_time.minute = minute;
|
||||
esp_time.second = 0;
|
||||
} else if (sscanf(time_to_parse.c_str(), "%04hu-%02hhu-%02hhu %n", &year, &month, &day, &num) == 3 && // NOLINT
|
||||
num == time_to_parse.size()) {
|
||||
num == static_cast<int>(time_to_parse.size())) {
|
||||
esp_time.year = year;
|
||||
esp_time.month = month;
|
||||
esp_time.day_of_month = day;
|
||||
|
@@ -82,11 +82,10 @@ class ComponentManifest:
|
||||
return getattr(self.module, "CONFLICTS_WITH", [])
|
||||
|
||||
@property
|
||||
def auto_load(self) -> list[str]:
|
||||
al = getattr(self.module, "AUTO_LOAD", [])
|
||||
if callable(al):
|
||||
return al()
|
||||
return al
|
||||
def auto_load(
|
||||
self,
|
||||
) -> list[str] | Callable[[], list[str]] | Callable[[ConfigType], list[str]]:
|
||||
return getattr(self.module, "AUTO_LOAD", [])
|
||||
|
||||
@property
|
||||
def codeowners(self) -> list[str]:
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
@@ -13,7 +14,7 @@ from esphome.const import (
|
||||
PLATFORM_ESP8266,
|
||||
)
|
||||
from esphome.core import CORE
|
||||
from esphome.loader import get_component, get_platform
|
||||
from esphome.loader import ComponentManifest, get_component, get_platform
|
||||
|
||||
|
||||
def filter_component_files(str):
|
||||
@@ -45,6 +46,29 @@ def add_item_to_components_graph(components_graph, parent, child):
|
||||
components_graph[parent].append(child)
|
||||
|
||||
|
||||
def resolve_auto_load(
|
||||
auto_load: list[str] | Callable[[], list[str]] | Callable[[dict | None], list[str]],
|
||||
config: dict | None = None,
|
||||
) -> list[str]:
|
||||
"""Resolve AUTO_LOAD to a list, handling callables with or without config parameter.
|
||||
|
||||
Args:
|
||||
auto_load: The AUTO_LOAD value (list or callable)
|
||||
config: Optional config to pass to callable AUTO_LOAD functions
|
||||
|
||||
Returns:
|
||||
List of component names to auto-load
|
||||
"""
|
||||
if not callable(auto_load):
|
||||
return auto_load
|
||||
|
||||
import inspect
|
||||
|
||||
if inspect.signature(auto_load).parameters:
|
||||
return auto_load(config)
|
||||
return auto_load()
|
||||
|
||||
|
||||
def create_components_graph():
|
||||
# The root directory of the repo
|
||||
root = Path(__file__).parent.parent
|
||||
@@ -63,7 +87,7 @@ def create_components_graph():
|
||||
|
||||
components_graph = {}
|
||||
platforms = []
|
||||
components = []
|
||||
components: list[tuple[ComponentManifest, str, Path]] = []
|
||||
|
||||
for path in components_dir.iterdir():
|
||||
if not path.is_dir():
|
||||
@@ -92,8 +116,8 @@ def create_components_graph():
|
||||
|
||||
for target_config in TARGET_CONFIGURATIONS:
|
||||
CORE.data[KEY_CORE] = target_config
|
||||
for auto_load in comp.auto_load:
|
||||
add_item_to_components_graph(components_graph, auto_load, name)
|
||||
for item in resolve_auto_load(comp.auto_load, config=None):
|
||||
add_item_to_components_graph(components_graph, item, name)
|
||||
# restore config
|
||||
CORE.data[KEY_CORE] = TARGET_CONFIGURATIONS[0]
|
||||
|
||||
@@ -114,8 +138,8 @@ def create_components_graph():
|
||||
|
||||
for target_config in TARGET_CONFIGURATIONS:
|
||||
CORE.data[KEY_CORE] = target_config
|
||||
for auto_load in platform.auto_load:
|
||||
add_item_to_components_graph(components_graph, auto_load, name)
|
||||
for item in resolve_auto_load(platform.auto_load, config={}):
|
||||
add_item_to_components_graph(components_graph, item, name)
|
||||
# restore config
|
||||
CORE.data[KEY_CORE] = TARGET_CONFIGURATIONS[0]
|
||||
|
||||
|
@@ -13,6 +13,7 @@ CONFIG_ESP_TASK_WDT=y
|
||||
CONFIG_ESP_TASK_WDT_PANIC=y
|
||||
CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0=n
|
||||
CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1=n
|
||||
CONFIG_AUTOSTART_ARDUINO=y
|
||||
|
||||
# esp32_ble
|
||||
CONFIG_BT_ENABLED=y
|
||||
|
@@ -101,3 +101,10 @@ def mock_get_idedata() -> Generator[Mock, None, None]:
|
||||
"""Mock get_idedata for platformio_api."""
|
||||
with patch("esphome.platformio_api.get_idedata") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_component() -> Generator[Mock, None, None]:
|
||||
"""Mock get_component for config module."""
|
||||
with patch("esphome.config.get_component") as mock:
|
||||
yield mock
|
||||
|
10
tests/unit_tests/fixtures/auto_load_dynamic.yaml
Normal file
10
tests/unit_tests/fixtures/auto_load_dynamic.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
esphome:
|
||||
name: test-device
|
||||
|
||||
esp32:
|
||||
board: esp32dev
|
||||
|
||||
# Test component with dynamic AUTO_LOAD
|
||||
test_component:
|
||||
enable_logger: true
|
||||
enable_api: false
|
8
tests/unit_tests/fixtures/auto_load_static.yaml
Normal file
8
tests/unit_tests/fixtures/auto_load_static.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
esphome:
|
||||
name: test-device
|
||||
|
||||
esp32:
|
||||
board: esp32dev
|
||||
|
||||
# Test component with static AUTO_LOAD
|
||||
test_component:
|
131
tests/unit_tests/test_config_auto_load.py
Normal file
131
tests/unit_tests/test_config_auto_load.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Tests for AUTO_LOAD functionality including dynamic AUTO_LOAD."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from esphome import config, config_validation as cv, yaml_util
|
||||
from esphome.core import CORE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixtures_dir() -> Path:
|
||||
"""Get the fixtures directory."""
|
||||
return Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_component() -> Mock:
|
||||
"""Create a default mock component for unmocked components."""
|
||||
return Mock(
|
||||
auto_load=[],
|
||||
is_platform_component=False,
|
||||
is_platform=False,
|
||||
multi_conf=False,
|
||||
multi_conf_no_default=False,
|
||||
dependencies=[],
|
||||
conflicts_with=[],
|
||||
config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def static_auto_load_component() -> Mock:
|
||||
"""Create a mock component with static AUTO_LOAD."""
|
||||
return Mock(
|
||||
auto_load=["logger"],
|
||||
is_platform_component=False,
|
||||
is_platform=False,
|
||||
multi_conf=False,
|
||||
multi_conf_no_default=False,
|
||||
dependencies=[],
|
||||
conflicts_with=[],
|
||||
config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA),
|
||||
)
|
||||
|
||||
|
||||
def test_static_auto_load_adds_components(
|
||||
mock_get_component: Mock,
|
||||
fixtures_dir: Path,
|
||||
static_auto_load_component: Mock,
|
||||
default_component: Mock,
|
||||
) -> None:
|
||||
"""Test that static AUTO_LOAD triggers loading of specified components."""
|
||||
CORE.config_path = fixtures_dir / "auto_load_static.yaml"
|
||||
|
||||
config_file = fixtures_dir / "auto_load_static.yaml"
|
||||
raw_config = yaml_util.load_yaml(config_file)
|
||||
|
||||
component_mocks = {"test_component": static_auto_load_component}
|
||||
mock_get_component.side_effect = lambda name: component_mocks.get(
|
||||
name, default_component
|
||||
)
|
||||
|
||||
result = config.validate_config(raw_config, {})
|
||||
|
||||
# Check for validation errors
|
||||
assert not result.errors, f"Validation errors: {result.errors}"
|
||||
|
||||
# Logger should have been auto-loaded by test_component
|
||||
assert "logger" in result
|
||||
assert "test_component" in result
|
||||
|
||||
|
||||
def test_dynamic_auto_load_with_config_param(
|
||||
mock_get_component: Mock,
|
||||
fixtures_dir: Path,
|
||||
default_component: Mock,
|
||||
) -> None:
|
||||
"""Test that dynamic AUTO_LOAD evaluates based on configuration."""
|
||||
CORE.config_path = fixtures_dir / "auto_load_dynamic.yaml"
|
||||
|
||||
config_file = fixtures_dir / "auto_load_dynamic.yaml"
|
||||
raw_config = yaml_util.load_yaml(config_file)
|
||||
|
||||
# Track if auto_load was called with config
|
||||
auto_load_calls = []
|
||||
|
||||
def dynamic_auto_load(conf: dict[str, Any]) -> list[str]:
|
||||
"""Dynamically load components based on config."""
|
||||
auto_load_calls.append(conf)
|
||||
component_map = {
|
||||
"enable_logger": "logger",
|
||||
"enable_api": "api",
|
||||
}
|
||||
return [comp for key, comp in component_map.items() if conf.get(key)]
|
||||
|
||||
dynamic_component = Mock(
|
||||
auto_load=dynamic_auto_load,
|
||||
is_platform_component=False,
|
||||
is_platform=False,
|
||||
multi_conf=False,
|
||||
multi_conf_no_default=False,
|
||||
dependencies=[],
|
||||
conflicts_with=[],
|
||||
config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA),
|
||||
)
|
||||
|
||||
component_mocks = {"test_component": dynamic_component}
|
||||
mock_get_component.side_effect = lambda name: component_mocks.get(
|
||||
name, default_component
|
||||
)
|
||||
|
||||
result = config.validate_config(raw_config, {})
|
||||
|
||||
# Check for validation errors
|
||||
assert not result.errors, f"Validation errors: {result.errors}"
|
||||
|
||||
# Verify auto_load was called with the validated config
|
||||
assert len(auto_load_calls) == 1, "auto_load should be called exactly once"
|
||||
assert auto_load_calls[0].get("enable_logger") is True
|
||||
assert auto_load_calls[0].get("enable_api") is False
|
||||
|
||||
# Only logger should be auto-loaded (enable_logger=true in YAML)
|
||||
assert "logger" in result, (
|
||||
f"Logger not found in result. Result keys: {list(result.keys())}"
|
||||
)
|
||||
# API should NOT be auto-loaded (enable_api=false in YAML)
|
||||
assert "api" not in result
|
||||
assert "test_component" in result
|
@@ -10,13 +10,6 @@ from esphome import config, yaml_util
|
||||
from esphome.core import CORE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_component() -> Generator[Mock, None, None]:
|
||||
"""Fixture for mocking get_component."""
|
||||
with patch("esphome.config.get_component") as mock_get_component:
|
||||
yield mock_get_component
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_platform() -> Generator[Mock, None, None]:
|
||||
"""Fixture for mocking get_platform."""
|
||||
|
Reference in New Issue
Block a user