1
0
mirror of https://github.com/esphome/esphome.git synced 2025-11-07 02:21:51 +00:00

Compare commits

..

4 Commits

Author SHA1 Message Date
J. Nick Koston
1fbb2a12b8 Merge branch 'dev' into voice_get_config 2025-10-19 09:22:55 -10:00
J. Nick Koston
f1fddc058e adjust 2025-10-08 08:06:38 -10:00
J. Nick Koston
d5ee5c7861 adjust 2025-10-08 08:05:44 -10:00
J. Nick Koston
542ca43cf6 [voice_assistant] Fix use-after-free crash with configuration StringRef pointers 2025-10-08 07:54:47 -10:00
37 changed files with 470 additions and 2118 deletions

View File

@@ -1,108 +0,0 @@
---
name: Memory Impact Comment (Forks)
on:
workflow_run:
workflows: ["CI"]
types: [completed]
permissions:
contents: read
pull-requests: write
actions: read
jobs:
memory-impact-comment:
name: Post memory impact comment (fork PRs only)
runs-on: ubuntu-24.04
# Only run for PRs from forks that had successful CI runs
if: >
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_repository.full_name != github.repository
env:
GH_TOKEN: ${{ github.token }}
steps:
- name: Get PR details
id: pr
run: |
# Get PR details by searching for PR with matching head SHA
# The workflow_run.pull_requests field is often empty for forks
head_sha="${{ github.event.workflow_run.head_sha }}"
pr_data=$(gh api "/repos/${{ github.repository }}/commits/$head_sha/pulls" \
--jq '.[0] | {number: .number, base_ref: .base.ref}')
if [ -z "$pr_data" ] || [ "$pr_data" == "null" ]; then
echo "No PR found for SHA $head_sha, skipping"
echo "skip=true" >> $GITHUB_OUTPUT
exit 0
fi
pr_number=$(echo "$pr_data" | jq -r '.number')
base_ref=$(echo "$pr_data" | jq -r '.base_ref')
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
echo "base_ref=$base_ref" >> $GITHUB_OUTPUT
echo "Found PR #$pr_number targeting base branch: $base_ref"
- name: Check out code from base repository
if: steps.pr.outputs.skip != 'true'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
# Always check out from the base repository (esphome/esphome), never from forks
# Use the PR's target branch to ensure we run trusted code from the main repo
repository: ${{ github.repository }}
ref: ${{ steps.pr.outputs.base_ref }}
- name: Restore Python
if: steps.pr.outputs.skip != 'true'
uses: ./.github/actions/restore-python
with:
python-version: "3.11"
cache-key: ${{ hashFiles('.cache-key') }}
- name: Download memory analysis artifacts
if: steps.pr.outputs.skip != 'true'
run: |
run_id="${{ github.event.workflow_run.id }}"
echo "Downloading artifacts from workflow run $run_id"
mkdir -p memory-analysis
# Download target analysis artifact
if gh run download --name "memory-analysis-target" --dir memory-analysis --repo "${{ github.repository }}" "$run_id"; then
echo "Downloaded memory-analysis-target artifact."
else
echo "No memory-analysis-target artifact found."
fi
# Download PR analysis artifact
if gh run download --name "memory-analysis-pr" --dir memory-analysis --repo "${{ github.repository }}" "$run_id"; then
echo "Downloaded memory-analysis-pr artifact."
else
echo "No memory-analysis-pr artifact found."
fi
- name: Check if artifacts exist
id: check
if: steps.pr.outputs.skip != 'true'
run: |
if [ -f ./memory-analysis/memory-analysis-target.json ] && [ -f ./memory-analysis/memory-analysis-pr.json ]; then
echo "found=true" >> $GITHUB_OUTPUT
else
echo "found=false" >> $GITHUB_OUTPUT
echo "Memory analysis artifacts not found, skipping comment"
fi
- name: Post or update PR comment
if: steps.pr.outputs.skip != 'true' && steps.check.outputs.found == 'true'
env:
PR_NUMBER: ${{ steps.pr.outputs.pr_number }}
run: |
. venv/bin/activate
# Pass PR number and JSON file paths directly to Python script
# Let Python parse the JSON to avoid shell injection risks
# The script will validate and sanitize all inputs
python script/ci_memory_impact_comment.py \
--pr-number "$PR_NUMBER" \
--target-json ./memory-analysis/memory-analysis-target.json \
--pr-json ./memory-analysis/memory-analysis-pr.json

View File

@@ -432,21 +432,6 @@ jobs:
with:
python-version: ${{ env.DEFAULT_PYTHON }}
cache-key: ${{ needs.common.outputs.cache-key }}
- name: Cache platformio
if: github.ref == 'refs/heads/dev'
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ~/.platformio
key: platformio-test-${{ hashFiles('platformio.ini') }}
- name: Cache platformio
if: github.ref != 'refs/heads/dev'
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ~/.platformio
key: platformio-test-${{ hashFiles('platformio.ini') }}
- name: Validate and compile components with intelligent grouping
run: |
. venv/bin/activate
@@ -656,12 +641,6 @@ jobs:
--output-env \
--output-json memory-analysis-target.json
# Add metadata to JSON before caching
python script/ci_add_metadata_to_json.py \
--json-file memory-analysis-target.json \
--components "$components" \
--platform "$platform"
- name: Save memory analysis to cache
if: steps.check-script.outputs.skip != 'true' && steps.cache-memory-analysis.outputs.cache-hit != 'true' && steps.build.outcome == 'success'
uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
@@ -741,13 +720,6 @@ jobs:
python script/ci_memory_impact_extract.py \
--output-env \
--output-json memory-analysis-pr.json
# Add metadata to JSON (components and platform are in shell variables above)
python script/ci_add_metadata_to_json.py \
--json-file memory-analysis-pr.json \
--components "$components" \
--platform "$platform"
- name: Upload memory analysis JSON
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
@@ -764,12 +736,10 @@ jobs:
- determine-jobs
- memory-impact-target-branch
- memory-impact-pr-branch
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && fromJSON(needs.determine-jobs.outputs.memory_impact).should_run == 'true' && needs.memory-impact-target-branch.outputs.skip != 'true'
if: github.event_name == 'pull_request' && fromJSON(needs.determine-jobs.outputs.memory_impact).should_run == 'true' && needs.memory-impact-target-branch.outputs.skip != 'true'
permissions:
contents: read
pull-requests: write
env:
GH_TOKEN: ${{ github.token }}
steps:
- name: Check out code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
@@ -792,16 +762,52 @@ jobs:
continue-on-error: true
- name: Post or update PR comment
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
GH_TOKEN: ${{ github.token }}
COMPONENTS: ${{ toJSON(fromJSON(needs.determine-jobs.outputs.memory_impact).components) }}
PLATFORM: ${{ fromJSON(needs.determine-jobs.outputs.memory_impact).platform }}
TARGET_RAM: ${{ needs.memory-impact-target-branch.outputs.ram_usage }}
TARGET_FLASH: ${{ needs.memory-impact-target-branch.outputs.flash_usage }}
PR_RAM: ${{ needs.memory-impact-pr-branch.outputs.ram_usage }}
PR_FLASH: ${{ needs.memory-impact-pr-branch.outputs.flash_usage }}
TARGET_CACHE_HIT: ${{ needs.memory-impact-target-branch.outputs.cache_hit }}
run: |
. venv/bin/activate
# Pass JSON file paths directly to Python script
# All data is extracted from JSON files for security
# Check if analysis JSON files exist
target_json_arg=""
pr_json_arg=""
if [ -f ./memory-analysis/memory-analysis-target.json ]; then
echo "Found target analysis JSON"
target_json_arg="--target-json ./memory-analysis/memory-analysis-target.json"
else
echo "No target analysis JSON found"
fi
if [ -f ./memory-analysis/memory-analysis-pr.json ]; then
echo "Found PR analysis JSON"
pr_json_arg="--pr-json ./memory-analysis/memory-analysis-pr.json"
else
echo "No PR analysis JSON found"
fi
# Add cache flag if target was cached
cache_flag=""
if [ "$TARGET_CACHE_HIT" == "true" ]; then
cache_flag="--target-cache-hit"
fi
python script/ci_memory_impact_comment.py \
--pr-number "$PR_NUMBER" \
--target-json ./memory-analysis/memory-analysis-target.json \
--pr-json ./memory-analysis/memory-analysis-pr.json
--pr-number "${{ github.event.pull_request.number }}" \
--components "$COMPONENTS" \
--platform "$PLATFORM" \
--target-ram "$TARGET_RAM" \
--target-flash "$TARGET_FLASH" \
--pr-ram "$PR_RAM" \
--pr-flash "$PR_FLASH" \
$target_json_arg \
$pr_json_arg \
$cache_flag
ci-status:
name: CI Status

View File

