mirror of
https://github.com/esphome/esphome.git
synced 2025-10-30 06:33:51 +00:00
Merge branch 'bk7200_tagged_pointer_fix' into integration
This commit is contained in:
@@ -1687,7 +1687,9 @@ void APIConnection::DeferredBatch::add_item(EntityBase *entity, MessageCreator c
|
||||
// O(n) but optimized for RAM and not performance.
|
||||
for (auto &item : items) {
|
||||
if (item.entity == entity && item.message_type == message_type) {
|
||||
// Update the existing item with the new creator
|
||||
// Clean up old creator before replacing
|
||||
item.creator.cleanup(message_type);
|
||||
// Move assign the new creator
|
||||
item.creator = std::move(creator);
|
||||
return;
|
||||
}
|
||||
@@ -1730,11 +1732,11 @@ void APIConnection::process_batch_() {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t num_items = this->deferred_batch_.items.size();
|
||||
size_t num_items = this->deferred_batch_.size();
|
||||
|
||||
// Fast path for single message - allocate exact size needed
|
||||
if (num_items == 1) {
|
||||
const auto &item = this->deferred_batch_.items[0];
|
||||
const auto &item = this->deferred_batch_[0];
|
||||
|
||||
// Let the creator calculate size and encode if it fits
|
||||
uint16_t payload_size =
|
||||
@@ -1764,7 +1766,8 @@ void APIConnection::process_batch_() {
|
||||
|
||||
// Pre-calculate exact buffer size needed based on message types
|
||||
uint32_t total_estimated_size = 0;
|
||||
for (const auto &item : this->deferred_batch_.items) {
|
||||
for (size_t i = 0; i < this->deferred_batch_.size(); i++) {
|
||||
const auto &item = this->deferred_batch_[i];
|
||||
total_estimated_size += get_estimated_message_size(item.message_type);
|
||||
}
|
||||
|
||||
@@ -1785,7 +1788,8 @@ void APIConnection::process_batch_() {
|
||||
uint32_t current_offset = 0;
|
||||
|
||||
// Process items and encode directly to buffer
|
||||
for (const auto &item : this->deferred_batch_.items) {
|
||||
for (size_t i = 0; i < this->deferred_batch_.size(); i++) {
|
||||
const auto &item = this->deferred_batch_[i];
|
||||
// Try to encode message
|
||||
// The creator will calculate overhead to determine if the message fits
|
||||
uint16_t payload_size = item.creator(item.entity, this, remaining_size, false, item.message_type);
|
||||
@@ -1840,17 +1844,15 @@ void APIConnection::process_batch_() {
|
||||
// Log messages after send attempt for VV debugging
|
||||
// It's safe to use the buffer for logging at this point regardless of send result
|
||||
for (size_t i = 0; i < items_processed; i++) {
|
||||
const auto &item = this->deferred_batch_.items[i];
|
||||
const auto &item = this->deferred_batch_[i];
|
||||
this->log_batch_item_(item);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Handle remaining items more efficiently
|
||||
if (items_processed < this->deferred_batch_.items.size()) {
|
||||
// Remove processed items from the beginning
|
||||
this->deferred_batch_.items.erase(this->deferred_batch_.items.begin(),
|
||||
this->deferred_batch_.items.begin() + items_processed);
|
||||
|
||||
if (items_processed < this->deferred_batch_.size()) {
|
||||
// Remove processed items from the beginning with proper cleanup
|
||||
this->deferred_batch_.remove_front(items_processed);
|
||||
// Reschedule for remaining items
|
||||
this->schedule_batch_();
|
||||
} else {
|
||||
@@ -1861,23 +1863,16 @@ void APIConnection::process_batch_() {
|
||||
|
||||
uint16_t APIConnection::MessageCreator::operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size,
|
||||
bool is_single, uint16_t message_type) const {
|
||||
if (has_tagged_string_ptr_()) {
|
||||
// Handle string-based messages
|
||||
switch (message_type) {
|
||||
#ifdef USE_EVENT
|
||||
case EventResponse::MESSAGE_TYPE: {
|
||||
auto *e = static_cast<event::Event *>(entity);
|
||||
return APIConnection::try_send_event_response(e, *get_string_ptr_(), conn, remaining_size, is_single);
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
// Should not happen, return 0 to indicate no message
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
// Function pointer case
|
||||
return data_.ptr(entity, conn, remaining_size, is_single);
|
||||
// Special case: EventResponse uses string pointer
|
||||
if (message_type == EventResponse::MESSAGE_TYPE) {
|
||||
auto *e = static_cast<event::Event *>(entity);
|
||||
return APIConnection::try_send_event_response(e, *data_.string_ptr, conn, remaining_size, is_single);
|
||||
}
|
||||
#endif
|
||||
|
||||
// All other message types use function pointers
|
||||
return data_.function_ptr(entity, conn, remaining_size, is_single);
|
||||
}
|
||||
|
||||
uint16_t APIConnection::try_send_list_info_done(EntityBase *entity, APIConnection *conn, uint32_t remaining_size,
|
||||
|
||||
@@ -451,96 +451,53 @@ class APIConnection : public APIServerConnection {
|
||||
// Function pointer type for message encoding
|
||||
using MessageCreatorPtr = uint16_t (*)(EntityBase *, APIConnection *, uint32_t remaining_size, bool is_single);
|
||||
|
||||
// Optimized MessageCreator class using tagged pointer
|
||||
class MessageCreator {
|
||||
// Ensure pointer alignment allows LSB tagging
|
||||
static_assert(alignof(std::string *) > 1, "String pointer alignment must be > 1 for LSB tagging");
|
||||
|
||||
public:
|
||||
// Constructor for function pointer
|
||||
MessageCreator(MessageCreatorPtr ptr) {
|
||||
// Function pointers are always aligned, so LSB is 0
|
||||
data_.ptr = ptr;
|
||||
}
|
||||
MessageCreator(MessageCreatorPtr ptr) { data_.function_ptr = ptr; }
|
||||
|
||||
// Constructor for string state capture
|
||||
explicit MessageCreator(const std::string &str_value) {
|
||||
// Allocate string and tag the pointer
|
||||
auto *str = new std::string(str_value);
|
||||
// Set LSB to 1 to indicate string pointer
|
||||
data_.tagged = reinterpret_cast<uintptr_t>(str) | 1;
|
||||
}
|
||||
explicit MessageCreator(const std::string &str_value) { data_.string_ptr = new std::string(str_value); }
|
||||
|
||||
// Destructor
|
||||
~MessageCreator() {
|
||||
if (has_tagged_string_ptr_()) {
|
||||
delete get_string_ptr_();
|
||||
}
|
||||
}
|
||||
// No destructor - cleanup must be called explicitly with message_type
|
||||
|
||||
// Copy constructor
|
||||
MessageCreator(const MessageCreator &other) {
|
||||
if (other.has_tagged_string_ptr_()) {
|
||||
auto *str = new std::string(*other.get_string_ptr_());
|
||||
data_.tagged = reinterpret_cast<uintptr_t>(str) | 1;
|
||||
} else {
|
||||
data_ = other.data_;
|
||||
}
|
||||
}
|
||||
// Delete copy operations - MessageCreator should only be moved
|
||||
MessageCreator(const MessageCreator &other) = delete;
|
||||
MessageCreator &operator=(const MessageCreator &other) = delete;
|
||||
|
||||
// Move constructor
|
||||
MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.ptr = nullptr; }
|
||||
|
||||
// Assignment operators (needed for batch deduplication)
|
||||
MessageCreator &operator=(const MessageCreator &other) {
|
||||
if (this != &other) {
|
||||
// Clean up current string data if needed
|
||||
if (has_tagged_string_ptr_()) {
|
||||
delete get_string_ptr_();
|
||||
}
|
||||
// Copy new data
|
||||
if (other.has_tagged_string_ptr_()) {
|
||||
auto *str = new std::string(*other.get_string_ptr_());
|
||||
data_.tagged = reinterpret_cast<uintptr_t>(str) | 1;
|
||||
} else {
|
||||
data_ = other.data_;
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.function_ptr = nullptr; }
|
||||
|
||||
// Move assignment
|
||||
MessageCreator &operator=(MessageCreator &&other) noexcept {
|
||||
if (this != &other) {
|
||||
// Clean up current string data if needed
|
||||
if (has_tagged_string_ptr_()) {
|
||||
delete get_string_ptr_();
|
||||
}
|
||||
// Move data
|
||||
// IMPORTANT: Caller must ensure cleanup() was called if this contains a string!
|
||||
// In our usage, this happens in add_item() deduplication and vector::erase()
|
||||
data_ = other.data_;
|
||||
// Reset other to safe state
|
||||
other.data_.ptr = nullptr;
|
||||
other.data_.function_ptr = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Call operator - now accepts message_type as parameter
|
||||
// Call operator - uses message_type to determine union type
|
||||
uint16_t operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single,
|
||||
uint16_t message_type) const;
|
||||
|
||||
private:
|
||||
// Check if this contains a string pointer
|
||||
bool has_tagged_string_ptr_() const { return (data_.tagged & 1) != 0; }
|
||||
|
||||
// Get the actual string pointer (clears the tag bit)
|
||||
std::string *get_string_ptr_() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<std::string *>(data_.tagged & ~uintptr_t(1));
|
||||
// Manual cleanup method - must be called before destruction for string types
|
||||
void cleanup(uint16_t message_type) {
|
||||
#ifdef USE_EVENT
|
||||
if (message_type == EventResponse::MESSAGE_TYPE && data_.string_ptr != nullptr) {
|
||||
delete data_.string_ptr;
|
||||
data_.string_ptr = nullptr;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
union {
|
||||
MessageCreatorPtr ptr;
|
||||
uintptr_t tagged;
|
||||
} data_; // 4 bytes on 32-bit
|
||||
private:
|
||||
union Data {
|
||||
MessageCreatorPtr function_ptr;
|
||||
std::string *string_ptr;
|
||||
} data_; // 4 bytes on 32-bit, 8 bytes on 64-bit - same as before
|
||||
};
|
||||
|
||||
// Generic batching mechanism for both state updates and entity info
|
||||
@@ -558,20 +515,46 @@ class APIConnection : public APIServerConnection {
|
||||
std::vector<BatchItem> items;
|
||||
uint32_t batch_start_time{0};
|
||||
|
||||
private:
|
||||
// Helper to cleanup items from the beginning
|
||||
void cleanup_items_(size_t count) {
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
items[i].creator.cleanup(items[i].message_type);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
DeferredBatch() {
|
||||
// Pre-allocate capacity for typical batch sizes to avoid reallocation
|
||||
items.reserve(8);
|
||||
}
|
||||
|
||||
~DeferredBatch() {
|
||||
// Ensure cleanup of any remaining items
|
||||
clear();
|
||||
}
|
||||
|
||||
// Add item to the batch
|
||||
void add_item(EntityBase *entity, MessageCreator creator, uint16_t message_type);
|
||||
// Add item to the front of the batch (for high priority messages like ping)
|
||||
void add_item_front(EntityBase *entity, MessageCreator creator, uint16_t message_type);
|
||||
|
||||
// Clear all items with proper cleanup
|
||||
void clear() {
|
||||
cleanup_items_(items.size());
|
||||
items.clear();
|
||||
batch_start_time = 0;
|
||||
}
|
||||
|
||||
// Remove processed items from the front with proper cleanup
|
||||
void remove_front(size_t count) {
|
||||
cleanup_items_(count);
|
||||
items.erase(items.begin(), items.begin() + count);
|
||||
}
|
||||
|
||||
bool empty() const { return items.empty(); }
|
||||
size_t size() const { return items.size(); }
|
||||
const BatchItem &operator[](size_t index) const { return items[index]; }
|
||||
};
|
||||
|
||||
// DeferredBatch here (16 bytes, 4-byte aligned)
|
||||
|
||||
@@ -44,3 +44,4 @@ async def to_code(config):
|
||||
cg.add_build_flag("-std=gnu++20")
|
||||
cg.add_define("ESPHOME_BOARD", "host")
|
||||
cg.add_platformio_option("platform", "platformio/native")
|
||||
cg.add_platformio_option("lib_ldf_mode", "off")
|
||||
|
||||
@@ -90,15 +90,24 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) {
|
||||
|
||||
} else {
|
||||
// data starts at 2 and length is 4 for read registers commands
|
||||
if (this->role == ModbusRole::SERVER && (function_code == 0x1 || function_code == 0x3 || function_code == 0x4)) {
|
||||
data_offset = 2;
|
||||
data_len = 4;
|
||||
}
|
||||
|
||||
// the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands
|
||||
if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) {
|
||||
data_offset = 2;
|
||||
data_len = 4;
|
||||
if (this->role == ModbusRole::SERVER) {
|
||||
if (function_code == 0x1 || function_code == 0x3 || function_code == 0x4 || function_code == 0x6) {
|
||||
data_offset = 2;
|
||||
data_len = 4;
|
||||
} else if (function_code == 0x10) {
|
||||
if (at < 6) {
|
||||
return true;
|
||||
}
|
||||
data_offset = 2;
|
||||
// starting address (2 bytes) + quantity of registers (2 bytes) + byte count itself (1 byte) + actual byte count
|
||||
data_len = 2 + 2 + 1 + raw[6];
|
||||
}
|
||||
} else {
|
||||
// the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands
|
||||
if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) {
|
||||
data_offset = 2;
|
||||
data_len = 4;
|
||||
}
|
||||
}
|
||||
|
||||
// Error ( msb indicates error )
|
||||
@@ -132,6 +141,7 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) {
|
||||
bool found = false;
|
||||
for (auto *device : this->devices_) {
|
||||
if (device->address_ == address) {
|
||||
found = true;
|
||||
// Is it an error response?
|
||||
if ((function_code & 0x80) == 0x80) {
|
||||
ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]);
|
||||
@@ -141,13 +151,21 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) {
|
||||
// Ignore modbus exception not related to a pending command
|
||||
ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response");
|
||||
}
|
||||
} else if (this->role == ModbusRole::SERVER && (function_code == 0x3 || function_code == 0x4)) {
|
||||
device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8),
|
||||
uint16_t(data[3]) | (uint16_t(data[2]) << 8));
|
||||
} else {
|
||||
device->on_modbus_data(data);
|
||||
continue;
|
||||
}
|
||||
found = true;
|
||||
if (this->role == ModbusRole::SERVER) {
|
||||
if (function_code == 0x3 || function_code == 0x4) {
|
||||
device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8),
|
||||
uint16_t(data[3]) | (uint16_t(data[2]) << 8));
|
||||
continue;
|
||||
}
|
||||
if (function_code == 0x6 || function_code == 0x10) {
|
||||
device->on_modbus_write_registers(function_code, data);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// fallthrough for other function codes
|
||||
device->on_modbus_data(data);
|
||||
}
|
||||
}
|
||||
waiting_for_response = 0;
|
||||
|
||||
@@ -59,6 +59,7 @@ class ModbusDevice {
|
||||
virtual void on_modbus_data(const std::vector<uint8_t> &data) = 0;
|
||||
virtual void on_modbus_error(uint8_t function_code, uint8_t exception_code) {}
|
||||
virtual void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers){};
|
||||
virtual void on_modbus_write_registers(uint8_t function_code, const std::vector<uint8_t> &data){};
|
||||
void send(uint8_t function, uint16_t start_address, uint16_t number_of_entities, uint8_t payload_len = 0,
|
||||
const uint8_t *payload = nullptr) {
|
||||
this->parent_->send(this->address_, function, start_address, number_of_entities, payload_len, payload);
|
||||
|
||||
@@ -39,6 +39,7 @@ CODEOWNERS = ["@martgras"]
|
||||
AUTO_LOAD = ["modbus"]
|
||||
|
||||
CONF_READ_LAMBDA = "read_lambda"
|
||||
CONF_WRITE_LAMBDA = "write_lambda"
|
||||
CONF_SERVER_REGISTERS = "server_registers"
|
||||
MULTI_CONF = True
|
||||
|
||||
@@ -148,6 +149,7 @@ ModbusServerRegisterSchema = cv.Schema(
|
||||
cv.Required(CONF_ADDRESS): cv.positive_int,
|
||||
cv.Optional(CONF_VALUE_TYPE, default="U_WORD"): cv.enum(SENSOR_VALUE_TYPE),
|
||||
cv.Required(CONF_READ_LAMBDA): cv.returning_lambda,
|
||||
cv.Optional(CONF_WRITE_LAMBDA): cv.returning_lambda,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -318,6 +320,17 @@ async def to_code(config):
|
||||
),
|
||||
)
|
||||
)
|
||||
if CONF_WRITE_LAMBDA in server_register:
|
||||
cg.add(
|
||||
server_register_var.set_write_lambda(
|
||||
cg.TemplateArguments(cpp_type),
|
||||
await cg.process_lambda(
|
||||
server_register[CONF_WRITE_LAMBDA],
|
||||
parameters=[(cg.uint16, "address"), (cpp_type, "x")],
|
||||
return_type=cg.bool_,
|
||||
),
|
||||
)
|
||||
)
|
||||
cg.add(var.add_server_register(server_register_var))
|
||||
await register_modbus_device(var, config)
|
||||
for conf in config.get(CONF_ON_COMMAND_SENT, []):
|
||||
|
||||
@@ -152,6 +152,86 @@ void ModbusController::on_modbus_read_registers(uint8_t function_code, uint16_t
|
||||
this->send(function_code, start_address, number_of_registers, response.size(), response.data());
|
||||
}
|
||||
|
||||
void ModbusController::on_modbus_write_registers(uint8_t function_code, const std::vector<uint8_t> &data) {
|
||||
uint16_t number_of_registers;
|
||||
uint16_t payload_offset;
|
||||
|
||||
if (function_code == 0x10) {
|
||||
number_of_registers = uint16_t(data[3]) | (uint16_t(data[2]) << 8);
|
||||
if (number_of_registers == 0 || number_of_registers > 0x7B) {
|
||||
ESP_LOGW(TAG, "Invalid number of registers %d. Sending exception response.", number_of_registers);
|
||||
send_error(function_code, 3);
|
||||
return;
|
||||
}
|
||||
uint16_t payload_size = data[4];
|
||||
if (payload_size != number_of_registers * 2) {
|
||||
ESP_LOGW(TAG, "Payload size of %d bytes is not 2 times the number of registers (%d). Sending exception response.",
|
||||
payload_size, number_of_registers);
|
||||
send_error(function_code, 3);
|
||||
return;
|
||||
}
|
||||
payload_offset = 5;
|
||||
} else if (function_code == 0x06) {
|
||||
number_of_registers = 1;
|
||||
payload_offset = 2;
|
||||
} else {
|
||||
ESP_LOGW(TAG, "Invalid function code 0x%X. Sending exception response.", function_code);
|
||||
send_error(function_code, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
uint16_t start_address = uint16_t(data[1]) | (uint16_t(data[0]) << 8);
|
||||
ESP_LOGD(TAG,
|
||||
"Received write holding registers for device 0x%X. FC: 0x%X. Start address: 0x%X. Number of registers: "
|
||||
"0x%X.",
|
||||
this->address_, function_code, start_address, number_of_registers);
|
||||
|
||||
auto for_each_register = [this, start_address, number_of_registers, payload_offset](
|
||||
const std::function<bool(ServerRegister *, uint16_t offset)> &callback) -> bool {
|
||||
uint16_t offset = payload_offset;
|
||||
for (uint16_t current_address = start_address; current_address < start_address + number_of_registers;) {
|
||||
bool ok = false;
|
||||
for (auto *server_register : this->server_registers_) {
|
||||
if (server_register->address == current_address) {
|
||||
ok = callback(server_register, offset);
|
||||
current_address += server_register->register_count;
|
||||
offset += server_register->register_count * sizeof(uint16_t);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// check all registers are writable before writing to any of them:
|
||||
if (!for_each_register([](ServerRegister *server_register, uint16_t offset) -> bool {
|
||||
return server_register->write_lambda != nullptr;
|
||||
})) {
|
||||
send_error(function_code, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
// Actually write to the registers:
|
||||
if (!for_each_register([&data](ServerRegister *server_register, uint16_t offset) {
|
||||
int64_t number = payload_to_number(data, server_register->value_type, offset, 0xFFFFFFFF);
|
||||
return server_register->write_lambda(number);
|
||||
})) {
|
||||
send_error(function_code, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> response;
|
||||
response.reserve(6);
|
||||
response.push_back(this->address_);
|
||||
response.push_back(function_code);
|
||||
response.insert(response.end(), data.begin(), data.begin() + 4);
|
||||
this->send_raw(response);
|
||||
}
|
||||
|
||||
SensorSet ModbusController::find_sensors_(ModbusRegisterType register_type, uint16_t start_address) const {
|
||||
auto reg_it = std::find_if(
|
||||
std::begin(this->register_ranges_), std::end(this->register_ranges_),
|
||||
|
||||
@@ -258,6 +258,7 @@ class SensorItem {
|
||||
|
||||
class ServerRegister {
|
||||
using ReadLambda = std::function<int64_t()>;
|
||||
using WriteLambda = std::function<bool(int64_t value)>;
|
||||
|
||||
public:
|
||||
ServerRegister(uint16_t address, SensorValueType value_type, uint8_t register_count) {
|
||||
@@ -277,6 +278,17 @@ class ServerRegister {
|
||||
};
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void set_write_lambda(const std::function<bool(uint16_t address, const T v)> &&user_write_lambda) {
|
||||
this->write_lambda = [this, user_write_lambda](int64_t number) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
float float_value = bit_cast<float>(static_cast<uint32_t>(number));
|
||||
return user_write_lambda(this->address, float_value);
|
||||
}
|
||||
return user_write_lambda(this->address, static_cast<T>(number));
|
||||
};
|
||||
}
|
||||
|
||||
// Formats a raw value into a string representation based on the value type for debugging
|
||||
std::string format_value(int64_t value) const {
|
||||
switch (this->value_type) {
|
||||
@@ -304,6 +316,7 @@ class ServerRegister {
|
||||
SensorValueType value_type{SensorValueType::RAW};
|
||||
uint8_t register_count{0};
|
||||
ReadLambda read_lambda;
|
||||
WriteLambda write_lambda;
|
||||
};
|
||||
|
||||
// ModbusController::create_register_ranges_ tries to optimize register range
|
||||
@@ -485,6 +498,8 @@ class ModbusController : public PollingComponent, public modbus::ModbusDevice {
|
||||
void on_modbus_error(uint8_t function_code, uint8_t exception_code) override;
|
||||
/// called when a modbus request (function code 0x03 or 0x04) was parsed without errors
|
||||
void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers) final;
|
||||
/// called when a modbus request (function code 0x06 or 0x10) was parsed without errors
|
||||
void on_modbus_write_registers(uint8_t function_code, const std::vector<uint8_t> &data) final;
|
||||
/// default delegate called by process_modbus_data when a response has retrieved from the incoming queue
|
||||
void on_register_data(ModbusRegisterType register_type, uint16_t start_address, const std::vector<uint8_t> &data);
|
||||
/// default delegate called by process_modbus_data when a response for a write response has retrieved from the
|
||||
|
||||
@@ -74,7 +74,7 @@ BASE_SCHEMA = cv.All(
|
||||
{
|
||||
cv.Required(CONF_PATH): validate_yaml_filename,
|
||||
cv.Optional(CONF_VARS, default={}): cv.Schema(
|
||||
{cv.string: cv.string}
|
||||
{cv.string: object}
|
||||
),
|
||||
}
|
||||
),
|
||||
@@ -148,7 +148,6 @@ def _process_base_package(config: dict) -> dict:
|
||||
raise cv.Invalid(
|
||||
f"Current ESPHome Version is too old to use this package: {ESPHOME_VERSION} < {min_version}"
|
||||
)
|
||||
vars = {k: str(v) for k, v in vars.items()}
|
||||
new_yaml = yaml_util.substitute_vars(new_yaml, vars)
|
||||
packages[f"{filename}{idx}"] = new_yaml
|
||||
except EsphomeError as e:
|
||||
|
||||
@@ -5,6 +5,13 @@ from esphome.config_helpers import Extend, Remove, merge_config
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS
|
||||
from esphome.yaml_util import ESPHomeDataBase, make_data_base
|
||||
from .jinja import (
|
||||
Jinja,
|
||||
JinjaStr,
|
||||
has_jinja,
|
||||
TemplateError,
|
||||
TemplateRuntimeError,
|
||||
)
|
||||
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -28,7 +35,7 @@ def validate_substitution_key(value):
|
||||
|
||||
CONFIG_SCHEMA = cv.Schema(
|
||||
{
|
||||
validate_substitution_key: cv.string_strict,
|
||||
validate_substitution_key: object,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -37,7 +44,42 @@ async def to_code(config):
|
||||
pass
|
||||
|
||||
|
||||
def _expand_substitutions(substitutions, value, path, ignore_missing):
|
||||
def _expand_jinja(value, orig_value, path, jinja, ignore_missing):
|
||||
if has_jinja(value):
|
||||
# If the original value passed in to this function is a JinjaStr, it means it contains an unresolved
|
||||
# Jinja expression from a previous pass.
|
||||
if isinstance(orig_value, JinjaStr):
|
||||
# Rebuild the JinjaStr in case it was lost while replacing substitutions.
|
||||
value = JinjaStr(value, orig_value.upvalues)
|
||||
try:
|
||||
# Invoke the jinja engine to evaluate the expression.
|
||||
value, err = jinja.expand(value)
|
||||
if err is not None:
|
||||
if not ignore_missing and "password" not in path:
|
||||
_LOGGER.warning(
|
||||
"Found '%s' (see %s) which looks like an expression,"
|
||||
" but could not resolve all the variables: %s",
|
||||
value,
|
||||
"->".join(str(x) for x in path),
|
||||
err.message,
|
||||
)
|
||||
except (
|
||||
TemplateError,
|
||||
TemplateRuntimeError,
|
||||
RuntimeError,
|
||||
ArithmeticError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
) as err:
|
||||
raise cv.Invalid(
|
||||
f"{type(err).__name__} Error evaluating jinja expression '{value}': {str(err)}."
|
||||
f" See {'->'.join(str(x) for x in path)}",
|
||||
path,
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _expand_substitutions(substitutions, value, path, jinja, ignore_missing):
|
||||
if "$" not in value:
|
||||
return value
|
||||
|
||||
@@ -47,7 +89,8 @@ def _expand_substitutions(substitutions, value, path, ignore_missing):
|
||||
while True:
|
||||
m = cv.VARIABLE_PROG.search(value, i)
|
||||
if not m:
|
||||
# Nothing more to match. Done
|
||||
# No more variable substitutions found. See if the remainder looks like a jinja template
|
||||
value = _expand_jinja(value, orig_value, path, jinja, ignore_missing)
|
||||
break
|
||||
|
||||
i, j = m.span(0)
|
||||
@@ -67,8 +110,15 @@ def _expand_substitutions(substitutions, value, path, ignore_missing):
|
||||
continue
|
||||
|
||||
sub = substitutions[name]
|
||||
|
||||
if i == 0 and j == len(value):
|
||||
# The variable spans the whole expression, e.g., "${varName}". Return its resolved value directly
|
||||
# to conserve its type.
|
||||
value = sub
|
||||
break
|
||||
|
||||
tail = value[j:]
|
||||
value = value[:i] + sub
|
||||
value = value[:i] + str(sub)
|
||||
i = len(value)
|
||||
value += tail
|
||||
|
||||
@@ -77,36 +127,40 @@ def _expand_substitutions(substitutions, value, path, ignore_missing):
|
||||
if isinstance(orig_value, ESPHomeDataBase):
|
||||
# even though string can get larger or smaller, the range should point
|
||||
# to original document marks
|
||||
return make_data_base(value, orig_value)
|
||||
value = make_data_base(value, orig_value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _substitute_item(substitutions, item, path, ignore_missing):
|
||||
def _substitute_item(substitutions, item, path, jinja, ignore_missing):
|
||||
if isinstance(item, list):
|
||||
for i, it in enumerate(item):
|
||||
sub = _substitute_item(substitutions, it, path + [i], ignore_missing)
|
||||
sub = _substitute_item(substitutions, it, path + [i], jinja, ignore_missing)
|
||||
if sub is not None:
|
||||
item[i] = sub
|
||||
elif isinstance(item, dict):
|
||||
replace_keys = []
|
||||
for k, v in item.items():
|
||||
if path or k != CONF_SUBSTITUTIONS:
|
||||
sub = _substitute_item(substitutions, k, path + [k], ignore_missing)
|
||||
sub = _substitute_item(
|
||||
substitutions, k, path + [k], jinja, ignore_missing
|
||||
)
|
||||
if sub is not None:
|
||||
replace_keys.append((k, sub))
|
||||
sub = _substitute_item(substitutions, v, path + [k], ignore_missing)
|
||||
sub = _substitute_item(substitutions, v, path + [k], jinja, ignore_missing)
|
||||
if sub is not None:
|
||||
item[k] = sub
|
||||
for old, new in replace_keys:
|
||||
item[new] = merge_config(item.get(old), item.get(new))
|
||||
del item[old]
|
||||
elif isinstance(item, str):
|
||||
sub = _expand_substitutions(substitutions, item, path, ignore_missing)
|
||||
if sub != item:
|
||||
sub = _expand_substitutions(substitutions, item, path, jinja, ignore_missing)
|
||||
if isinstance(sub, JinjaStr) or sub != item:
|
||||
return sub
|
||||
elif isinstance(item, (core.Lambda, Extend, Remove)):
|
||||
sub = _expand_substitutions(substitutions, item.value, path, ignore_missing)
|
||||
sub = _expand_substitutions(
|
||||
substitutions, item.value, path, jinja, ignore_missing
|
||||
)
|
||||
if sub != item:
|
||||
item.value = sub
|
||||
return None
|
||||
@@ -116,11 +170,11 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals
|
||||
if CONF_SUBSTITUTIONS not in config and not command_line_substitutions:
|
||||
return
|
||||
|
||||
substitutions = config.get(CONF_SUBSTITUTIONS)
|
||||
if substitutions is None:
|
||||
substitutions = command_line_substitutions
|
||||
elif command_line_substitutions:
|
||||
substitutions = {**substitutions, **command_line_substitutions}
|
||||
# Merge substitutions in config, overriding with substitutions coming from command line:
|
||||
substitutions = {
|
||||
**config.get(CONF_SUBSTITUTIONS, {}),
|
||||
**(command_line_substitutions or {}),
|
||||
}
|
||||
with cv.prepend_path("substitutions"):
|
||||
if not isinstance(substitutions, dict):
|
||||
raise cv.Invalid(
|
||||
@@ -133,7 +187,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals
|
||||
sub = validate_substitution_key(key)
|
||||
if sub != key:
|
||||
replace_keys.append((key, sub))
|
||||
substitutions[key] = cv.string_strict(value)
|
||||
substitutions[key] = value
|
||||
for old, new in replace_keys:
|
||||
substitutions[new] = substitutions[old]
|
||||
del substitutions[old]
|
||||
@@ -141,4 +195,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals
|
||||
config[CONF_SUBSTITUTIONS] = substitutions
|
||||
# Move substitutions to the first place to replace substitutions in them correctly
|
||||
config.move_to_end(CONF_SUBSTITUTIONS, False)
|
||||
_substitute_item(substitutions, config, [], ignore_missing)
|
||||
|
||||
# Create a Jinja environment that will consider substitutions in scope:
|
||||
jinja = Jinja(substitutions)
|
||||
_substitute_item(substitutions, config, [], jinja, ignore_missing)
|
||||
|
||||
99
esphome/components/substitutions/jinja.py
Normal file
99
esphome/components/substitutions/jinja.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import jinja2 as jinja
|
||||
from jinja2.nativetypes import NativeEnvironment
|
||||
|
||||
TemplateError = jinja.TemplateError
|
||||
TemplateSyntaxError = jinja.TemplateSyntaxError
|
||||
TemplateRuntimeError = jinja.TemplateRuntimeError
|
||||
UndefinedError = jinja.UndefinedError
|
||||
Undefined = jinja.Undefined
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DETECT_JINJA = r"(\$\{)"
|
||||
detect_jinja_re = re.compile(
|
||||
r"<%.+?%>" # Block form expression: <% ... %>
|
||||
r"|\$\{[^}]+\}", # Braced form expression: ${ ... }
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def has_jinja(st):
|
||||
return detect_jinja_re.search(st) is not None
|
||||
|
||||
|
||||
class JinjaStr(str):
|
||||
"""
|
||||
Wraps a string containing an unresolved Jinja expression,
|
||||
storing the variables visible to it when it failed to resolve.
|
||||
For example, an expression inside a package, `${ A * B }` may fail
|
||||
to resolve at package parsing time if `A` is a local package var
|
||||
but `B` is a substitution defined in the root yaml.
|
||||
Therefore, we store the value of `A` as an upvalue bound
|
||||
to the original string so we may be able to resolve `${ A * B }`
|
||||
later in the main substitutions pass.
|
||||
"""
|
||||
|
||||
def __new__(cls, value: str, upvalues=None):
|
||||
obj = super().__new__(cls, value)
|
||||
obj.upvalues = upvalues or {}
|
||||
return obj
|
||||
|
||||
def __init__(self, value: str, upvalues=None):
|
||||
self.upvalues = upvalues or {}
|
||||
|
||||
|
||||
class Jinja:
|
||||
"""
|
||||
Wraps a Jinja environment
|
||||
"""
|
||||
|
||||
def __init__(self, context_vars):
|
||||
self.env = NativeEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
block_start_string="<%",
|
||||
block_end_string="%>",
|
||||
line_statement_prefix="#",
|
||||
line_comment_prefix="##",
|
||||
variable_start_string="${",
|
||||
variable_end_string="}",
|
||||
undefined=jinja.StrictUndefined,
|
||||
)
|
||||
self.env.add_extension("jinja2.ext.do")
|
||||
self.env.globals["math"] = math # Inject entire math module
|
||||
self.context_vars = {**context_vars}
|
||||
self.env.globals = {**self.env.globals, **self.context_vars}
|
||||
|
||||
def expand(self, content_str):
|
||||
"""
|
||||
Renders a string that may contain Jinja expressions or statements
|
||||
Returns the resulting processed string if all values could be resolved.
|
||||
Otherwise, it returns a tagged (JinjaStr) string that captures variables
|
||||
in scope (upvalues), like a closure for later evaluation.
|
||||
"""
|
||||
result = None
|
||||
override_vars = {}
|
||||
if isinstance(content_str, JinjaStr):
|
||||
# If `value` is already a JinjaStr, it means we are trying to evaluate it again
|
||||
# in a parent pass.
|
||||
# Hopefully, all required variables are visible now.
|
||||
override_vars = content_str.upvalues
|
||||
try:
|
||||
template = self.env.from_string(content_str)
|
||||
result = template.render(override_vars)
|
||||
if isinstance(result, Undefined):
|
||||
# This happens when the expression is simply an undefined variable. Jinja does not
|
||||
# raise an exception, instead we get "Undefined".
|
||||
# Trigger an UndefinedError exception so we skip to below.
|
||||
print("" + result)
|
||||
except (TemplateSyntaxError, UndefinedError) as err:
|
||||
# `content_str` contains a Jinja expression that refers to a variable that is undefined
|
||||
# in this scope. Perhaps it refers to a root substitution that is not visible yet.
|
||||
# Therefore, return the original `content_str` as a JinjaStr, which contains the variables
|
||||
# that are actually visible to it at this point to postpone evaluation.
|
||||
return JinjaStr(content_str, {**self.context_vars, **override_vars}), err
|
||||
|
||||
return result, None
|
||||
@@ -789,7 +789,6 @@ def validate_config(
|
||||
result.add_output_path([CONF_SUBSTITUTIONS], CONF_SUBSTITUTIONS)
|
||||
try:
|
||||
substitutions.do_substitution_pass(config, command_line_substitutions)
|
||||
substitutions.do_substitution_pass(config, command_line_substitutions)
|
||||
except vol.Invalid as err:
|
||||
result.add_error(err)
|
||||
return result
|
||||
|
||||
@@ -292,8 +292,6 @@ class ESPHomeLoaderMixin:
|
||||
if file is None:
|
||||
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
|
||||
vars = fields.get(CONF_VARS)
|
||||
if vars:
|
||||
vars = {k: str(v) for k, v in vars.items()}
|
||||
return file, vars
|
||||
|
||||
if isinstance(node, yaml.nodes.MappingNode):
|
||||
|
||||
Reference in New Issue
Block a user