@@ -1,8 +1,8 @@
#pragma once
#include <set>
#include "climate_mode.h"
#include "esphome/core/helpers.h"
#include "climate_mode.h"
#include <set>
namespace esphome {
@@ -109,12 +109,44 @@ class ClimateTraits {
void set_supported_modes(std::set<ClimateMode> modes) { this->supported_modes_ = std::move(modes); }
void add_supported_mode(ClimateMode mode) { this->supported_modes_.insert(mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_auto_mode(bool supports_auto_mode) { set_mode_support_(CLIMATE_MODE_AUTO, supports_auto_mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_cool_mode(bool supports_cool_mode) { set_mode_support_(CLIMATE_MODE_COOL, supports_cool_mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_heat_mode(bool supports_heat_mode) { set_mode_support_(CLIMATE_MODE_HEAT, supports_heat_mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_heat_cool_mode(bool supported) { set_mode_support_(CLIMATE_MODE_HEAT_COOL, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_fan_only_mode(bool supports_fan_only_mode) {
set_mode_support_(CLIMATE_MODE_FAN_ONLY, supports_fan_only_mode);
}
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_dry_mode(bool supports_dry_mode) { set_mode_support_(CLIMATE_MODE_DRY, supports_dry_mode); }
bool supports_mode(ClimateMode mode) const { return this->supported_modes_.count(mode); }
const std::set<ClimateMode> &get_supported_modes() const { return this->supported_modes_; }
void set_supported_fan_modes(std::set<ClimateFanMode> modes) { this->supported_fan_modes_ = std::move(modes); }
void add_supported_fan_mode(ClimateFanMode mode) { this->supported_fan_modes_.insert(mode); }
void add_supported_custom_fan_mode(const std::string &mode) { this->supported_custom_fan_modes_.insert(mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_on(bool supported) { set_fan_mode_support_(CLIMATE_FAN_ON, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_off(bool supported) { set_fan_mode_support_(CLIMATE_FAN_OFF, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_auto(bool supported) { set_fan_mode_support_(CLIMATE_FAN_AUTO, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_low(bool supported) { set_fan_mode_support_(CLIMATE_FAN_LOW, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_medium(bool supported) { set_fan_mode_support_(CLIMATE_FAN_MEDIUM, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_high(bool supported) { set_fan_mode_support_(CLIMATE_FAN_HIGH, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_middle(bool supported) { set_fan_mode_support_(CLIMATE_FAN_MIDDLE, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_focus(bool supported) { set_fan_mode_support_(CLIMATE_FAN_FOCUS, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_diffuse(bool supported) { set_fan_mode_support_(CLIMATE_FAN_DIFFUSE, supported); }
bool supports_fan_mode(ClimateFanMode fan_mode) const { return this->supported_fan_modes_.count(fan_mode); }
bool get_supports_fan_modes() const {
return !this->supported_fan_modes_.empty() || !this->supported_custom_fan_modes_.empty();
@@ -146,6 +178,16 @@ class ClimateTraits {
void set_supported_swing_modes(std::set<ClimateSwingMode> modes) { this->supported_swing_modes_ = std::move(modes); }
void add_supported_swing_mode(ClimateSwingMode mode) { this->supported_swing_modes_.insert(mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_off(bool supported) { set_swing_mode_support_(CLIMATE_SWING_OFF, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_both(bool supported) { set_swing_mode_support_(CLIMATE_SWING_BOTH, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_vertical(bool supported) { set_swing_mode_support_(CLIMATE_SWING_VERTICAL, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_horizontal(bool supported) {
set_swing_mode_support_(CLIMATE_SWING_HORIZONTAL, supported);
}
bool supports_swing_mode(ClimateSwingMode swing_mode) const { return this->supported_swing_modes_.count(swing_mode); }
bool get_supports_swing_modes() const { return !this->supported_swing_modes_.empty(); }
const std::set<ClimateSwingMode> &get_supported_swing_modes() const { return this->supported_swing_modes_; }

View File

@@ -6,7 +6,6 @@
#include <freertos/FreeRTOS.h>
#include <freertos/task.h>
#include <esp_idf_version.h>
#include <esp_ota_ops.h>
#include <esp_task_wdt.h>
#include <esp_timer.h>
#include <soc/rtc.h>
@@ -53,16 +52,6 @@ void arch_init() {
disableCore1WDT();
#endif
#endif
// If the bootloader was compiled with CONFIG_BOOTLOADER_APP_ROLLBACK_ENABLE the current
// partition will get rolled back unless it is marked as valid.
esp_ota_img_states_t state;
const esp_partition_t *running = esp_ota_get_running_partition();
if (esp_ota_get_state_partition(running, &state) == ESP_OK) {
if (state == ESP_OTA_IMG_PENDING_VERIFY) {
esp_ota_mark_app_valid_cancel_rollback();
}
}
}
void IRAM_ATTR HOT arch_feed_wdt() { esp_task_wdt_reset(); }

View File

@@ -14,7 +14,7 @@ void Kuntze::on_modbus_data(const std::vector<uint8_t> &data) {
auto get_16bit = [&](int i) -> uint16_t { return (uint16_t(data[i * 2]) << 8) | uint16_t(data[i * 2 + 1]); };
this->waiting_ = false;
ESP_LOGV(TAG, "Data: %s", format_hex_pretty(data).c_str());
ESP_LOGV(TAG, "Data: %s", hexencode(data).c_str());
float value = (float) get_16bit(0);
for (int i = 0; i < data[3]; i++)

View File

@@ -1,11 +1,11 @@
#pragma once
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/color.h"
#include "esp_color_correction.h"
#include "esp_color_view.h"
#include "esp_range_view.h"
#include "esphome/core/color.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "light_output.h"
#include "light_state.h"
#include "transformers.h"
@@ -17,6 +17,8 @@
namespace esphome {
namespace light {
using ESPColor ESPDEPRECATED("esphome::light::ESPColor is deprecated, use esphome::Color instead.", "v1.21") = Color;
/// Convert the color information from a `LightColorValues` object to a `Color` object (does not apply brightness).
Color color_from_light_color_values(LightColorValues val);

View File

@@ -1,7 +1,7 @@
#pragma once
#include "color_mode.h"
#include "esphome/core/helpers.h"
#include "color_mode.h"
namespace esphome {
@@ -31,6 +31,26 @@ class LightTraits {
return this->supported_color_modes_.has_capability(color_capability);
}
ESPDEPRECATED("get_supports_brightness() is deprecated, use color modes instead.", "v1.21")
bool get_supports_brightness() const { return this->supports_color_capability(ColorCapability::BRIGHTNESS); }
ESPDEPRECATED("get_supports_rgb() is deprecated, use color modes instead.", "v1.21")
bool get_supports_rgb() const { return this->supports_color_capability(ColorCapability::RGB); }
ESPDEPRECATED("get_supports_rgb_white_value() is deprecated, use color modes instead.", "v1.21")
bool get_supports_rgb_white_value() const {
return this->supports_color_mode(ColorMode::RGB_WHITE) ||
this->supports_color_mode(ColorMode::RGB_COLOR_TEMPERATURE);
}
ESPDEPRECATED("get_supports_color_temperature() is deprecated, use color modes instead.", "v1.21")
bool get_supports_color_temperature() const {
return this->supports_color_capability(ColorCapability::COLOR_TEMPERATURE);
}
ESPDEPRECATED("get_supports_color_interlock() is deprecated, use color modes instead.", "v1.21")
bool get_supports_color_interlock() const {
return this->supports_color_mode(ColorMode::RGB) &&
(this->supports_color_mode(ColorMode::WHITE) || this->supports_color_mode(ColorMode::COLD_WARM_WHITE) ||
this->supports_color_mode(ColorMode::COLOR_TEMPERATURE));
}
float get_min_mireds() const { return this->min_mireds_; }
void set_min_mireds(float min_mireds) { this->min_mireds_ = min_mireds; }
float get_max_mireds() const { return this->max_mireds_; }

View File

@@ -1291,6 +1291,9 @@ void Nextion::check_pending_waveform_() {
void Nextion::set_writer(const nextion_writer_t &writer) { this->writer_ = writer; }
ESPDEPRECATED("set_wait_for_ack(bool) deprecated, no effect", "v1.20")
void Nextion::set_wait_for_ack(bool wait_for_ack) { ESP_LOGE(TAG, "Deprecated"); }
bool Nextion::is_updating() { return this->connection_state_.is_updating_; }
} // namespace nextion

View File

@@ -45,26 +45,13 @@ def get_script(script_id):
def check_max_runs(value):
# Set default for queued mode to prevent unbounded queue growth
if CONF_MAX_RUNS not in value and value[CONF_MODE] == CONF_QUEUED:
value[CONF_MAX_RUNS] = 5
if CONF_MAX_RUNS not in value:
return value
if value[CONF_MODE] not in [CONF_QUEUED, CONF_PARALLEL]:
raise cv.Invalid(
"The option 'max_runs' is only valid in 'queued' and 'parallel' mode.",
"The option 'max_runs' is only valid in 'queue' and 'parallel' mode.",
path=[CONF_MAX_RUNS],
)
# Queued mode must have bounded queue (min 1), parallel mode can be unlimited (0)
if value[CONF_MODE] == CONF_QUEUED and value[CONF_MAX_RUNS] < 1:
raise cv.Invalid(
"The option 'max_runs' must be at least 1 for queued mode.",
path=[CONF_MAX_RUNS],
)
return value
@@ -119,7 +106,7 @@ CONFIG_SCHEMA = automation.validate_automation(
cv.Optional(CONF_MODE, default=CONF_SINGLE): cv.one_of(
*SCRIPT_MODES, lower=True
),
cv.Optional(CONF_MAX_RUNS): cv.int_range(min=0, max=100),
cv.Optional(CONF_MAX_RUNS): cv.positive_int,
cv.Optional(CONF_PARAMETERS, default={}): cv.Schema(
{
validate_parameter_name: validate_parameter_type,

View File

@@ -1,11 +1,10 @@
#pragma once
#include <memory>
#include <tuple>
#include "esphome/core/automation.h"
#include "esphome/core/component.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
#include <queue>
namespace esphome {
namespace script {
@@ -97,41 +96,23 @@ template<typename... Ts> class RestartScript : public Script<Ts...> {
/** A script type that queues new instances that are created.
*
* Only one instance of the script can be active at a time.
*
* Ring buffer implementation:
* - num_queued_ tracks the number of queued (waiting) instances, NOT including the currently running one
* - queue_front_ points to the next item to execute (read position)
* - Buffer size is max_runs_ - 1 (max total instances minus the running one)
* - Write position is calculated as: (queue_front_ + num_queued_) % (max_runs_ - 1)
* - When an item finishes, queue_front_ advances: (queue_front_ + 1) % (max_runs_ - 1)
* - First execute() runs immediately without queuing (num_queued_ stays 0)
* - Subsequent executes while running are queued starting at position 0
* - Maximum total instances = max_runs_ (includes 1 running + (max_runs_ - 1) queued)
*/
template<typename... Ts> class QueueingScript : public Script<Ts...>, public Component {
public:
void execute(Ts... x) override {
if (this->is_action_running() || this->num_queued_ > 0) {
// num_queued_ is the number of *queued* instances (waiting, not including currently running)
// max_runs_ is the maximum *total* instances (running + queued)
// So we reject when num_queued_ + 1 >= max_runs_ (queued + running >= max)
if (this->num_queued_ + 1 >= this->max_runs_) {
this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' max instances (running + queued) reached!"),
if (this->is_action_running() || this->num_runs_ > 0) {
// num_runs_ is the number of *queued* instances, so total number of instances is
// num_runs_ + 1
if (this->max_runs_ != 0 && this->num_runs_ + 1 >= this->max_runs_) {
this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' maximum number of queued runs exceeded!"),
LOG_STR_ARG(this->name_));
return;
}
// Initialize queue on first queued item (after capacity check)
this->lazy_init_queue_();
this->esp_logd_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' queueing new instance (mode: queued)"),
LOG_STR_ARG(this->name_));
// Ring buffer: write to (queue_front_ + num_queued_) % queue_capacity
const size_t queue_capacity = static_cast<size_t>(this->max_runs_ - 1);
size_t write_pos = (this->queue_front_ + this->num_queued_) % queue_capacity;
// Use std::make_unique to replace the unique_ptr
this->var_queue_[write_pos] = std::make_unique<std::tuple<Ts...>>(x...);
this->num_queued_++;
this->num_runs_++;
this->var_queue_.push(std::make_tuple(x...));
return;
}
@@ -141,46 +122,29 @@ template<typename... Ts> class QueueingScript : public Script<Ts...>, public Com
}
void stop() override {
// Clear all queued items to free memory immediately
// Resetting the array automatically destroys all unique_ptrs and their contents
this->var_queue_.reset();
this->num_queued_ = 0;
this->queue_front_ = 0;
this->num_runs_ = 0;
Script<Ts...>::stop();
}
void loop() override {
if (this->num_queued_ != 0 && !this->is_action_running()) {
// Dequeue: decrement count, move tuple out (frees slot), advance read position
this->num_queued_--;
const size_t queue_capacity = static_cast<size_t>(this->max_runs_ - 1);
auto tuple_ptr = std::move(this->var_queue_[this->queue_front_]);
this->queue_front_ = (this->queue_front_ + 1) % queue_capacity;
this->trigger_tuple_(*tuple_ptr, typename gens<sizeof...(Ts)>::type());
if (this->num_runs_ != 0 && !this->is_action_running()) {
this->num_runs_--;
auto &vars = this->var_queue_.front();
this->var_queue_.pop();
this->trigger_tuple_(vars, typename gens<sizeof...(Ts)>::type());
}
}
void set_max_runs(int max_runs) { max_runs_ = max_runs; }
protected:
// Lazy init queue on first use - avoids setup() ordering issues and saves memory
// if script is never executed during this boot cycle
inline void lazy_init_queue_() {
if (!this->var_queue_) {
// Allocate array of max_runs_ - 1 slots for queued items (running item is separate)
// unique_ptr array is zero-initialized, so all slots start as nullptr
this->var_queue_ = std::make_unique<std::unique_ptr<std::tuple<Ts...>>[]>(this->max_runs_ - 1);
}
}
template<int... S> void trigger_tuple_(const std::tuple<Ts...> &tuple, seq<S...> /*unused*/) {
this->trigger(std::get<S>(tuple)...);
}
int num_queued_ = 0; // Number of queued instances (not including currently running)
int max_runs_ = 0; // Maximum total instances (running + queued)
size_t queue_front_ = 0; // Ring buffer read position (next item to execute)
std::unique_ptr<std::unique_ptr<std::tuple<Ts...>>[]> var_queue_; // Ring buffer of queued parameters
int num_runs_ = 0;
int max_runs_ = 0;
std::queue<std::tuple<Ts...>> var_queue_;
};
/** A script type that executes new instances in parallel.

View File

@@ -953,11 +953,21 @@ void VoiceAssistant::on_set_configuration(const std::vector<std::string> &active
}
}
}
// Mark configuration dirty to trigger rebuild on next get_configuration() call.
this->config_needs_rebuild_ = true;
}
#endif
};
const Configuration &VoiceAssistant::get_configuration() {
// Return cached configuration if it hasn't changed. This prevents a use-after-free
// race condition when API message serialization creates StringRef pointers to strings
// in config_.available_wake_words, and avoids wastefully rebuilding on every call.
if (!this->config_needs_rebuild_) {
return this->config_;
}
this->config_.available_wake_words.clear();
this->config_.active_wake_words.clear();
@@ -986,6 +996,8 @@ const Configuration &VoiceAssistant::get_configuration() {
}
#endif
// Mark configuration as clean now that we've rebuilt it
this->config_needs_rebuild_ = false;
return this->config_;
};

View File

@@ -112,7 +112,10 @@ class VoiceAssistant : public Component {
void set_microphone_source(microphone::MicrophoneSource *mic_source) { this->mic_source_ = mic_source; }
#ifdef USE_MICRO_WAKE_WORD
void set_micro_wake_word(micro_wake_word::MicroWakeWord *mww) { this->micro_wake_word_ = mww; }
void set_micro_wake_word(micro_wake_word::MicroWakeWord *mww) {
this->micro_wake_word_ = mww;
this->config_needs_rebuild_ = true;
}
#endif
#ifdef USE_SPEAKER
void set_speaker(speaker::Speaker *speaker) {
@@ -313,7 +316,11 @@ class VoiceAssistant : public Component {
bool udp_socket_running_{false};
bool start_udp_socket_();
// Configuration caching for safety and performance. Only rebuild when config_needs_rebuild_
// is true to prevent use-after-free race condition when StringRef pointers reference
// wake word strings during API message serialization, and to avoid wasteful rebuilding.
Configuration config_{};
bool config_needs_rebuild_{true};
#ifdef USE_MICRO_WAKE_WORD
micro_wake_word::MicroWakeWord *micro_wake_word_{nullptr};

View File

@@ -407,8 +407,7 @@ async def to_code(config):
cg.add(var.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT]))
cg.add(var.set_power_save_mode(config[CONF_POWER_SAVE_MODE]))
if config[CONF_FAST_CONNECT]:
cg.add_define("USE_WIFI_FAST_CONNECT")
cg.add(var.set_fast_connect(config[CONF_FAST_CONNECT]))
cg.add(var.set_passive_scan(config[CONF_PASSIVE_SCAN]))
if CONF_OUTPUT_POWER in config:
cg.add(var.set_output_power(config[CONF_OUTPUT_POWER]))

View File

@@ -84,9 +84,9 @@ void WiFiComponent::start() {
uint32_t hash = this->has_sta() ? fnv1_hash(App.get_compilation_time()) : 88491487UL;
this->pref_ = global_preferences->make_preference<wifi::SavedWifiSettings>(hash, true);
#ifdef USE_WIFI_FAST_CONNECT
this->fast_connect_pref_ = global_preferences->make_preference<wifi::SavedWifiFastConnectSettings>(hash + 1, false);
#endif
if (this->fast_connect_) {
this->fast_connect_pref_ = global_preferences->make_preference<wifi::SavedWifiFastConnectSettings>(hash + 1, false);
}
SavedWifiSettings save{};
if (this->pref_.load(&save)) {
@@ -108,16 +108,16 @@ void WiFiComponent::start() {
ESP_LOGV(TAG, "Setting Power Save Option failed");
}
#ifdef USE_WIFI_FAST_CONNECT
this->trying_loaded_ap_ = this->load_fast_connect_settings_();
if (!this->trying_loaded_ap_) {
this->ap_index_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
if (this->fast_connect_) {
this->trying_loaded_ap_ = this->load_fast_connect_settings_();
if (!this->trying_loaded_ap_) {
this->ap_index_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
}
this->start_connecting(this->selected_ap_, false);
} else {
this->start_scanning();
}
this->start_connecting(this->selected_ap_, false);
#else
this->start_scanning();
#endif
#ifdef USE_WIFI_AP
} else if (this->has_ap()) {
this->setup_ap_config_();
@@ -168,20 +168,13 @@ void WiFiComponent::loop() {
case WIFI_COMPONENT_STATE_COOLDOWN: {
this->status_set_warning(LOG_STR("waiting to reconnect"));
if (millis() - this->action_started_ > 5000) {
#ifdef USE_WIFI_FAST_CONNECT
// NOTE: This check may not make sense here as it could interfere with AP cycling
if (!this->selected_ap_.get_bssid().has_value())
this->selected_ap_ = this->sta_[0];
this->start_connecting(this->selected_ap_, false);
#else
if (this->retry_hidden_) {
if (this->fast_connect_ || this->retry_hidden_) {
if (!this->selected_ap_.get_bssid().has_value())
this->selected_ap_ = this->sta_[0];
this->start_connecting(this->selected_ap_, false);
} else {
this->start_scanning();
}
#endif
}
break;
}
@@ -251,6 +244,7 @@ WiFiComponent::WiFiComponent() { global_wifi_component = this; }
bool WiFiComponent::has_ap() const { return this->has_ap_; }
bool WiFiComponent::has_sta() const { return !this->sta_.empty(); }
void WiFiComponent::set_fast_connect(bool fast_connect) { this->fast_connect_ = fast_connect; }
#ifdef USE_WIFI_11KV_SUPPORT
void WiFiComponent::set_btm(bool btm) { this->btm_ = btm; }
void WiFiComponent::set_rrm(bool rrm) { this->rrm_ = rrm; }
@@ -613,12 +607,10 @@ void WiFiComponent::check_scanning_finished() {
for (auto &ap : this->sta_) {
if (res.matches(ap)) {
res.set_matches(true);
// Cache priority lookup - do single search instead of 2 separate searches
const bssid_t &bssid = res.get_bssid();
if (!this->has_sta_priority(bssid)) {
this->set_sta_priority(bssid, ap.get_priority());
if (!this->has_sta_priority(res.get_bssid())) {
this->set_sta_priority(res.get_bssid(), ap.get_priority());
}
res.set_priority(this->get_sta_priority(bssid));
res.set_priority(this->get_sta_priority(res.get_bssid()));
break;
}
}
@@ -637,9 +629,8 @@ void WiFiComponent::check_scanning_finished() {
return;
}
// Build connection params directly into selected_ap_ to avoid extra copy
const WiFiScanResult &scan_res = this->scan_result_[0];
WiFiAP &selected = this->selected_ap_;
WiFiAP connect_params;
WiFiScanResult scan_res = this->scan_result_[0];
for (auto &config : this->sta_) {
// search for matching STA config, at least one will match (from checks before)
if (!scan_res.matches(config)) {
@@ -648,38 +639,37 @@ void WiFiComponent::check_scanning_finished() {
if (config.get_hidden()) {
// selected network is hidden, we use the data from the config
selected.set_hidden(true);
selected.set_ssid(config.get_ssid());
// Clear channel and BSSID for hidden networks - there might be multiple hidden networks
connect_params.set_hidden(true);
connect_params.set_ssid(config.get_ssid());
// don't set BSSID and channel, there might be multiple hidden networks
// but we can't know which one is the correct one. Rely on probe-req with just SSID.
selected.set_channel(0);
selected.set_bssid(optional<bssid_t>{});
} else {
// selected network is visible, we use the data from the scan
// limit the connect params to only connect to exactly this network
// (network selection is done during scan phase).
selected.set_hidden(false);
selected.set_ssid(scan_res.get_ssid());
selected.set_channel(scan_res.get_channel());
selected.set_bssid(scan_res.get_bssid());
connect_params.set_hidden(false);
connect_params.set_ssid(scan_res.get_ssid());
connect_params.set_channel(scan_res.get_channel());
connect_params.set_bssid(scan_res.get_bssid());
}
// copy manual IP (if set)
selected.set_manual_ip(config.get_manual_ip());
connect_params.set_manual_ip(config.get_manual_ip());
#ifdef USE_WIFI_WPA2_EAP
// copy EAP parameters (if set)
selected.set_eap(config.get_eap());
connect_params.set_eap(config.get_eap());
#endif
// copy password (if set)
selected.set_password(config.get_password());
connect_params.set_password(config.get_password());
break;
}
yield();
this->start_connecting(this->selected_ap_, false);
this->selected_ap_ = connect_params;
this->start_connecting(connect_params, false);
}
void WiFiComponent::dump_config() {
@@ -729,9 +719,9 @@ void WiFiComponent::check_connecting_finished() {
this->scan_result_.shrink_to_fit();
}
#ifdef USE_WIFI_FAST_CONNECT
this->save_fast_connect_settings_();
#endif
if (this->fast_connect_) {
this->save_fast_connect_settings_();
}
return;
}
@@ -779,31 +769,31 @@ void WiFiComponent::retry_connect() {
delay(10);
if (!this->is_captive_portal_active_() && !this->is_esp32_improv_active_() &&
(this->num_retried_ > 3 || this->error_from_callback_)) {
#ifdef USE_WIFI_FAST_CONNECT
if (this->trying_loaded_ap_) {
this->trying_loaded_ap_ = false;
this->ap_index_ = 0; // Retry from the first configured AP
} else if (this->ap_index_ >= this->sta_.size() - 1) {
ESP_LOGW(TAG, "No more APs to try");
this->ap_index_ = 0;
this->restart_adapter();
if (this->fast_connect_) {
if (this->trying_loaded_ap_) {
this->trying_loaded_ap_ = false;
this->ap_index_ = 0; // Retry from the first configured AP
} else if (this->ap_index_ >= this->sta_.size() - 1) {
ESP_LOGW(TAG, "No more APs to try");
this->ap_index_ = 0;
this->restart_adapter();
} else {
// Try next AP
this->ap_index_++;
}
this->num_retried_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
} else {
// Try next AP
this->ap_index_++;
if (this->num_retried_ > 5) {
// If retry failed for more than 5 times, let's restart STA
this->restart_adapter();
} else {
// Try hidden networks after 3 failed retries
ESP_LOGD(TAG, "Retrying with hidden networks");
this->retry_hidden_ = true;
this->num_retried_++;
}
}
this->num_retried_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
#else
if (this->num_retried_ > 5) {
// If retry failed for more than 5 times, let's restart STA
this->restart_adapter();
} else {
// Try hidden networks after 3 failed retries
ESP_LOGD(TAG, "Retrying with hidden networks");
this->retry_hidden_ = true;
this->num_retried_++;
}
#endif
} else {
this->num_retried_++;
}
@@ -849,7 +839,6 @@ bool WiFiComponent::is_esp32_improv_active_() {
#endif
}
#ifdef USE_WIFI_FAST_CONNECT
bool WiFiComponent::load_fast_connect_settings_() {
SavedWifiFastConnectSettings fast_connect_save{};
@@ -884,7 +873,6 @@ void WiFiComponent::save_fast_connect_settings_() {
ESP_LOGD(TAG, "Saved fast_connect settings");
}
}
#endif
void WiFiAP::set_ssid(const std::string &ssid) { this->ssid_ = ssid; }
void WiFiAP::set_bssid(bssid_t bssid) { this->bssid_ = bssid; }
@@ -914,7 +902,7 @@ WiFiScanResult::WiFiScanResult(const bssid_t &bssid, std::string ssid, uint8_t c
rssi_(rssi),
with_auth_(with_auth),
is_hidden_(is_hidden) {}
bool WiFiScanResult::matches(const WiFiAP &config) const {
bool WiFiScanResult::matches(const WiFiAP &config) {
if (config.get_hidden()) {
// User configured a hidden network, only match actually hidden networks
// don't match SSID

View File

@@ -170,7 +170,7 @@ class WiFiScanResult {
public:
WiFiScanResult(const bssid_t &bssid, std::string ssid, uint8_t channel, int8_t rssi, bool with_auth, bool is_hidden);
bool matches(const WiFiAP &config) const;
bool matches(const WiFiAP &config);
bool get_matches() const;
void set_matches(bool matches);
@@ -240,6 +240,7 @@ class WiFiComponent : public Component {
void start_scanning();
void check_scanning_finished();
void start_connecting(const WiFiAP &ap, bool two);
void set_fast_connect(bool fast_connect);
void set_ap_timeout(uint32_t ap_timeout) { ap_timeout_ = ap_timeout; }
void check_connecting_finished();
@@ -363,10 +364,8 @@ class WiFiComponent : public Component {
bool is_captive_portal_active_();
bool is_esp32_improv_active_();
#ifdef USE_WIFI_FAST_CONNECT
bool load_fast_connect_settings_();
void save_fast_connect_settings_();
#endif
#ifdef USE_ESP8266
static void wifi_event_callback(System_Event_t *event);
@@ -400,9 +399,7 @@ class WiFiComponent : public Component {
WiFiAP ap_;
optional<float> output_power_;
ESPPreferenceObject pref_;
#ifdef USE_WIFI_FAST_CONNECT
ESPPreferenceObject fast_connect_pref_;
#endif
// Group all 32-bit integers together
uint32_t action_started_;
@@ -414,17 +411,14 @@ class WiFiComponent : public Component {
WiFiComponentState state_{WIFI_COMPONENT_STATE_OFF};
WiFiPowerSaveMode power_save_{WIFI_POWER_SAVE_NONE};
uint8_t num_retried_{0};
#ifdef USE_WIFI_FAST_CONNECT
uint8_t ap_index_{0};
#endif
#if USE_NETWORK_IPV6
uint8_t num_ipv6_addresses_{0};
#endif /* USE_NETWORK_IPV6 */
// Group all boolean values together
#ifdef USE_WIFI_FAST_CONNECT
bool fast_connect_{false};
bool trying_loaded_ap_{false};
#endif
bool retry_hidden_{false};
bool has_ap_{false};
bool handled_connected_state_{false};

View File

@@ -706,10 +706,10 @@ void WiFiComponent::wifi_scan_done_callback_(void *arg, STATUS status) {
this->scan_result_.init(count);
for (bss_info *it = head; it != nullptr; it = STAILQ_NEXT(it, next)) {
this->scan_result_.emplace_back(
bssid_t{it->bssid[0], it->bssid[1], it->bssid[2], it->bssid[3], it->bssid[4], it->bssid[5]},
std::string(reinterpret_cast<char *>(it->ssid), it->ssid_len), it->channel, it->rssi, it->authmode != AUTH_OPEN,
it->is_hidden != 0);
WiFiScanResult res({it->bssid[0], it->bssid[1], it->bssid[2], it->bssid[3], it->bssid[4], it->bssid[5]},
std::string(reinterpret_cast<char *>(it->ssid), it->ssid_len), it->channel, it->rssi,
it->authmode != AUTH_OPEN, it->is_hidden != 0);
this->scan_result_.push_back(res);
}
this->scan_done_ = true;
}

View File

@@ -789,8 +789,8 @@ void WiFiComponent::wifi_process_event_(IDFWiFiEvent *data) {
bssid_t bssid;
std::copy(record.bssid, record.bssid + 6, bssid.begin());
std::string ssid(reinterpret_cast<const char *>(record.ssid));
scan_result_.emplace_back(bssid, ssid, record.primary, record.rssi, record.authmode != WIFI_AUTH_OPEN,
ssid.empty());
WiFiScanResult result(bssid, ssid, record.primary, record.rssi, record.authmode != WIFI_AUTH_OPEN, ssid.empty());
scan_result_.push_back(result);
}
} else if (data->event_base == WIFI_EVENT && data->event_id == WIFI_EVENT_AP_START) {

View File

@@ -419,9 +419,9 @@ void WiFiComponent::wifi_scan_done_callback_() {
uint8_t *bssid = WiFi.BSSID(i);
int32_t channel = WiFi.channel(i);
this->scan_result_.emplace_back(bssid_t{bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]},
std::string(ssid.c_str()), channel, rssi, authmode != WIFI_AUTH_OPEN,
ssid.length() == 0);
WiFiScanResult scan({bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]}, std::string(ssid.c_str()),
channel, rssi, authmode != WIFI_AUTH_OPEN, ssid.length() == 0);
this->scan_result_.push_back(scan);
}
WiFi.scanDelete();
this->scan_done_ = true;

View File

@@ -12,7 +12,7 @@ from typing import Any
import voluptuous as vol
from esphome import core, loader, pins, yaml_util
from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered
from esphome.config_helpers import Extend, Remove, merge_dicts_ordered
import esphome.config_validation as cv
from esphome.const import (
CONF_ESPHOME,
@@ -324,7 +324,13 @@ def iter_ids(config, path=None):
yield from iter_ids(value, path + [key])
def check_replaceme(value):
def recursive_check_replaceme(value):
if isinstance(value, list):
return cv.Schema([recursive_check_replaceme])(value)
if isinstance(value, dict):
return cv.Schema({cv.valid: recursive_check_replaceme})(value)
if isinstance(value, ESPLiteralValue):
pass
if isinstance(value, str) and value == "REPLACEME":
raise cv.Invalid(
"Found 'REPLACEME' in configuration, this is most likely an error. "
@@ -333,86 +339,7 @@ def check_replaceme(value):
"If you want to use the literal REPLACEME string, "
'please use "!literal REPLACEME"'
)
def _build_list_index(lst):
index = OrderedDict()
extensions, removals = [], set()
for item in lst:
if item is None:
removals.add(None)
continue
item_id = None
if isinstance(item, dict) and (item_id := item.get(CONF_ID)):
if isinstance(item_id, Extend):
extensions.append(item)
continue
if isinstance(item_id, Remove):
removals.add(item_id.value)
continue
if not item_id or item_id in index:
# no id or duplicate -> pass through with identity-based key
item_id = id(item)
index[item_id] = item
return index, extensions, removals
def resolve_extend_remove(value, is_key=None):
if isinstance(value, ESPLiteralValue):
return # do not check inside literal blocks
if isinstance(value, list):
index, extensions, removals = _build_list_index(value)
if extensions or removals:
# Rebuild the original list after
# processing all extensions and removals
for item in extensions:
item_id = item[CONF_ID].value
if item_id in removals:
continue
old = index.get(item_id)
if old is None:
# Failed to find source for extension
# Find index of item to show error at correct position
i = next(
(
i
for i, d in enumerate(value)
if d.get(CONF_ID) == item[CONF_ID]
)
)
with cv.prepend_path(i):
raise cv.Invalid(
f"Source for extension of ID '{item_id}' was not found."
)
item[CONF_ID] = item_id
index[item_id] = merge_config(old, item)
for item_id in removals:
index.pop(item_id, None)
value[:] = index.values()
for i, item in enumerate(value):
with cv.prepend_path(i):
resolve_extend_remove(item, False)
return
if isinstance(value, dict):
removals = []
for k, v in value.items():
with cv.prepend_path(k):
if isinstance(v, Remove):
removals.append(k)
continue
resolve_extend_remove(k, True)
resolve_extend_remove(v, False)
for k in removals:
value.pop(k, None)
return
if is_key:
return # do not check keys (yet)
check_replaceme(value)
return
return value
class ConfigValidationStep(abc.ABC):
@@ -510,6 +437,19 @@ class LoadValidationStep(ConfigValidationStep):
continue
p_name = p_config.get("platform")
if p_name is None:
p_id = p_config.get(CONF_ID)
if isinstance(p_id, Extend):
result.add_str_error(
f"Source for extension of ID '{p_id.value}' was not found.",
path + [CONF_ID],
)
continue
if isinstance(p_id, Remove):
result.add_str_error(
f"Source for removal of ID '{p_id.value}' was not found.",
path + [CONF_ID],
)
continue
result.add_str_error(
f"'{self.domain}' requires a 'platform' key but it was not specified.",
path,
@@ -994,10 +934,9 @@ def validate_config(
CORE.raw_config = config
# 1.1. Resolve !extend and !remove and check for REPLACEME
# After this step, there will not be any Extend or Remove values in the config anymore
# 1.1. Check for REPLACEME special value
try:
resolve_extend_remove(config)
recursive_check_replaceme(config)
except vol.Invalid as err:
result.add_error(err)

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from esphome.const import (
CONF_ID,
CONF_LEVEL,
CONF_LOGGER,
KEY_CORE,
@@ -74,28 +75,73 @@ class Remove:
return isinstance(b, Remove) and self.value == b.value
def merge_config(old, new):
if isinstance(new, Remove):
return new
if isinstance(new, dict):
if not isinstance(old, dict):
return new
# Preserve OrderedDict type by copying to OrderedDict if either input is OrderedDict
if isinstance(old, OrderedDict) or isinstance(new, OrderedDict):
res = OrderedDict(old)
else:
def merge_config(full_old, full_new):
def merge(old, new):
if isinstance(new, dict):
if not isinstance(old, dict):
return new
# Preserve OrderedDict type by copying to OrderedDict if either input is OrderedDict
if isinstance(old, OrderedDict) or isinstance(new, OrderedDict):
res = OrderedDict(old)
else:
res = old.copy()
for k, v in new.items():
if isinstance(v, Remove) and k in old:
del res[k]
else:
res[k] = merge(old[k], v) if k in old else v
return res
if isinstance(new, list):
if not isinstance(old, list):
return new
res = old.copy()
for k, v in new.items():
res[k] = merge_config(old.get(k), v)
return res
if isinstance(new, list):
if not isinstance(old, list):
return new
return old + new
if new is None:
return old
ids = {
v_id: i
for i, v in enumerate(res)
if isinstance(v, dict)
and (v_id := v.get(CONF_ID))
and isinstance(v_id, str)
}
extend_ids = {
v_id.value: i
for i, v in enumerate(res)
if isinstance(v, dict)
and (v_id := v.get(CONF_ID))
and isinstance(v_id, Extend)
}
return new
ids_to_delete = []
for v in new:
if isinstance(v, dict) and (new_id := v.get(CONF_ID)):
if isinstance(new_id, Extend):
new_id = new_id.value
if new_id in ids:
v[CONF_ID] = new_id
res[ids[new_id]] = merge(res[ids[new_id]], v)
continue
elif isinstance(new_id, Remove):
new_id = new_id.value
if new_id in ids:
ids_to_delete.append(ids[new_id])
continue
elif (
new_id in extend_ids
): # When a package is extending a non-packaged item
extend_res = res[extend_ids[new_id]]
extend_res[CONF_ID] = new_id
new_v = merge(v, extend_res)
res[extend_ids[new_id]] = new_v
continue
else:
ids[new_id] = len(res)
res.append(v)
return [v for i, v in enumerate(res) if i not in ids_to_delete]
if new is None:
return old
return new
return merge(full_old, full_new)
def filter_source_files_from_platform(

View File

@@ -24,6 +24,7 @@ import voluptuous as vol
from esphome import core
import esphome.codegen as cg
from esphome.config_helpers import Extend, Remove
from esphome.const import (
ALLOWED_NAME_CHARS,
CONF_AVAILABILITY,
@@ -623,6 +624,12 @@ def declare_id(type):
if value is None:
return core.ID(None, is_declaration=True, type=type)
if isinstance(value, Extend):
raise Invalid(f"Source for extension of ID '{value.value}' was not found.")
if isinstance(value, Remove):
raise Invalid(f"Source for Removal of ID '{value.value}' was not found.")
return core.ID(validate_id_name(value), is_declaration=True, type=type)
return validator

View File

@@ -199,7 +199,6 @@
#define USE_WEBSERVER_PORT 80 // NOLINT
#define USE_WEBSERVER_SORTING
#define USE_WIFI_11KV_SUPPORT
#define USE_WIFI_FAST_CONNECT
#define USB_HOST_MAX_REQUESTS 16
#ifdef USE_ARDUINO

View File

@@ -281,13 +281,13 @@ template<typename T> class FixedVector {
}
}
/// Emplace element without bounds checking - constructs in-place with arguments
/// Emplace element without bounds checking - constructs in-place
/// Caller must ensure sufficient capacity was allocated via init()
/// Returns reference to the newly constructed element
/// NOTE: Caller MUST ensure size_ < capacity_ before calling
template<typename... Args> T &emplace_back(Args &&...args) {
// Use placement new to construct the object in pre-allocated memory
new (&data_[size_]) T(std::forward<Args>(args)...);
T &emplace_back() {
// Use placement new to default-construct the object in pre-allocated memory
new (&data_[size_]) T();
size_++;
return data_[size_ - 1];
}
@@ -1158,4 +1158,18 @@ template<typename T, enable_if_t<std::is_pointer<T *>::value, int> = 0> T &id(T
///@}
/// @name Deprecated functions
///@{
ESPDEPRECATED("hexencode() is deprecated, use format_hex_pretty() instead.", "2022.1")
inline std::string hexencode(const uint8_t *data, uint32_t len) { return format_hex_pretty(data, len); }
template<typename T>
ESPDEPRECATED("hexencode() is deprecated, use format_hex_pretty() instead.", "2022.1")
std::string hexencode(const T &data) {
return hexencode(data.data(), data.size());
}
///@}
} // namespace esphome

View File

@@ -1,362 +0,0 @@
"""GitHub download cache for ESPHome.
This module provides caching functionality for GitHub release downloads
to avoid redundant network I/O when switching between platforms.
"""
from __future__ import annotations
import hashlib
import json
import logging
from pathlib import Path
import shutil
import time
import urllib.error
import urllib.request
_LOGGER = logging.getLogger(__name__)
class GitHubCache:
"""Manages caching of GitHub release downloads."""
# Cache expiration time in seconds (30 days)
CACHE_EXPIRATION_SECONDS = 30 * 24 * 60 * 60
def __init__(self, cache_dir: Path | None = None):
"""Initialize the cache manager.
Args:
cache_dir: Directory to store cached files.
Defaults to ~/.esphome_cache/github
"""
if cache_dir is None:
cache_dir = Path.home() / ".esphome_cache" / "github"
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.metadata_file = self.cache_dir / "cache_metadata.json"
# Prune old files on initialization
try:
self._prune_old_files()
except Exception as e:
_LOGGER.debug("Failed to prune old cache files: %s", e)
def _load_metadata(self) -> dict:
"""Load cache metadata from disk."""
if self.metadata_file.exists():
try:
with open(self.metadata_file) as f:
return json.load(f)
except (OSError, ValueError, json.JSONDecodeError):
return {}
return {}
def _save_metadata(self, metadata: dict) -> None:
"""Save cache metadata to disk."""
try:
with open(self.metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
except OSError as e:
_LOGGER.debug("Failed to save cache metadata: %s", e)
@staticmethod
def is_github_url(url: str) -> bool:
"""Check if URL is a GitHub release download."""
return "github.com" in url.lower() and url.endswith(".zip")
def _get_cache_key(self, url: str) -> str:
"""Get cache key (hash) for a URL."""
return hashlib.sha256(url.encode()).hexdigest()
def _get_cache_path(self, url: str) -> Path:
"""Get cache file path for a URL."""
cache_key = self._get_cache_key(url)
ext = Path(url.split("?")[0]).suffix
return self.cache_dir / f"{cache_key}{ext}"
def _check_if_modified(
self,
url: str,
last_modified: str | None = None,
etag: str | None = None,
) -> bool:
"""Check if a URL has been modified using HTTP 304.
Args:
url: URL to check
last_modified: Last-Modified header from previous response
etag: ETag header from previous response
Returns:
True if modified, False if not modified (or offline/unreachable)
"""
if not last_modified and not etag:
# No cache headers available, assume modified
return True
try:
request = urllib.request.Request(url)
request.get_method = lambda: "HEAD"
if last_modified:
request.add_header("If-Modified-Since", last_modified)
if etag:
request.add_header("If-None-Match", etag)
try:
urllib.request.urlopen(request, timeout=10)
# 200 OK = file was modified
return True
except urllib.error.HTTPError as e:
if e.code == 304:
# Not modified
_LOGGER.debug("File not modified (HTTP 304): %s", url)
return False
# Other errors, assume modified to be safe
return True
except (OSError, urllib.error.URLError):
# If check fails (offline/network error), assume not modified (use cache)
_LOGGER.info("Cannot reach server (offline?), using cached file: %s", url)
return False
def get_cached_path(self, url: str, check_updates: bool = True) -> Path | None:
"""Get path to cached file if available and valid.
Args:
url: URL to check
check_updates: Whether to check for updates using HTTP 304
Returns:
Path to cached file if valid, None if needs download
"""
if not self.is_github_url(url):
return None
cache_path = self._get_cache_path(url)
if not cache_path.exists():
return None
# Load metadata
metadata = self._load_metadata()
cache_key = self._get_cache_key(url)
# Check if file should be re-downloaded
should_redownload = False
if check_updates and cache_key in metadata:
last_modified = metadata[cache_key].get("last_modified")
etag = metadata[cache_key].get("etag")
if self._check_if_modified(url, last_modified, etag):
# File was modified, need to re-download
_LOGGER.debug("Cached file is outdated: %s", url)
should_redownload = True
if should_redownload:
return None
# File is valid, update cached_at timestamp to keep it fresh
if cache_key in metadata:
metadata[cache_key]["cached_at"] = time.time()
self._save_metadata(metadata)
# Log appropriate message
if not check_updates:
_LOGGER.debug("Using cached file (no update check): %s", url)
elif cache_key not in metadata:
_LOGGER.debug("Using cached file (no metadata): %s", url)
else:
_LOGGER.debug("Using cached file: %s", url)
return cache_path
def save_to_cache(self, url: str, source_path: Path) -> None:
"""Save a downloaded file to cache.
Args:
url: URL the file was downloaded from
source_path: Path to the downloaded file
"""
if not self.is_github_url(url):
return
try:
cache_path = self._get_cache_path(url)
# Only copy if source and destination are different
if source_path.resolve() != cache_path.resolve():
shutil.copy2(source_path, cache_path)
# Try to get HTTP headers for caching
last_modified = None
etag = None
try:
request = urllib.request.Request(url)
request.get_method = lambda: "HEAD"
response = urllib.request.urlopen(request, timeout=10)
last_modified = response.headers.get("Last-Modified")
etag = response.headers.get("ETag")
except (OSError, urllib.error.URLError):
pass
# Update metadata
metadata = self._load_metadata()
cache_key = self._get_cache_key(url)
metadata[cache_key] = {
"url": url,
"size": cache_path.stat().st_size,
"cached_at": time.time(),
"last_modified": last_modified,
"etag": etag,
}
self._save_metadata(metadata)
_LOGGER.debug("Saved to cache: %s", url)
except OSError as e:
_LOGGER.debug("Failed to save to cache: %s", e)
def copy_from_cache(self, url: str, destination: Path) -> bool:
"""Copy a cached file to destination.
Args:
url: URL of the cached file
destination: Where to copy the file
Returns:
True if successful, False otherwise
"""
cached_path = self.get_cached_path(url, check_updates=True)
if not cached_path:
return False
try:
shutil.copy2(cached_path, destination)
_LOGGER.info("Using cached download for %s", url)
return True
except OSError as e:
_LOGGER.warning("Failed to use cache: %s", e)
return False
def cache_size(self) -> int:
"""Get total size of cached files in bytes."""
total = 0
try:
for file_path in self.cache_dir.glob("*"):
if file_path.is_file() and file_path != self.metadata_file:
total += file_path.stat().st_size
except OSError:
pass
return total
def list_cached(self) -> list[dict]:
"""List all cached files with metadata."""
cached_files = []
metadata = self._load_metadata()
for cache_key, meta in metadata.items():
cache_path = (
self.cache_dir / f"{cache_key}{Path(meta['url'].split('?')[0]).suffix}"
)
if cache_path.exists():
cached_files.append(
{
"url": meta["url"],
"path": cache_path,
"size": meta["size"],
"cached_at": meta.get("cached_at"),
"last_modified": meta.get("last_modified"),
"etag": meta.get("etag"),
}
)
return cached_files
def clear_cache(self) -> None:
"""Clear all cached files."""
try:
for file_path in self.cache_dir.glob("*"):
if file_path.is_file():
file_path.unlink()
_LOGGER.info("Cache cleared: %s", self.cache_dir)
except OSError as e:
_LOGGER.warning("Failed to clear cache: %s", e)
def _prune_old_files(self) -> None:
"""Remove cache files older than CACHE_EXPIRATION_SECONDS."""
current_time = time.time()
metadata = self._load_metadata()
removed_count = 0
removed_size = 0
# Check each file in metadata
for cache_key, meta in list(metadata.items()):
cached_at = meta.get("cached_at", 0)
age_seconds = current_time - cached_at
if age_seconds > self.CACHE_EXPIRATION_SECONDS:
# File is too old, remove it
cache_path = (
self.cache_dir
/ f"{cache_key}{Path(meta['url'].split('?')[0]).suffix}"
)
if cache_path.exists():
file_size = cache_path.stat().st_size
cache_path.unlink()
removed_size += file_size
removed_count += 1
_LOGGER.debug(
"Pruned old cache file (age: %.1f days): %s",
age_seconds / (24 * 60 * 60),
meta["url"],
)
# Remove from metadata
del metadata[cache_key]
# Also check for orphaned files (files without metadata)
for file_path in self.cache_dir.glob("*.zip"):
if file_path == self.metadata_file:
continue
# Check if file is in metadata
found_in_metadata = False
for cache_key in metadata:
if file_path.name.startswith(cache_key):
found_in_metadata = True
break
if not found_in_metadata:
# Orphaned file - check age by modification time
file_age = current_time - file_path.stat().st_mtime
if file_age > self.CACHE_EXPIRATION_SECONDS:
file_size = file_path.stat().st_size
file_path.unlink()
removed_size += file_size
removed_count += 1
_LOGGER.debug(
"Pruned orphaned cache file (age: %.1f days): %s",
file_age / (24 * 60 * 60),
file_path.name,
)
# Save updated metadata if anything was removed
if removed_count > 0:
self._save_metadata(metadata)
removed_mb = removed_size / (1024 * 1024)
_LOGGER.info(
"Pruned %d old cache file(s), freed %.2f MB",
removed_count,
removed_mb,
)
# Global cache instance
_cache: GitHubCache | None = None
def get_cache() -> GitHubCache:
"""Get the global GitHub cache instance."""
global _cache # noqa: PLW0603
if _cache is None:
_cache = GitHubCache()
return _cache

View File

@@ -5,6 +5,7 @@ import os
from pathlib import Path
import re
import subprocess
from typing import Any
from esphome.const import CONF_COMPILE_PROCESS_LIMIT, CONF_ESPHOME, KEY_CORE
from esphome.core import CORE, EsphomeError
@@ -43,168 +44,32 @@ def patch_structhash():
def patch_file_downloader():
"""Patch PlatformIO's FileDownloader to add caching and retry on PackageException errors.
"""Patch PlatformIO's FileDownloader to retry on PackageException errors."""
from platformio.package.download import FileDownloader
from platformio.package.exception import PackageException
This function attempts to patch PlatformIO's internal download mechanism.
If patching fails (due to API changes), it gracefully falls back to no caching.
"""
try:
from platformio.package.download import FileDownloader
from platformio.package.exception import PackageException
except ImportError as e:
_LOGGER.debug("Could not import PlatformIO modules for patching: %s", e)
return
original_init = FileDownloader.__init__
# Import our cache module
from esphome.github_cache import GitHubCache
def patched_init(self, *args: Any, **kwargs: Any) -> None:
max_retries = 3
_LOGGER.debug("Applying GitHub download cache patch...")
# Verify the classes have the expected methods before patching
if not hasattr(FileDownloader, "__init__") or not hasattr(FileDownloader, "start"):
_LOGGER.warning(
"PlatformIO FileDownloader API has changed, skipping cache patch"
)
return
try:
original_init = FileDownloader.__init__
original_start = FileDownloader.start
# Initialize cache in .platformio directory so it benefits from GitHub Actions cache
platformio_dir = Path.home() / ".platformio"
cache_dir = platformio_dir / "esphome_download_cache"
cache_dir_existed = cache_dir.exists()
cache = GitHubCache(cache_dir=cache_dir)
if not cache_dir_existed:
_LOGGER.info("Created GitHub download cache at: %s", cache.cache_dir)
except Exception as e:
_LOGGER.warning("Failed to initialize GitHub download cache: %s", e)
return
def patched_init(self, *args, **kwargs):
"""Patched init that checks cache before making HTTP connection."""
try:
# Extract URL from args (first positional argument)
url = args[0] if args else kwargs.get("url")
dest_dir = args[1] if len(args) > 1 else kwargs.get("dest_dir")
# Debug: Log all downloads
_LOGGER.debug("[GitHub Cache] Download request for: %s", url)
# Store URL for later use (original FileDownloader doesn't store it)
self._esphome_cache_url = url if cache.is_github_url(url) else None
# Check cache for GitHub URLs BEFORE making HTTP request
if self._esphome_cache_url:
_LOGGER.debug("[GitHub Cache] This is a GitHub URL, checking cache...")
self._esphome_use_cache = cache.get_cached_path(url, check_updates=True)
if self._esphome_use_cache:
_LOGGER.info(
"Found %s in cache, will restore instead of downloading",
Path(url.split("?")[0]).name,
)
_LOGGER.debug(
"[GitHub Cache] Found in cache: %s", self._esphome_use_cache
for attempt in range(max_retries):
try:
return original_init(self, *args, **kwargs)
except PackageException as e:
if attempt < max_retries - 1:
_LOGGER.warning(
"Package download failed: %s. Retrying... (attempt %d/%d)",
str(e),
attempt + 1,
max_retries,
)
else:
_LOGGER.debug(
"[GitHub Cache] Not in cache, will download and cache"
)
else:
self._esphome_use_cache = None
if url and str(url).startswith("http"):
_LOGGER.debug("[GitHub Cache] Not a GitHub URL, skipping cache")
# Final attempt - re-raise
raise
return None
# Only make HTTP connection if we don't have cached file
if self._esphome_use_cache:
# Skip HTTP connection, we'll handle this in start()
# Set minimal attributes to satisfy FileDownloader
# Create a mock session that can be safely closed in __del__
class MockSession:
def close(self):
pass
self._http_session = MockSession()
self._http_response = None
self._fname = Path(url.split("?")[0]).name
self._destination = self._fname
if dest_dir:
from os.path import join
self._destination = join(dest_dir, self._fname)
# Note: Actual restoration logged in patched_start
return None # Don't call original_init
# Normal initialization with retry logic
max_retries = 3
for attempt in range(max_retries):
try:
return original_init(self, *args, **kwargs)
except PackageException as e:
if attempt < max_retries - 1:
_LOGGER.warning(
"Package download failed: %s. Retrying... (attempt %d/%d)",
str(e),
attempt + 1,
max_retries,
)
else:
# Final attempt - re-raise
raise
return None
except Exception as e:
# If anything goes wrong in our cache logic, fall back to normal download
_LOGGER.debug("Cache check failed, falling back to normal download: %s", e)
self._esphome_cache_url = None
self._esphome_use_cache = None
return original_init(self, *args, **kwargs)
def patched_start(self, *args, **kwargs):
"""Patched start that uses cache when available."""
try:
import shutil
# Get the cache URL and path that were set in __init__
cache_url = getattr(self, "_esphome_cache_url", None)
cached_file = getattr(self, "_esphome_use_cache", None)
# If we're using cache, copy file instead of downloading
if cached_file:
try:
shutil.copy2(cached_file, self._destination)
_LOGGER.info(
"Restored %s from cache (avoided download)",
Path(cached_file).name,
)
return True
except OSError as e:
_LOGGER.warning("Failed to copy from cache: %s", e)
# Fall through to re-download
# Perform normal download
result = original_start(self, *args, **kwargs)
# Save to cache if it was a GitHub URL
if cache_url:
try:
cache.save_to_cache(cache_url, Path(self._destination))
except OSError as e:
_LOGGER.debug("Failed to save to cache: %s", e)
return result
except Exception as e:
# If anything goes wrong, fall back to normal download
_LOGGER.debug("Cache restoration failed, using normal download: %s", e)
return original_start(self, *args, **kwargs)
# Apply the patches
try:
FileDownloader.__init__ = patched_init
FileDownloader.start = patched_start
_LOGGER.debug("GitHub download cache patch applied successfully")
except Exception as e:
_LOGGER.warning("Failed to apply GitHub download cache patch: %s", e)
FileDownloader.__init__ = patched_init
IGNORE_LIB_WARNINGS = f"(?:{'|'.join(['Hash', 'Update'])})"
@@ -222,8 +87,6 @@ FILTER_PLATFORMIO_LINES = [
r"Memory Usage -> https://bit.ly/pio-memory-usage",
r"Found: https://platformio.org/lib/show/.*",
r"Using cache: .*",
# Don't filter our cache messages - let users see when cache is being used
# r"Using cached download for .*",
r"Installing dependencies",
r"Library Manager: Already installed, built-in library",
r"Building in .* mode",

View File

@@ -1,164 +0,0 @@
#!/usr/bin/env python3
"""
Pre-cache PlatformIO GitHub Downloads
This script extracts GitHub URLs from platformio.ini and pre-caches them
to avoid redundant downloads when switching between ESP8266 and ESP32 builds.
Usage:
python3 script/cache_platformio_downloads.py [platformio.ini]
"""
import argparse
import configparser
from pathlib import Path
import re
import sys
# Import the cache manager
sys.path.insert(0, str(Path(__file__).parent.parent))
from esphome.github_cache import GitHubCache
def extract_github_urls(platformio_ini: Path) -> list[str]:
"""Extract all GitHub URLs from platformio.ini.
Args:
platformio_ini: Path to platformio.ini file
Returns:
List of GitHub URLs found
"""
config = configparser.ConfigParser(inline_comment_prefixes=(";",))
config.read(platformio_ini)
urls = []
github_pattern = re.compile(r"https://github\.com/[^\s;]+\.zip")
for section in config.sections():
conf = config[section]
# Check platform
if "platform" in conf:
platform_value = conf["platform"]
matches = github_pattern.findall(platform_value)
urls.extend(matches)
# Check platform_packages
if "platform_packages" in conf:
for line in conf["platform_packages"].splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
matches = github_pattern.findall(line)
urls.extend(matches)
# Remove duplicates while preserving order using dict
return list(dict.fromkeys(urls))
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Pre-cache PlatformIO GitHub downloads",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
This script scans platformio.ini for GitHub URLs and pre-caches them.
This avoids redundant downloads when switching between platforms (e.g., ESP8266 and ESP32).
Examples:
# Cache downloads from default platformio.ini
%(prog)s
# Cache downloads from specific file
%(prog)s custom_platformio.ini
# Show what would be cached without downloading
%(prog)s --dry-run
""",
)
parser.add_argument(
"platformio_ini",
nargs="?",
default="platformio.ini",
help="Path to platformio.ini (default: platformio.ini)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be cached without downloading",
)
parser.add_argument(
"--cache-dir",
type=Path,
help="Cache directory (default: ~/.platformio/esphome_download_cache)",
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-download even if cached",
)
args = parser.parse_args()
platformio_ini = Path(args.platformio_ini)
if not platformio_ini.exists():
print(f"Error: {platformio_ini} not found", file=sys.stderr)
return 1
# Extract URLs
print(f"Scanning {platformio_ini} for GitHub URLs...")
urls = extract_github_urls(platformio_ini)
if not urls:
print("No GitHub URLs found in platformio.ini")
return 0
print(f"Found {len(urls)} unique GitHub URL(s):")
for url in urls:
print(f" - {url}")
print()
if args.dry_run:
print("Dry run - not downloading")
return 0
# Initialize cache (use PlatformIO directory by default)
cache_dir = args.cache_dir
if cache_dir is None:
cache_dir = Path.home() / ".platformio" / "esphome_download_cache"
cache = GitHubCache(cache_dir)
# Cache each URL
success_count = 0
for i, url in enumerate(urls, 1):
print(f"[{i}/{len(urls)}] Checking {url}")
try:
# Use the download_with_progress from github_download_cache CLI
from script.github_download_cache import download_with_progress
download_with_progress(cache, url, force=args.force, check_updates=True)
success_count += 1
print()
except Exception as e:
print(f"Error caching {url}: {e}", file=sys.stderr)
print()
# Show cache stats
total_size = cache.cache_size()
size_mb = total_size / (1024 * 1024)
print("\nCache summary:")
print(f" Successfully cached: {success_count}/{len(urls)}")
print(f" Total cache size: {size_mb:.2f} MB")
print(f" Cache location: {cache.cache_dir}")
return 0 if success_count == len(urls) else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,88 +0,0 @@
#!/usr/bin/env python3
"""Add metadata to memory analysis JSON file.
This script adds components and platform metadata to an existing
memory analysis JSON file. Used by CI to ensure all required fields are present
for the comment script.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Add metadata to memory analysis JSON file"
)
parser.add_argument(
"--json-file",
required=True,
help="Path to JSON file to update",
)
parser.add_argument(
"--components",
required=True,
help='JSON array of component names (e.g., \'["api", "wifi"]\')',
)
parser.add_argument(
"--platform",
required=True,
help="Platform name",
)
args = parser.parse_args()
# Load existing JSON
json_path = Path(args.json_file)
if not json_path.exists():
print(f"Error: JSON file not found: {args.json_file}", file=sys.stderr)
return 1
try:
with open(json_path, encoding="utf-8") as f:
data = json.load(f)
except (json.JSONDecodeError, OSError) as e:
print(f"Error loading JSON: {e}", file=sys.stderr)
return 1
# Parse components
try:
components = json.loads(args.components)
if not isinstance(components, list):
print("Error: --components must be a JSON array", file=sys.stderr)
return 1
# Element-level validation: ensure each component is a non-empty string
for idx, comp in enumerate(components):
if not isinstance(comp, str) or not comp.strip():
print(
f"Error: component at index {idx} is not a non-empty string: {comp!r}",
file=sys.stderr,
)
return 1
except json.JSONDecodeError as e:
print(f"Error parsing components: {e}", file=sys.stderr)
return 1
# Add metadata
data["components"] = components
data["platform"] = args.platform
# Write back
try:
with open(json_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
print(f"Added metadata to {args.json_file}", file=sys.stderr)
except OSError as e:
print(f"Error writing JSON: {e}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -24,37 +24,6 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
# Comment marker to identify our memory impact comments
COMMENT_MARKER = "<!-- esphome-memory-impact-analysis -->"
def run_gh_command(args: list[str], operation: str) -> subprocess.CompletedProcess:
"""Run a gh CLI command with error handling.
Args:
args: Command arguments (including 'gh')
operation: Description of the operation for error messages
Returns:
CompletedProcess result
Raises:
subprocess.CalledProcessError: If command fails (with detailed error output)
"""
try:
return subprocess.run(
args,
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
print(
f"ERROR: {operation} failed with exit code {e.returncode}", file=sys.stderr
)
print(f"ERROR: Command: {' '.join(args)}", file=sys.stderr)
print(f"ERROR: stdout: {e.stdout}", file=sys.stderr)
print(f"ERROR: stderr: {e.stderr}", file=sys.stderr)
raise
# Thresholds for emoji significance indicators (percentage)
OVERALL_CHANGE_THRESHOLD = 1.0 # Overall RAM/Flash changes
COMPONENT_CHANGE_THRESHOLD = 3.0 # Component breakdown changes
@@ -269,6 +238,7 @@ def create_comment_body(
pr_analysis: dict | None = None,
target_symbols: dict | None = None,
pr_symbols: dict | None = None,
target_cache_hit: bool = False,
) -> str:
"""Create the comment body with memory impact analysis using Jinja2 templates.
@@ -283,6 +253,7 @@ def create_comment_body(
pr_analysis: Optional component breakdown for PR branch
target_symbols: Optional symbol map for target branch
pr_symbols: Optional symbol map for PR branch
target_cache_hit: Whether target branch analysis was loaded from cache
Returns:
Formatted comment body
@@ -312,6 +283,7 @@ def create_comment_body(
"flash_change": format_change(
target_flash, pr_flash, threshold=OVERALL_CHANGE_THRESHOLD
),
"target_cache_hit": target_cache_hit,
"component_change_threshold": COMPONENT_CHANGE_THRESHOLD,
}
@@ -384,7 +356,7 @@ def find_existing_comment(pr_number: str) -> str | None:
print(f"DEBUG: Looking for existing comment on PR #{pr_number}", file=sys.stderr)
# Use gh api to get comments directly - this returns the numeric id field
result = run_gh_command(
result = subprocess.run(
[
"gh",
"api",
@@ -392,7 +364,9 @@ def find_existing_comment(pr_number: str) -> str | None:
"--jq",
".[] | {id, body}",
],
operation="Get PR comments",
capture_output=True,
text=True,
check=True,
)
print(
@@ -446,8 +420,7 @@ def update_existing_comment(comment_id: str, comment_body: str) -> None:
subprocess.CalledProcessError: If gh command fails
"""
print(f"DEBUG: Updating existing comment {comment_id}", file=sys.stderr)
print(f"DEBUG: Comment body length: {len(comment_body)} bytes", file=sys.stderr)
result = run_gh_command(
result = subprocess.run(
[
"gh",
"api",
@@ -457,7 +430,9 @@ def update_existing_comment(comment_id: str, comment_body: str) -> None:
"-f",
f"body={comment_body}",
],
operation="Update PR comment",
check=True,
capture_output=True,
text=True,
)
print(f"DEBUG: Update response: {result.stdout}", file=sys.stderr)
@@ -473,10 +448,11 @@ def create_new_comment(pr_number: str, comment_body: str) -> None:
subprocess.CalledProcessError: If gh command fails
"""
print(f"DEBUG: Posting new comment on PR #{pr_number}", file=sys.stderr)
print(f"DEBUG: Comment body length: {len(comment_body)} bytes", file=sys.stderr)
result = run_gh_command(
result = subprocess.run(
["gh", "pr", "comment", pr_number, "--body", comment_body],
operation="Create PR comment",
check=True,
capture_output=True,
text=True,
)
print(f"DEBUG: Post response: {result.stdout}", file=sys.stderr)
@@ -509,128 +485,79 @@ def main() -> int:
)
parser.add_argument("--pr-number", required=True, help="PR number")
parser.add_argument(
"--target-json",
"--components",
required=True,
help="Path to target branch analysis JSON file",
help='JSON array of component names (e.g., \'["api", "wifi"]\')',
)
parser.add_argument("--platform", required=True, help="Platform name")
parser.add_argument(
"--target-ram", type=int, required=True, help="Target branch RAM usage"
)
parser.add_argument(
"--target-flash", type=int, required=True, help="Target branch flash usage"
)
parser.add_argument("--pr-ram", type=int, required=True, help="PR branch RAM usage")
parser.add_argument(
"--pr-flash", type=int, required=True, help="PR branch flash usage"
)
parser.add_argument(
"--target-json",
help="Optional path to target branch analysis JSON (for detailed analysis)",
)
parser.add_argument(
"--pr-json",
required=True,
help="Path to PR branch analysis JSON file",
help="Optional path to PR branch analysis JSON (for detailed analysis)",
)
parser.add_argument(
"--target-cache-hit",
action="store_true",
help="Indicates that target branch analysis was loaded from cache",
)
args = parser.parse_args()
# Load analysis JSON files (all data comes from JSON for security)
target_data: dict | None = load_analysis_json(args.target_json)
if not target_data:
print("Error: Failed to load target analysis JSON", file=sys.stderr)
# Parse components from JSON
try:
components = json.loads(args.components)
if not isinstance(components, list):
print("Error: --components must be a JSON array", file=sys.stderr)
sys.exit(1)
except json.JSONDecodeError as e:
print(f"Error parsing --components JSON: {e}", file=sys.stderr)
sys.exit(1)
pr_data: dict | None = load_analysis_json(args.pr_json)
if not pr_data:
print("Error: Failed to load PR analysis JSON", file=sys.stderr)
sys.exit(1)
# Load analysis JSON files
target_analysis = None
pr_analysis = None
target_symbols = None
pr_symbols = None
# Extract detailed analysis if available
target_analysis: dict | None = None
pr_analysis: dict | None = None
target_symbols: dict | None = None
pr_symbols: dict | None = None
if args.target_json:
target_data = load_analysis_json(args.target_json)
if target_data and target_data.get("detailed_analysis"):
target_analysis = target_data["detailed_analysis"].get("components")
target_symbols = target_data["detailed_analysis"].get("symbols")
if target_data.get("detailed_analysis"):
target_analysis = target_data["detailed_analysis"].get("components")
target_symbols = target_data["detailed_analysis"].get("symbols")
if pr_data.get("detailed_analysis"):
pr_analysis = pr_data["detailed_analysis"].get("components")
pr_symbols = pr_data["detailed_analysis"].get("symbols")
# Extract all values from JSON files (prevents shell injection from PR code)
components = target_data.get("components")
platform = target_data.get("platform")
target_ram = target_data.get("ram_bytes")
target_flash = target_data.get("flash_bytes")
pr_ram = pr_data.get("ram_bytes")
pr_flash = pr_data.get("flash_bytes")
# Validate required fields and types
missing_fields: list[str] = []
type_errors: list[str] = []
if components is None:
missing_fields.append("components")
elif not isinstance(components, list):
type_errors.append(
f"components must be a list, got {type(components).__name__}"
)
else:
for idx, comp in enumerate(components):
if not isinstance(comp, str):
type_errors.append(
f"components[{idx}] must be a string, got {type(comp).__name__}"
)
if platform is None:
missing_fields.append("platform")
elif not isinstance(platform, str):
type_errors.append(f"platform must be a string, got {type(platform).__name__}")
if target_ram is None:
missing_fields.append("target.ram_bytes")
elif not isinstance(target_ram, int):
type_errors.append(
f"target.ram_bytes must be an integer, got {type(target_ram).__name__}"
)
if target_flash is None:
missing_fields.append("target.flash_bytes")
elif not isinstance(target_flash, int):
type_errors.append(
f"target.flash_bytes must be an integer, got {type(target_flash).__name__}"
)
if pr_ram is None:
missing_fields.append("pr.ram_bytes")
elif not isinstance(pr_ram, int):
type_errors.append(
f"pr.ram_bytes must be an integer, got {type(pr_ram).__name__}"
)
if pr_flash is None:
missing_fields.append("pr.flash_bytes")
elif not isinstance(pr_flash, int):
type_errors.append(
f"pr.flash_bytes must be an integer, got {type(pr_flash).__name__}"
)
if missing_fields or type_errors:
if missing_fields:
print(
f"Error: JSON files missing required fields: {', '.join(missing_fields)}",
file=sys.stderr,
)
if type_errors:
print(
f"Error: Type validation failed: {'; '.join(type_errors)}",
file=sys.stderr,
)
print(f"Target JSON keys: {list(target_data.keys())}", file=sys.stderr)
print(f"PR JSON keys: {list(pr_data.keys())}", file=sys.stderr)
sys.exit(1)
if args.pr_json:
pr_data = load_analysis_json(args.pr_json)
if pr_data and pr_data.get("detailed_analysis"):
pr_analysis = pr_data["detailed_analysis"].get("components")
pr_symbols = pr_data["detailed_analysis"].get("symbols")
# Create comment body
# Note: Memory totals (RAM/Flash) are summed across all builds if multiple were run.
comment_body = create_comment_body(
components=components,
platform=platform,
target_ram=target_ram,
target_flash=target_flash,
pr_ram=pr_ram,
pr_flash=pr_flash,
platform=args.platform,
target_ram=args.target_ram,
target_flash=args.target_flash,
pr_ram=args.pr_ram,
pr_flash=args.pr_flash,
target_analysis=target_analysis,
pr_analysis=pr_analysis,
target_symbols=target_symbols,
pr_symbols=pr_symbols,
target_cache_hit=args.target_cache_hit,
)
# Post or update comment

View File

@@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
GitHub Download Cache CLI
This script provides a command-line interface to the GitHub download cache.
The actual caching logic is in esphome/github_cache.py.
Usage:
python3 script/github_download_cache.py download URL
python3 script/github_download_cache.py list
python3 script/github_download_cache.py stats
python3 script/github_download_cache.py clear
"""
import argparse
from pathlib import Path
import sys
import urllib.request
# Add parent directory to path to import esphome modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from esphome.github_cache import GitHubCache
def download_with_progress(
cache: GitHubCache, url: str, force: bool = False, check_updates: bool = True
) -> Path:
"""Download a URL with progress indicator and caching.
Args:
cache: GitHubCache instance
url: URL to download
force: Force re-download even if cached
check_updates: Check for updates using HTTP 304
Returns:
Path to cached file
"""
# If force, skip cache check
if not force:
cached_path = cache.get_cached_path(url, check_updates=check_updates)
if cached_path:
print(f"Using cached file for {url}")
print(f" Cache: {cached_path}")
return cached_path
# Need to download
print(f"Downloading {url}")
cache_path = cache._get_cache_path(url)
print(f" Cache: {cache_path}")
# Download with progress
temp_path = cache_path.with_suffix(cache_path.suffix + ".tmp")
try:
with urllib.request.urlopen(url) as response:
total_size = int(response.headers.get("Content-Length", 0))
downloaded = 0
with open(temp_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f"\r Progress: {percent:.1f}%", end="", flush=True)
print() # New line after progress
# Move to final location
temp_path.replace(cache_path)
# Let cache handle metadata
cache.save_to_cache(url, cache_path)
return cache_path
except (OSError, urllib.error.URLError) as e:
if temp_path.exists():
temp_path.unlink()
raise RuntimeError(f"Failed to download {url}: {e}") from e
def main():
"""CLI entry point."""
parser = argparse.ArgumentParser(
description="GitHub Download Cache Manager",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Download and cache a URL
%(prog)s download https://github.com/pioarduino/registry/releases/download/0.0.1/esptoolpy-v5.1.0.zip
# List cached files
%(prog)s list
# Show cache statistics
%(prog)s stats
# Clear cache
%(prog)s clear
""",
)
parser.add_argument(
"--cache-dir",
type=Path,
help="Cache directory (default: ~/.platformio/esphome_download_cache)",
)
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
# Download command
download_parser = subparsers.add_parser("download", help="Download and cache a URL")
download_parser.add_argument("url", help="URL to download")
download_parser.add_argument(
"--force", action="store_true", help="Force re-download even if cached"
)
download_parser.add_argument(
"--no-check-updates",
action="store_true",
help="Skip checking for updates (don't use HTTP 304)",
)
# List command
subparsers.add_parser("list", help="List cached files")
# Stats command
subparsers.add_parser("stats", help="Show cache statistics")
# Clear command
subparsers.add_parser("clear", help="Clear all cached files")
args = parser.parse_args()
if not args.command:
parser.print_help()
return 1
# Use PlatformIO cache directory by default
if args.cache_dir is None:
args.cache_dir = Path.home() / ".platformio" / "esphome_download_cache"
cache = GitHubCache(args.cache_dir)
if args.command == "download":
try:
check_updates = not args.no_check_updates
cache_path = download_with_progress(
cache, args.url, force=args.force, check_updates=check_updates
)
print(f"\nCached at: {cache_path}")
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
elif args.command == "list":
cached = cache.list_cached()
if not cached:
print("No cached files")
return 0
print(f"Cached files ({len(cached)}):")
for item in cached:
size_mb = item["size"] / (1024 * 1024)
print(f" {item['url']}")
print(f" Size: {size_mb:.2f} MB")
print(f" Path: {item['path']}")
return 0
elif args.command == "stats":
total_size = cache.cache_size()
cached_count = len(cache.list_cached())
size_mb = total_size / (1024 * 1024)
print(f"Cache directory: {cache.cache_dir}")
print(f"Cached files: {cached_count}")
print(f"Total size: {size_mb:.2f} MB")
return 0
elif args.command == "clear":
cache.clear_cache()
return 0
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,138 +0,0 @@
#!/usr/bin/env python3
"""
PlatformIO Download Wrapper with Caching
This script can be used as a wrapper around PlatformIO downloads to add caching.
It intercepts download operations and uses the GitHub download cache.
This is designed to be called from PlatformIO's extra_scripts if needed.
"""
from pathlib import Path
import sys
# Import the cache manager
sys.path.insert(0, str(Path(__file__).parent))
from github_download_cache import GitHubDownloadCache
def is_github_url(url: str) -> bool:
"""Check if a URL is a GitHub URL."""
return "github.com" in url.lower()
def cached_download_handler(source, target, env):
"""PlatformIO download handler that uses caching for GitHub URLs.
This function can be registered as a custom download handler in PlatformIO.
Args:
source: Source URL
target: Target file path
env: SCons environment
"""
import shutil
import urllib.request
url = str(source[0])
target_path = Path(str(target[0]))
# Only cache GitHub URLs
if not is_github_url(url):
# Fall back to default download
print(f"Downloading (no cache): {url}")
with (
urllib.request.urlopen(url) as response,
open(target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
return
# Use cache for GitHub URLs
cache = GitHubDownloadCache()
print(f"Downloading with cache: {url}")
try:
cached_path = cache.download_with_cache(url, check_updates=True)
# Copy from cache to target
shutil.copy2(cached_path, target_path)
print(f" Copied to: {target_path}")
except Exception as e:
print(f"Cache download failed, using direct download: {e}")
# Fall back to direct download
with (
urllib.request.urlopen(url) as response,
open(target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
def setup_platformio_caching():
"""Setup PlatformIO to use cached downloads.
This should be called from an extra_scripts file in platformio.ini.
Example extra_scripts file (e.g., platformio_hooks.py):
Import("env")
from script.platformio_download_wrapper import setup_platformio_caching
setup_platformio_caching()
"""
try:
from SCons.Script import DefaultEnvironment
DefaultEnvironment()
# Register custom download handler
# Note: This may not work with all PlatformIO versions
# as the download mechanism is internal
print("Note: Direct download interception is not fully supported.")
print("Please use the cache_platformio_downloads.py script instead.")
except ImportError:
print("Warning: SCons not available, cannot setup download caching")
if __name__ == "__main__":
# CLI mode - can be used to manually download a URL with caching
import argparse
parser = argparse.ArgumentParser(description="Download a URL with caching")
parser.add_argument("url", help="URL to download")
parser.add_argument("target", help="Target file path")
parser.add_argument("--cache-dir", type=Path, help="Cache directory")
args = parser.parse_args()
cache = GitHubDownloadCache(args.cache_dir)
target_path = Path(args.target)
try:
if is_github_url(args.url):
print(f"Downloading with cache: {args.url}")
cached_path = cache.download_with_cache(args.url)
# Copy to target
import shutil
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(cached_path, target_path)
print(f"Copied to: {target_path}")
else:
print(f"Downloading directly (not a GitHub URL): {args.url}")
import shutil
import urllib.request
target_path.parent.mkdir(parents=True, exist_ok=True)
with (
urllib.request.urlopen(args.url) as response,
open(target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
sys.exit(0)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

View File

@@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
import pytest
from esphome.components.packages import do_packages_pass
from esphome.config import resolve_extend_remove
from esphome.config_helpers import Extend, Remove
import esphome.config_validation as cv
from esphome.const import (
@@ -65,20 +64,13 @@ def fixture_basic_esphome():
return {CONF_NAME: TEST_DEVICE_NAME, CONF_PLATFORM: TEST_PLATFORM}
def packages_pass(config):
"""Wrapper around packages_pass that also resolves Extend and Remove."""
config = do_packages_pass(config)
resolve_extend_remove(config)
return config
def test_package_unused(basic_esphome, basic_wifi):
"""
Ensures do_package_pass does not change a config if packages aren't used.
"""
config = {CONF_ESPHOME: basic_esphome, CONF_WIFI: basic_wifi}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == config
@@ -91,7 +83,7 @@ def test_package_invalid_dict(basic_esphome, basic_wifi):
config = {CONF_ESPHOME: basic_esphome, CONF_PACKAGES: basic_wifi | {CONF_URL: ""}}
with pytest.raises(cv.Invalid):
packages_pass(config)
do_packages_pass(config)
def test_package_include(basic_wifi, basic_esphome):
@@ -107,7 +99,7 @@ def test_package_include(basic_wifi, basic_esphome):
expected = {CONF_ESPHOME: basic_esphome, CONF_WIFI: basic_wifi}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -132,7 +124,7 @@ def test_package_append(basic_wifi, basic_esphome):
},
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -156,7 +148,7 @@ def test_package_override(basic_wifi, basic_esphome):
},
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -185,7 +177,7 @@ def test_multiple_package_order():
},
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -241,7 +233,7 @@ def test_package_list_merge():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -319,7 +311,7 @@ def test_package_list_merge_by_id():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -358,13 +350,13 @@ def test_package_merge_by_id_with_list():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
def test_package_merge_by_missing_id():
"""
Ensures that a validation error is thrown when trying to extend a missing ID.
Ensures that components with missing IDs are not merged.
"""
config = {
@@ -387,15 +379,25 @@ def test_package_merge_by_missing_id():
],
}
error_raised = False
try:
packages_pass(config)
assert False, "Expected validation error for missing ID"
except cv.Invalid as err:
error_raised = True
assert err.path == [CONF_SENSOR, 2]
expected = {
CONF_SENSOR: [
{
CONF_ID: TEST_SENSOR_ID_1,
CONF_FILTERS: [{CONF_MULTIPLY: 42.0}],
},
{
CONF_ID: TEST_SENSOR_ID_1,
CONF_FILTERS: [{CONF_MULTIPLY: 10.0}],
},
{
CONF_ID: Extend(TEST_SENSOR_ID_2),
CONF_FILTERS: [{CONF_OFFSET: 146.0}],
},
]
}
assert error_raised
actual = do_packages_pass(config)
assert actual == expected
def test_package_list_remove_by_id():
@@ -445,7 +447,7 @@ def test_package_list_remove_by_id():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -491,7 +493,7 @@ def test_multiple_package_list_remove_by_id():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -512,7 +514,7 @@ def test_package_dict_remove_by_id(basic_wifi, basic_esphome):
CONF_ESPHOME: basic_esphome,
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -543,6 +545,7 @@ def test_package_remove_by_missing_id():
}
expected = {
"missing_key": Remove(),
CONF_SENSOR: [
{
CONF_ID: TEST_SENSOR_ID_1,
@@ -552,10 +555,14 @@ def test_package_remove_by_missing_id():
CONF_ID: TEST_SENSOR_ID_1,
CONF_FILTERS: [{CONF_MULTIPLY: 10.0}],
},
{
CONF_ID: Remove(TEST_SENSOR_ID_2),
CONF_FILTERS: [{CONF_OFFSET: 146.0}],
},
],
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -627,7 +634,7 @@ def test_remote_packages_with_files_list(
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -723,5 +730,5 @@ def test_remote_packages_with_files_and_vars(
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected

View File

@@ -1,5 +1,4 @@
wifi:
fast_connect: true
networks:
- ssid: MySSID
eap:

View File

@@ -1,170 +0,0 @@
esphome:
name: test-script-queued
host:
api:
actions:
# Test 1: Queue depth with default max_runs=5
- action: test_queue_depth
then:
- logger.log: "=== TEST 1: Queue depth (max_runs=5 means 5 total, reject 6-7) ==="
- script.execute:
id: queue_depth_script
value: 1
- script.execute:
id: queue_depth_script
value: 2
- script.execute:
id: queue_depth_script
value: 3
- script.execute:
id: queue_depth_script
value: 4
- script.execute:
id: queue_depth_script
value: 5
- script.execute:
id: queue_depth_script
value: 6
- script.execute:
id: queue_depth_script
value: 7
# Test 2: Ring buffer wrap test
- action: test_ring_buffer
then:
- logger.log: "=== TEST 2: Ring buffer wrap (should process A, B, C in order) ==="
- script.execute:
id: wrap_script
msg: "A"
- script.execute:
id: wrap_script
msg: "B"
- script.execute:
id: wrap_script
msg: "C"
# Test 3: Stop clears queue
- action: test_stop_clears
then:
- logger.log: "=== TEST 3: Stop clears queue (should only see 1, then 'STOPPED') ==="
- script.execute:
id: stop_script
num: 1
- script.execute:
id: stop_script
num: 2
- script.execute:
id: stop_script
num: 3
- delay: 50ms
- logger.log: "STOPPING script now"
- script.stop: stop_script
# Test 4: Verify rejection (max_runs=3)
- action: test_rejection
then:
- logger.log: "=== TEST 4: Verify rejection (max_runs=3 means 3 total, reject 4-8) ==="
- script.execute:
id: rejection_script
val: 1
- script.execute:
id: rejection_script
val: 2
- script.execute:
id: rejection_script
val: 3
- script.execute:
id: rejection_script
val: 4
- script.execute:
id: rejection_script
val: 5
- script.execute:
id: rejection_script
val: 6
- script.execute:
id: rejection_script
val: 7
- script.execute:
id: rejection_script
val: 8
# Test 5: No parameters test
- action: test_no_params
then:
- logger.log: "=== TEST 5: No params (should process 3 times) ==="
- script.execute: no_params_script
- script.execute: no_params_script
- script.execute: no_params_script
logger:
level: DEBUG
script:
# Test script 1: Queue depth test (default max_runs=5)
- id: queue_depth_script
mode: queued
parameters:
value: int
then:
- logger.log:
format: "Queue test: START item %d"
args: ['value']
- delay: 100ms
- logger.log:
format: "Queue test: END item %d"
args: ['value']
# Test script 2: Ring buffer wrap test (max_runs=3)
- id: wrap_script
mode: queued
max_runs: 3
parameters:
msg: string
then:
- logger.log:
format: "Ring buffer: START '%s'"
args: ['msg.c_str()']
- delay: 50ms
- logger.log:
format: "Ring buffer: END '%s'"
args: ['msg.c_str()']
# Test script 3: Stop test
- id: stop_script
mode: queued
max_runs: 5
parameters:
num: int
then:
- logger.log:
format: "Stop test: START %d"
args: ['num']
- delay: 100ms
- logger.log:
format: "Stop test: END %d"
args: ['num']
# Test script 4: Rejection test (max_runs=3)
- id: rejection_script
mode: queued
max_runs: 3
parameters:
val: int
then:
- logger.log:
format: "Rejection test: START %d"
args: ['val']
- delay: 200ms
- logger.log:
format: "Rejection test: END %d"
args: ['val']
# Test script 5: No parameters
- id: no_params_script
mode: queued
then:
- logger.log: "No params: START"
- delay: 50ms
- logger.log: "No params: END"

View File

@@ -1,203 +0,0 @@
"""Test ESPHome queued script functionality."""
from __future__ import annotations
import asyncio
import re
import pytest
from .types import APIClientConnectedFactory, RunCompiledFunction
@pytest.mark.asyncio
async def test_script_queued(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test comprehensive queued script functionality."""
loop = asyncio.get_running_loop()
# Track all test results
test_results = {
"queue_depth": {"processed": [], "rejections": 0},
"ring_buffer": {"start_order": [], "end_order": []},
"stop": {"processed": [], "stop_logged": False},
"rejection": {"processed": [], "rejections": 0},
"no_params": {"executions": 0},
}
# Patterns for Test 1: Queue depth
queue_start = re.compile(r"Queue test: START item (\d+)")
queue_end = re.compile(r"Queue test: END item (\d+)")
queue_reject = re.compile(r"Script 'queue_depth_script' max instances")
# Patterns for Test 2: Ring buffer
ring_start = re.compile(r"Ring buffer: START '([A-Z])'")
ring_end = re.compile(r"Ring buffer: END '([A-Z])'")
# Patterns for Test 3: Stop
stop_start = re.compile(r"Stop test: START (\d+)")
stop_log = re.compile(r"STOPPING script now")
# Patterns for Test 4: Rejection
reject_start = re.compile(r"Rejection test: START (\d+)")
reject_end = re.compile(r"Rejection test: END (\d+)")
reject_reject = re.compile(r"Script 'rejection_script' max instances")
# Patterns for Test 5: No params
no_params_end = re.compile(r"No params: END")
# Test completion futures
test1_complete = loop.create_future()
test2_complete = loop.create_future()
test3_complete = loop.create_future()
test4_complete = loop.create_future()
test5_complete = loop.create_future()
def check_output(line: str) -> None:
"""Check log output for all test messages."""
# Test 1: Queue depth
if match := queue_start.search(line):
item = int(match.group(1))
if item not in test_results["queue_depth"]["processed"]:
test_results["queue_depth"]["processed"].append(item)
if match := queue_end.search(line):
item = int(match.group(1))
if item == 5 and not test1_complete.done():
test1_complete.set_result(True)
if queue_reject.search(line):
test_results["queue_depth"]["rejections"] += 1
# Test 2: Ring buffer
if match := ring_start.search(line):
msg = match.group(1)
test_results["ring_buffer"]["start_order"].append(msg)
if match := ring_end.search(line):
msg = match.group(1)
test_results["ring_buffer"]["end_order"].append(msg)
if (
len(test_results["ring_buffer"]["end_order"]) == 3
and not test2_complete.done()
):
test2_complete.set_result(True)
# Test 3: Stop
if match := stop_start.search(line):
item = int(match.group(1))
if item not in test_results["stop"]["processed"]:
test_results["stop"]["processed"].append(item)
if stop_log.search(line):
test_results["stop"]["stop_logged"] = True
# Give time for any queued items to be cleared
if not test3_complete.done():
loop.call_later(
0.3,
lambda: test3_complete.set_result(True)
if not test3_complete.done()
else None,
)
# Test 4: Rejection
if match := reject_start.search(line):
item = int(match.group(1))
if item not in test_results["rejection"]["processed"]:
test_results["rejection"]["processed"].append(item)
if match := reject_end.search(line):
item = int(match.group(1))
if item == 3 and not test4_complete.done():
test4_complete.set_result(True)
if reject_reject.search(line):
test_results["rejection"]["rejections"] += 1
# Test 5: No params
if no_params_end.search(line):
test_results["no_params"]["executions"] += 1
if (
test_results["no_params"]["executions"] == 3
and not test5_complete.done()
):
test5_complete.set_result(True)
async with (
run_compiled(yaml_config, line_callback=check_output),
api_client_connected() as client,
):
# Get services
_, services = await client.list_entities_services()
# Test 1: Queue depth limit
test_service = next((s for s in services if s.name == "test_queue_depth"), None)
assert test_service is not None, "test_queue_depth service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test1_complete, timeout=2.0)
await asyncio.sleep(0.1) # Give time for rejections
# Verify Test 1
assert sorted(test_results["queue_depth"]["processed"]) == [1, 2, 3, 4, 5], (
f"Test 1: Expected to process items 1-5 (max_runs=5 means 5 total), got {sorted(test_results['queue_depth']['processed'])}"
)
assert test_results["queue_depth"]["rejections"] >= 2, (
"Test 1: Expected at least 2 rejection warnings (items 6-7 should be rejected)"
)
# Test 2: Ring buffer order
test_service = next((s for s in services if s.name == "test_ring_buffer"), None)
assert test_service is not None, "test_ring_buffer service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test2_complete, timeout=2.0)
# Verify Test 2
assert test_results["ring_buffer"]["start_order"] == ["A", "B", "C"], (
f"Test 2: Expected start order [A, B, C], got {test_results['ring_buffer']['start_order']}"
)
assert test_results["ring_buffer"]["end_order"] == ["A", "B", "C"], (
f"Test 2: Expected end order [A, B, C], got {test_results['ring_buffer']['end_order']}"
)
# Test 3: Stop clears queue
test_service = next((s for s in services if s.name == "test_stop_clears"), None)
assert test_service is not None, "test_stop_clears service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test3_complete, timeout=2.0)
# Verify Test 3
assert test_results["stop"]["stop_logged"], (
"Test 3: Stop command was not logged"
)
assert test_results["stop"]["processed"] == [1], (
f"Test 3: Expected only item 1 to process, got {test_results['stop']['processed']}"
)
# Test 4: Rejection enforcement (max_runs=3)
test_service = next((s for s in services if s.name == "test_rejection"), None)
assert test_service is not None, "test_rejection service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test4_complete, timeout=2.0)
await asyncio.sleep(0.1) # Give time for rejections
# Verify Test 4
assert sorted(test_results["rejection"]["processed"]) == [1, 2, 3], (
f"Test 4: Expected to process items 1-3 (max_runs=3 means 3 total), got {sorted(test_results['rejection']['processed'])}"
)
assert test_results["rejection"]["rejections"] == 5, (
f"Test 4: Expected 5 rejections (items 4-8), got {test_results['rejection']['rejections']}"
)
# Test 5: No parameters
test_service = next((s for s in services if s.name == "test_no_params"), None)
assert test_service is not None, "test_no_params service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test5_complete, timeout=2.0)
# Verify Test 5
assert test_results["no_params"]["executions"] == 3, (
f"Test 5: Expected 3 executions, got {test_results['no_params']['executions']}"
)

View File

@@ -1,9 +0,0 @@
substitutions:
A: component1
B: component2
C: component3
some_component:
- id: component1
value: 2
- id: component2
value: 5

View File

@@ -1,22 +0,0 @@
substitutions:
A: component1
B: component2
C: component3
packages:
- some_component:
- id: component1
value: 1
- id: !extend ${B}
value: 4
- id: !extend ${B}
value: 5
- id: component3
value: 6
some_component:
- id: !extend ${A}
value: 2
- id: component2
value: 3
- id: !remove ${C}

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from esphome import config as config_module, yaml_util
from esphome.components import substitutions
from esphome.config import resolve_extend_remove
from esphome.config_helpers import merge_config
from esphome.const import CONF_PACKAGES, CONF_SUBSTITUTIONS
from esphome.core import CORE
@@ -82,8 +81,6 @@ def test_substitutions_fixtures(fixture_path):
substitutions.do_substitution_pass(config, None)
resolve_extend_remove(config)
# Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE
if expected_path.is_file():
expected = yaml_util.load_yaml(expected_path)