1
0
mirror of https://github.com/esphome/esphome.git synced 2025-11-01 15:41:52 +00:00

Compare commits

..

2 Commits

Author SHA1 Message Date
Jesse Hills
6bf78e2e82 Merge branch 'dev' into jesserockz-2025-297 2025-10-20 07:10:49 +13:00
Jesse Hills
817ee70db0 [touchscreen] Disable loop until interrupt triggered 2025-07-16 10:59:44 +12:00
92 changed files with 823 additions and 7953 deletions

View File

@@ -1,5 +1,4 @@
[run]
omit =
esphome/components/*
esphome/analyze_memory/*
tests/integration/*

View File

@@ -1,108 +0,0 @@
---
name: Memory Impact Comment (Forks)
on:
workflow_run:
workflows: ["CI"]
types: [completed]
permissions:
contents: read
pull-requests: write
actions: read
jobs:
memory-impact-comment:
name: Post memory impact comment (fork PRs only)
runs-on: ubuntu-24.04
# Only run for PRs from forks that had successful CI runs
if: >
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_repository.full_name != github.repository
env:
GH_TOKEN: ${{ github.token }}
steps:
- name: Get PR details
id: pr
run: |
# Get PR details by searching for PR with matching head SHA
# The workflow_run.pull_requests field is often empty for forks
head_sha="${{ github.event.workflow_run.head_sha }}"
pr_data=$(gh api "/repos/${{ github.repository }}/commits/$head_sha/pulls" \
--jq '.[0] | {number: .number, base_ref: .base.ref}')
if [ -z "$pr_data" ] || [ "$pr_data" == "null" ]; then
echo "No PR found for SHA $head_sha, skipping"
echo "skip=true" >> $GITHUB_OUTPUT
exit 0
fi
pr_number=$(echo "$pr_data" | jq -r '.number')
base_ref=$(echo "$pr_data" | jq -r '.base_ref')
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT
echo "base_ref=$base_ref" >> $GITHUB_OUTPUT
echo "Found PR #$pr_number targeting base branch: $base_ref"
- name: Check out code from base repository
if: steps.pr.outputs.skip != 'true'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
# Always check out from the base repository (esphome/esphome), never from forks
# Use the PR's target branch to ensure we run trusted code from the main repo
repository: ${{ github.repository }}
ref: ${{ steps.pr.outputs.base_ref }}
- name: Restore Python
if: steps.pr.outputs.skip != 'true'
uses: ./.github/actions/restore-python
with:
python-version: "3.11"
cache-key: ${{ hashFiles('.cache-key') }}
- name: Download memory analysis artifacts
if: steps.pr.outputs.skip != 'true'
run: |
run_id="${{ github.event.workflow_run.id }}"
echo "Downloading artifacts from workflow run $run_id"
mkdir -p memory-analysis
# Download target analysis artifact
if gh run download --name "memory-analysis-target" --dir memory-analysis --repo "${{ github.repository }}" "$run_id"; then
echo "Downloaded memory-analysis-target artifact."
else
echo "No memory-analysis-target artifact found."
fi
# Download PR analysis artifact
if gh run download --name "memory-analysis-pr" --dir memory-analysis --repo "${{ github.repository }}" "$run_id"; then
echo "Downloaded memory-analysis-pr artifact."
else
echo "No memory-analysis-pr artifact found."
fi
- name: Check if artifacts exist
id: check
if: steps.pr.outputs.skip != 'true'
run: |
if [ -f ./memory-analysis/memory-analysis-target.json ] && [ -f ./memory-analysis/memory-analysis-pr.json ]; then
echo "found=true" >> $GITHUB_OUTPUT
else
echo "found=false" >> $GITHUB_OUTPUT
echo "Memory analysis artifacts not found, skipping comment"
fi
- name: Post or update PR comment
if: steps.pr.outputs.skip != 'true' && steps.check.outputs.found == 'true'
env:
PR_NUMBER: ${{ steps.pr.outputs.pr_number }}
run: |
. venv/bin/activate
# Pass PR number and JSON file paths directly to Python script
# Let Python parse the JSON to avoid shell injection risks
# The script will validate and sanitize all inputs
python script/ci_memory_impact_comment.py \
--pr-number "$PR_NUMBER" \
--target-json ./memory-analysis/memory-analysis-target.json \
--pr-json ./memory-analysis/memory-analysis-pr.json

View File

@@ -175,7 +175,6 @@ jobs:
changed-components-with-tests: ${{ steps.determine.outputs.changed-components-with-tests }}
directly-changed-components-with-tests: ${{ steps.determine.outputs.directly-changed-components-with-tests }}
component-test-count: ${{ steps.determine.outputs.component-test-count }}
memory_impact: ${{ steps.determine.outputs.memory-impact }}
steps:
- name: Check out code from GitHub
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
@@ -205,7 +204,6 @@ jobs:
echo "changed-components-with-tests=$(echo "$output" | jq -c '.changed_components_with_tests')" >> $GITHUB_OUTPUT
echo "directly-changed-components-with-tests=$(echo "$output" | jq -c '.directly_changed_components_with_tests')" >> $GITHUB_OUTPUT
echo "component-test-count=$(echo "$output" | jq -r '.component_test_count')" >> $GITHUB_OUTPUT
echo "memory-impact=$(echo "$output" | jq -c '.memory_impact')" >> $GITHUB_OUTPUT
integration-tests:
name: Run integration tests
@@ -432,21 +430,6 @@ jobs:
with:
python-version: ${{ env.DEFAULT_PYTHON }}
cache-key: ${{ needs.common.outputs.cache-key }}
- name: Cache platformio
if: github.ref == 'refs/heads/dev'
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ~/.platformio
key: platformio-test-${{ hashFiles('platformio.ini') }}
- name: Cache platformio
if: github.ref != 'refs/heads/dev'
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ~/.platformio
key: platformio-test-${{ hashFiles('platformio.ini') }}
- name: Validate and compile components with intelligent grouping
run: |
. venv/bin/activate
@@ -538,271 +521,6 @@ jobs:
- uses: pre-commit-ci/lite-action@5d6cc0eb514c891a40562a58a8e71576c5c7fb43 # v1.1.0
if: always()
memory-impact-target-branch:
name: Build target branch for memory impact
runs-on: ubuntu-24.04
needs:
- common
- determine-jobs
if: github.event_name == 'pull_request' && fromJSON(needs.determine-jobs.outputs.memory_impact).should_run == 'true'
outputs:
ram_usage: ${{ steps.extract.outputs.ram_usage }}
flash_usage: ${{ steps.extract.outputs.flash_usage }}
cache_hit: ${{ steps.cache-memory-analysis.outputs.cache-hit }}
skip: ${{ steps.check-script.outputs.skip }}
steps:
- name: Check out target branch
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
ref: ${{ github.base_ref }}
# Check if memory impact extraction script exists on target branch
# If not, skip the analysis (this handles older branches that don't have the feature)
- name: Check for memory impact script
id: check-script
run: |
if [ -f "script/ci_memory_impact_extract.py" ]; then
echo "skip=false" >> $GITHUB_OUTPUT
else
echo "skip=true" >> $GITHUB_OUTPUT
echo "::warning::ci_memory_impact_extract.py not found on target branch, skipping memory impact analysis"
fi
# All remaining steps only run if script exists
- name: Generate cache key
id: cache-key
if: steps.check-script.outputs.skip != 'true'
run: |
# Get the commit SHA of the target branch
target_sha=$(git rev-parse HEAD)
# Hash the build infrastructure files (all files that affect build/analysis)
infra_hash=$(cat \
script/test_build_components.py \
script/ci_memory_impact_extract.py \
script/analyze_component_buses.py \
script/merge_component_configs.py \
script/ci_helpers.py \
.github/workflows/ci.yml \
| sha256sum | cut -d' ' -f1)
# Get platform and components from job inputs
platform="${{ fromJSON(needs.determine-jobs.outputs.memory_impact).platform }}"
components='${{ toJSON(fromJSON(needs.determine-jobs.outputs.memory_impact).components) }}'
components_hash=$(echo "$components" | sha256sum | cut -d' ' -f1)
# Combine into cache key
cache_key="memory-analysis-target-${target_sha}-${infra_hash}-${platform}-${components_hash}"
echo "cache-key=${cache_key}" >> $GITHUB_OUTPUT
echo "Cache key: ${cache_key}"
- name: Restore cached memory analysis
id: cache-memory-analysis
if: steps.check-script.outputs.skip != 'true'
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: memory-analysis-target.json
key: ${{ steps.cache-key.outputs.cache-key }}
- name: Cache status
if: steps.check-script.outputs.skip != 'true'
run: |
if [ "${{ steps.cache-memory-analysis.outputs.cache-hit }}" == "true" ]; then
echo "✓ Cache hit! Using cached memory analysis results."
echo " Skipping build step to save time."
else
echo "✗ Cache miss. Will build and analyze memory usage."
fi
- name: Restore Python
if: steps.check-script.outputs.skip != 'true' && steps.cache-memory-analysis.outputs.cache-hit != 'true'
uses: ./.github/actions/restore-python
with:
python-version: ${{ env.DEFAULT_PYTHON }}
cache-key: ${{ needs.common.outputs.cache-key }}
- name: Cache platformio
if: steps.check-script.outputs.skip != 'true' && steps.cache-memory-analysis.outputs.cache-hit != 'true'
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ~/.platformio
key: platformio-memory-${{ fromJSON(needs.determine-jobs.outputs.memory_impact).platform }}-${{ hashFiles('platformio.ini') }}
- name: Build, compile, and analyze memory
if: steps.check-script.outputs.skip != 'true' && steps.cache-memory-analysis.outputs.cache-hit != 'true'
id: build
run: |
. venv/bin/activate
components='${{ toJSON(fromJSON(needs.determine-jobs.outputs.memory_impact).components) }}'
platform="${{ fromJSON(needs.determine-jobs.outputs.memory_impact).platform }}"
echo "Building with test_build_components.py for $platform with components:"
echo "$components" | jq -r '.[]' | sed 's/^/ - /'
# Use test_build_components.py which handles grouping automatically
# Pass components as comma-separated list
component_list=$(echo "$components" | jq -r 'join(",")')
echo "Compiling with test_build_components.py..."
# Run build and extract memory with auto-detection of build directory for detailed analysis
# Use tee to show output in CI while also piping to extraction script
python script/test_build_components.py \
-e compile \
-c "$component_list" \
-t "$platform" 2>&1 | \
tee /dev/stderr | \
python script/ci_memory_impact_extract.py \
--output-env \
--output-json memory-analysis-target.json
# Add metadata to JSON before caching
python script/ci_add_metadata_to_json.py \
--json-file memory-analysis-target.json \
--components "$components" \
--platform "$platform"
- name: Save memory analysis to cache
if: steps.check-script.outputs.skip != 'true' && steps.cache-memory-analysis.outputs.cache-hit != 'true' && steps.build.outcome == 'success'
uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: memory-analysis-target.json
key: ${{ steps.cache-key.outputs.cache-key }}
- name: Extract memory usage for outputs
id: extract
if: steps.check-script.outputs.skip != 'true'
run: |
if [ -f memory-analysis-target.json ]; then
ram=$(jq -r '.ram_bytes' memory-analysis-target.json)
flash=$(jq -r '.flash_bytes' memory-analysis-target.json)
echo "ram_usage=${ram}" >> $GITHUB_OUTPUT
echo "flash_usage=${flash}" >> $GITHUB_OUTPUT
echo "RAM: ${ram} bytes, Flash: ${flash} bytes"
else
echo "Error: memory-analysis-target.json not found"
exit 1
fi
- name: Upload memory analysis JSON
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: memory-analysis-target
path: memory-analysis-target.json
if-no-files-found: warn
retention-days: 1
memory-impact-pr-branch:
name: Build PR branch for memory impact
runs-on: ubuntu-24.04
needs:
- common
- determine-jobs
if: github.event_name == 'pull_request' && fromJSON(needs.determine-jobs.outputs.memory_impact).should_run == 'true'
outputs:
ram_usage: ${{ steps.extract.outputs.ram_usage }}
flash_usage: ${{ steps.extract.outputs.flash_usage }}
steps:
- name: Check out PR branch
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Restore Python
uses: ./.github/actions/restore-python
with:
python-version: ${{ env.DEFAULT_PYTHON }}
cache-key: ${{ needs.common.outputs.cache-key }}
- name: Cache platformio
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
with:
path: ~/.platformio
key: platformio-memory-${{ fromJSON(needs.determine-jobs.outputs.memory_impact).platform }}-${{ hashFiles('platformio.ini') }}
- name: Build, compile, and analyze memory
id: extract
run: |
. venv/bin/activate
components='${{ toJSON(fromJSON(needs.determine-jobs.outputs.memory_impact).components) }}'
platform="${{ fromJSON(needs.determine-jobs.outputs.memory_impact).platform }}"
echo "Building with test_build_components.py for $platform with components:"
echo "$components" | jq -r '.[]' | sed 's/^/ - /'
# Use test_build_components.py which handles grouping automatically
# Pass components as comma-separated list
component_list=$(echo "$components" | jq -r 'join(",")')
echo "Compiling with test_build_components.py..."
# Run build and extract memory with auto-detection of build directory for detailed analysis
# Use tee to show output in CI while also piping to extraction script
python script/test_build_components.py \
-e compile \
-c "$component_list" \
-t "$platform" 2>&1 | \
tee /dev/stderr | \
python script/ci_memory_impact_extract.py \
--output-env \
--output-json memory-analysis-pr.json
# Add metadata to JSON (components and platform are in shell variables above)
python script/ci_add_metadata_to_json.py \
--json-file memory-analysis-pr.json \
--components "$components" \
--platform "$platform"
- name: Upload memory analysis JSON
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: memory-analysis-pr
path: memory-analysis-pr.json
if-no-files-found: warn
retention-days: 1
memory-impact-comment:
name: Comment memory impact
runs-on: ubuntu-24.04
needs:
- common
- determine-jobs
- memory-impact-target-branch
- memory-impact-pr-branch
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository && fromJSON(needs.determine-jobs.outputs.memory_impact).should_run == 'true' && needs.memory-impact-target-branch.outputs.skip != 'true'
permissions:
contents: read
pull-requests: write
env:
GH_TOKEN: ${{ github.token }}
steps:
- name: Check out code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Restore Python
uses: ./.github/actions/restore-python
with:
python-version: ${{ env.DEFAULT_PYTHON }}
cache-key: ${{ needs.common.outputs.cache-key }}
- name: Download target analysis JSON
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: memory-analysis-target
path: ./memory-analysis
continue-on-error: true
- name: Download PR analysis JSON
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: memory-analysis-pr
path: ./memory-analysis
continue-on-error: true
- name: Post or update PR comment
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
run: |
. venv/bin/activate
# Pass JSON file paths directly to Python script
# All data is extracted from JSON files for security
python script/ci_memory_impact_comment.py \
--pr-number "$PR_NUMBER" \
--target-json ./memory-analysis/memory-analysis-target.json \
--pr-json ./memory-analysis/memory-analysis-pr.json
ci-status:
name: CI Status
runs-on: ubuntu-24.04
@@ -817,9 +535,6 @@ jobs:
- test-build-components-splitter
- test-build-components-split
- pre-commit-ci-lite
- memory-impact-target-branch
- memory-impact-pr-branch
- memory-impact-comment
if: always()
steps:
- name: Success

View File

@@ -185,9 +185,7 @@ def choose_upload_log_host(
else:
resolved.append(device)
if not resolved:
raise EsphomeError(
f"All specified devices {defaults} could not be resolved. Is the device connected to the network?"
)
_LOGGER.error("All specified devices: %s could not be resolved.", defaults)
return resolved
# No devices specified, show interactive chooser
@@ -468,9 +466,7 @@ def write_cpp_file() -> int:
def compile_program(args: ArgsProtocol, config: ConfigType) -> int:
from esphome import platformio_api
# NOTE: "Build path:" format is parsed by script/ci_memory_impact_extract.py
# If you change this format, update the regex in that script as well
_LOGGER.info("Compiling app... Build path: %s", CORE.build_path)
_LOGGER.info("Compiling app...")
rc = platformio_api.run_compile(config, CORE.verbose)
if rc != 0:
return rc

View File

@@ -1,502 +0,0 @@
"""Memory usage analyzer for ESPHome compiled binaries."""
from collections import defaultdict
from dataclasses import dataclass, field
import logging
from pathlib import Path
import re
import subprocess
from typing import TYPE_CHECKING
from .const import (
CORE_SUBCATEGORY_PATTERNS,
DEMANGLED_PATTERNS,
ESPHOME_COMPONENT_PATTERN,
SECTION_TO_ATTR,
SYMBOL_PATTERNS,
)
from .helpers import (
get_component_class_patterns,
get_esphome_components,
map_section_name,
parse_symbol_line,
)
if TYPE_CHECKING:
from esphome.platformio_api import IDEData
_LOGGER = logging.getLogger(__name__)
# GCC global constructor/destructor prefix annotations
_GCC_PREFIX_ANNOTATIONS = {
"_GLOBAL__sub_I_": "global constructor for",
"_GLOBAL__sub_D_": "global destructor for",
}
# GCC optimization suffix pattern (e.g., $isra$0, $part$1, $constprop$2)
_GCC_OPTIMIZATION_SUFFIX_PATTERN = re.compile(r"(\$(?:isra|part|constprop)\$\d+)")
# C++ runtime patterns for categorization
_CPP_RUNTIME_PATTERNS = frozenset(["vtable", "typeinfo", "thunk"])
# libc printf/scanf family base names (used to detect variants like _printf_r, vfprintf, etc.)
_LIBC_PRINTF_SCANF_FAMILY = frozenset(["printf", "fprintf", "sprintf", "scanf"])
# Regex pattern for parsing readelf section headers
# Format: [ #] name type addr off size
_READELF_SECTION_PATTERN = re.compile(
r"\s*\[\s*\d+\]\s+([\.\w]+)\s+\w+\s+[\da-fA-F]+\s+[\da-fA-F]+\s+([\da-fA-F]+)"
)
# Component category prefixes
_COMPONENT_PREFIX_ESPHOME = "[esphome]"
_COMPONENT_PREFIX_EXTERNAL = "[external]"
_COMPONENT_CORE = f"{_COMPONENT_PREFIX_ESPHOME}core"
_COMPONENT_API = f"{_COMPONENT_PREFIX_ESPHOME}api"
# C++ namespace prefixes
_NAMESPACE_ESPHOME = "esphome::"
_NAMESPACE_STD = "std::"
# Type alias for symbol information: (symbol_name, size, component)
SymbolInfoType = tuple[str, int, str]
@dataclass
class MemorySection:
"""Represents a memory section with its symbols."""
name: str
symbols: list[SymbolInfoType] = field(default_factory=list)
total_size: int = 0
@dataclass
class ComponentMemory:
"""Tracks memory usage for a component."""
name: str
text_size: int = 0 # Code in flash
rodata_size: int = 0 # Read-only data in flash
data_size: int = 0 # Initialized data (flash + ram)
bss_size: int = 0 # Uninitialized data (ram only)
symbol_count: int = 0
@property
def flash_total(self) -> int:
"""Total flash usage (text + rodata + data)."""
return self.text_size + self.rodata_size + self.data_size
@property
def ram_total(self) -> int:
"""Total RAM usage (data + bss)."""
return self.data_size + self.bss_size
class MemoryAnalyzer:
"""Analyzes memory usage from ELF files."""
def __init__(
self,
elf_path: str,
objdump_path: str | None = None,
readelf_path: str | None = None,
external_components: set[str] | None = None,
idedata: "IDEData | None" = None,
) -> None:
"""Initialize memory analyzer.
Args:
elf_path: Path to ELF file to analyze
objdump_path: Path to objdump binary (auto-detected from idedata if not provided)
readelf_path: Path to readelf binary (auto-detected from idedata if not provided)
external_components: Set of external component names
idedata: Optional PlatformIO IDEData object to auto-detect toolchain paths
"""
self.elf_path = Path(elf_path)
if not self.elf_path.exists():
raise FileNotFoundError(f"ELF file not found: {elf_path}")
# Auto-detect toolchain paths from idedata if not provided
if idedata is not None and (objdump_path is None or readelf_path is None):
objdump_path = objdump_path or idedata.objdump_path
readelf_path = readelf_path or idedata.readelf_path
_LOGGER.debug("Using toolchain paths from PlatformIO idedata")
self.objdump_path = objdump_path or "objdump"
self.readelf_path = readelf_path or "readelf"
self.external_components = external_components or set()
self.sections: dict[str, MemorySection] = {}
self.components: dict[str, ComponentMemory] = defaultdict(
lambda: ComponentMemory("")
)
self._demangle_cache: dict[str, str] = {}
self._uncategorized_symbols: list[tuple[str, str, int]] = []
self._esphome_core_symbols: list[
tuple[str, str, int]
] = [] # Track core symbols
self._component_symbols: dict[str, list[tuple[str, str, int]]] = defaultdict(
list
) # Track symbols for all components
def analyze(self) -> dict[str, ComponentMemory]:
"""Analyze the ELF file and return component memory usage."""
self._parse_sections()
self._parse_symbols()
self._categorize_symbols()
return dict(self.components)
def _parse_sections(self) -> None:
"""Parse section headers from ELF file."""
result = subprocess.run(
[self.readelf_path, "-S", str(self.elf_path)],
capture_output=True,
text=True,
check=True,
)
# Parse section headers
for line in result.stdout.splitlines():
# Look for section entries
if not (match := _READELF_SECTION_PATTERN.match(line)):
continue
section_name = match.group(1)
size_hex = match.group(2)
size = int(size_hex, 16)
# Map to standard section name
mapped_section = map_section_name(section_name)
if not mapped_section:
continue
if mapped_section not in self.sections:
self.sections[mapped_section] = MemorySection(mapped_section)
self.sections[mapped_section].total_size += size
def _parse_symbols(self) -> None:
"""Parse symbols from ELF file."""
result = subprocess.run(
[self.objdump_path, "-t", str(self.elf_path)],
capture_output=True,
text=True,
check=True,
)
# Track seen addresses to avoid duplicates
seen_addresses: set[str] = set()
for line in result.stdout.splitlines():
if not (symbol_info := parse_symbol_line(line)):
continue
section, name, size, address = symbol_info
# Skip duplicate symbols at the same address (e.g., C1/C2 constructors)
if address in seen_addresses or section not in self.sections:
continue
self.sections[section].symbols.append((name, size, ""))
seen_addresses.add(address)
def _categorize_symbols(self) -> None:
"""Categorize symbols by component."""
# First, collect all unique symbol names for batch demangling
all_symbols = {
symbol_name
for section in self.sections.values()
for symbol_name, _, _ in section.symbols
}
# Batch demangle all symbols at once
self._batch_demangle_symbols(list(all_symbols))
# Now categorize with cached demangled names
for section_name, section in self.sections.items():
for symbol_name, size, _ in section.symbols:
component = self._identify_component(symbol_name)
if component not in self.components:
self.components[component] = ComponentMemory(component)
comp_mem = self.components[component]
comp_mem.symbol_count += 1
# Update the appropriate size attribute based on section
if attr_name := SECTION_TO_ATTR.get(section_name):
setattr(comp_mem, attr_name, getattr(comp_mem, attr_name) + size)
# Track uncategorized symbols
if component == "other" and size > 0:
demangled = self._demangle_symbol(symbol_name)
self._uncategorized_symbols.append((symbol_name, demangled, size))
# Track ESPHome core symbols for detailed analysis
if component == _COMPONENT_CORE and size > 0:
demangled = self._demangle_symbol(symbol_name)
self._esphome_core_symbols.append((symbol_name, demangled, size))
# Track all component symbols for detailed analysis
if size > 0:
demangled = self._demangle_symbol(symbol_name)
self._component_symbols[component].append(
(symbol_name, demangled, size)
)
def _identify_component(self, symbol_name: str) -> str:
"""Identify which component a symbol belongs to."""
# Demangle C++ names if needed
demangled = self._demangle_symbol(symbol_name)
# Check for special component classes first (before namespace pattern)
# This handles cases like esphome::ESPHomeOTAComponent which should map to ota
if _NAMESPACE_ESPHOME in demangled:
# Check for special component classes that include component name in the class
# For example: esphome::ESPHomeOTAComponent -> ota component
for component_name in get_esphome_components():
patterns = get_component_class_patterns(component_name)
if any(pattern in demangled for pattern in patterns):
return f"{_COMPONENT_PREFIX_ESPHOME}{component_name}"
# Check for ESPHome component namespaces
match = ESPHOME_COMPONENT_PATTERN.search(demangled)
if match:
component_name = match.group(1)
# Strip trailing underscore if present (e.g., switch_ -> switch)
component_name = component_name.rstrip("_")
# Check if this is an actual component in the components directory
if component_name in get_esphome_components():
return f"{_COMPONENT_PREFIX_ESPHOME}{component_name}"
# Check if this is a known external component from the config
if component_name in self.external_components:
return f"{_COMPONENT_PREFIX_EXTERNAL}{component_name}"
# Everything else in esphome:: namespace is core
return _COMPONENT_CORE
# Check for esphome core namespace (no component namespace)
if _NAMESPACE_ESPHOME in demangled:
# If no component match found, it's core
return _COMPONENT_CORE
# Check against symbol patterns
for component, patterns in SYMBOL_PATTERNS.items():
if any(pattern in symbol_name for pattern in patterns):
return component
# Check against demangled patterns
for component, patterns in DEMANGLED_PATTERNS.items():
if any(pattern in demangled for pattern in patterns):
return component
# Special cases that need more complex logic
# Check if spi_flash vs spi_driver
if "spi_" in symbol_name or "SPI" in symbol_name:
return "spi_flash" if "spi_flash" in symbol_name else "spi_driver"
# libc special printf variants
if (
symbol_name.startswith("_")
and symbol_name[1:].replace("_r", "").replace("v", "").replace("s", "")
in _LIBC_PRINTF_SCANF_FAMILY
):
return "libc"
# Track uncategorized symbols for analysis
return "other"
def _batch_demangle_symbols(self, symbols: list[str]) -> None:
"""Batch demangle C++ symbol names for efficiency."""
if not symbols:
return
# Try to find the appropriate c++filt for the platform
cppfilt_cmd = "c++filt"
_LOGGER.info("Demangling %d symbols", len(symbols))
_LOGGER.debug("objdump_path = %s", self.objdump_path)
# Check if we have a toolchain-specific c++filt
if self.objdump_path and self.objdump_path != "objdump":
# Replace objdump with c++filt in the path
potential_cppfilt = self.objdump_path.replace("objdump", "c++filt")
_LOGGER.info("Checking for toolchain c++filt at: %s", potential_cppfilt)
if Path(potential_cppfilt).exists():
cppfilt_cmd = potential_cppfilt
_LOGGER.info("✓ Using toolchain c++filt: %s", cppfilt_cmd)
else:
_LOGGER.info(
"✗ Toolchain c++filt not found at %s, using system c++filt",
potential_cppfilt,
)
else:
_LOGGER.info("✗ Using system c++filt (objdump_path=%s)", self.objdump_path)
# Strip GCC optimization suffixes and prefixes before demangling
# Suffixes like $isra$0, $part$0, $constprop$0 confuse c++filt
# Prefixes like _GLOBAL__sub_I_ need to be removed and tracked
symbols_stripped: list[str] = []
symbols_prefixes: list[str] = [] # Track removed prefixes
for symbol in symbols:
# Remove GCC optimization markers
stripped = _GCC_OPTIMIZATION_SUFFIX_PATTERN.sub("", symbol)
# Handle GCC global constructor/initializer prefixes
# _GLOBAL__sub_I_<mangled> -> extract <mangled> for demangling
prefix = ""
for gcc_prefix in _GCC_PREFIX_ANNOTATIONS:
if stripped.startswith(gcc_prefix):
prefix = gcc_prefix
stripped = stripped[len(prefix) :]
break
symbols_stripped.append(stripped)
symbols_prefixes.append(prefix)
try:
# Send all symbols to c++filt at once
result = subprocess.run(
[cppfilt_cmd],
input="\n".join(symbols_stripped),
capture_output=True,
text=True,
check=False,
)
except (subprocess.SubprocessError, OSError, UnicodeDecodeError) as e:
# On error, cache originals
_LOGGER.warning("Failed to batch demangle symbols: %s", e)
for symbol in symbols:
self._demangle_cache[symbol] = symbol
return
if result.returncode != 0:
_LOGGER.warning(
"c++filt exited with code %d: %s",
result.returncode,
result.stderr[:200] if result.stderr else "(no error output)",
)
# Cache originals on failure
for symbol in symbols:
self._demangle_cache[symbol] = symbol
return
# Process demangled output
self._process_demangled_output(
symbols, symbols_stripped, symbols_prefixes, result.stdout, cppfilt_cmd
)
def _process_demangled_output(
self,
symbols: list[str],
symbols_stripped: list[str],
symbols_prefixes: list[str],
demangled_output: str,
cppfilt_cmd: str,
) -> None:
"""Process demangled symbol output and populate cache.
Args:
symbols: Original symbol names
symbols_stripped: Stripped symbol names sent to c++filt
symbols_prefixes: Removed prefixes to restore
demangled_output: Output from c++filt
cppfilt_cmd: Path to c++filt command (for logging)
"""
demangled_lines = demangled_output.strip().split("\n")
failed_count = 0
for original, stripped, prefix, demangled in zip(
symbols, symbols_stripped, symbols_prefixes, demangled_lines
):
# Add back any prefix that was removed
demangled = self._restore_symbol_prefix(prefix, stripped, demangled)
# If we stripped a suffix, add it back to the demangled name for clarity
if original != stripped and not prefix:
demangled = self._restore_symbol_suffix(original, demangled)
self._demangle_cache[original] = demangled
# Log symbols that failed to demangle (stayed the same as stripped version)
if stripped == demangled and stripped.startswith("_Z"):
failed_count += 1
if failed_count <= 5: # Only log first 5 failures
_LOGGER.warning("Failed to demangle: %s", original)
if failed_count == 0:
_LOGGER.info("Successfully demangled all %d symbols", len(symbols))
return
_LOGGER.warning(
"Failed to demangle %d/%d symbols using %s",
failed_count,
len(symbols),
cppfilt_cmd,
)
@staticmethod
def _restore_symbol_prefix(prefix: str, stripped: str, demangled: str) -> str:
"""Restore prefix that was removed before demangling.
Args:
prefix: Prefix that was removed (e.g., "_GLOBAL__sub_I_")
stripped: Stripped symbol name
demangled: Demangled symbol name
Returns:
Demangled name with prefix restored/annotated
"""
if not prefix:
return demangled
# Successfully demangled - add descriptive prefix
if demangled != stripped and (
annotation := _GCC_PREFIX_ANNOTATIONS.get(prefix)
):
return f"[{annotation}: {demangled}]"
# Failed to demangle - restore original prefix
return prefix + demangled
@staticmethod
def _restore_symbol_suffix(original: str, demangled: str) -> str:
"""Restore GCC optimization suffix that was removed before demangling.
Args:
original: Original symbol name with suffix
demangled: Demangled symbol name without suffix
Returns:
Demangled name with suffix annotation
"""
if suffix_match := _GCC_OPTIMIZATION_SUFFIX_PATTERN.search(original):
return f"{demangled} [{suffix_match.group(1)}]"
return demangled
def _demangle_symbol(self, symbol: str) -> str:
"""Get demangled C++ symbol name from cache."""
return self._demangle_cache.get(symbol, symbol)
def _categorize_esphome_core_symbol(self, demangled: str) -> str:
"""Categorize ESPHome core symbols into subcategories."""
# Special patterns that need to be checked separately
if any(pattern in demangled for pattern in _CPP_RUNTIME_PATTERNS):
return "C++ Runtime (vtables/RTTI)"
if demangled.startswith(_NAMESPACE_STD):
return "C++ STL"
# Check against patterns from const.py
for category, patterns in CORE_SUBCATEGORY_PATTERNS.items():
if any(pattern in demangled for pattern in patterns):
return category
return "Other Core"
if __name__ == "__main__":
from .cli import main
main()

View File

@@ -1,6 +0,0 @@
"""Main entry point for running the memory analyzer as a module."""
from .cli import main
if __name__ == "__main__":
main()

View File

@@ -1,408 +0,0 @@
"""CLI interface for memory analysis with report generation."""
from collections import defaultdict
import sys
from . import (
_COMPONENT_API,
_COMPONENT_CORE,
_COMPONENT_PREFIX_ESPHOME,
_COMPONENT_PREFIX_EXTERNAL,
MemoryAnalyzer,
)
class MemoryAnalyzerCLI(MemoryAnalyzer):
"""Memory analyzer with CLI-specific report generation."""
# Column width constants
COL_COMPONENT: int = 29
COL_FLASH_TEXT: int = 14
COL_FLASH_DATA: int = 14
COL_RAM_DATA: int = 12
COL_RAM_BSS: int = 12
COL_TOTAL_FLASH: int = 15
COL_TOTAL_RAM: int = 12
COL_SEPARATOR: int = 3 # " | "
# Core analysis column widths
COL_CORE_SUBCATEGORY: int = 30
COL_CORE_SIZE: int = 12
COL_CORE_COUNT: int = 6
COL_CORE_PERCENT: int = 10
# Calculate table width once at class level
TABLE_WIDTH: int = (
COL_COMPONENT
+ COL_SEPARATOR
+ COL_FLASH_TEXT
+ COL_SEPARATOR
+ COL_FLASH_DATA
+ COL_SEPARATOR
+ COL_RAM_DATA
+ COL_SEPARATOR
+ COL_RAM_BSS
+ COL_SEPARATOR
+ COL_TOTAL_FLASH
+ COL_SEPARATOR
+ COL_TOTAL_RAM
)
@staticmethod
def _make_separator_line(*widths: int) -> str:
"""Create a separator line with given column widths.
Args:
widths: Column widths to create separators for
Returns:
Separator line like "----+---------+-----"
"""
return "-+-".join("-" * width for width in widths)
# Pre-computed separator lines
MAIN_TABLE_SEPARATOR: str = _make_separator_line(
COL_COMPONENT,
COL_FLASH_TEXT,
COL_FLASH_DATA,
COL_RAM_DATA,
COL_RAM_BSS,
COL_TOTAL_FLASH,
COL_TOTAL_RAM,
)
CORE_TABLE_SEPARATOR: str = _make_separator_line(
COL_CORE_SUBCATEGORY,
COL_CORE_SIZE,
COL_CORE_COUNT,
COL_CORE_PERCENT,
)
def generate_report(self, detailed: bool = False) -> str:
"""Generate a formatted memory report."""
components = sorted(
self.components.items(), key=lambda x: x[1].flash_total, reverse=True
)
# Calculate totals
total_flash = sum(c.flash_total for _, c in components)
total_ram = sum(c.ram_total for _, c in components)
# Build report
lines: list[str] = []
lines.append("=" * self.TABLE_WIDTH)
lines.append("Component Memory Analysis".center(self.TABLE_WIDTH))
lines.append("=" * self.TABLE_WIDTH)
lines.append("")
# Main table - fixed column widths
lines.append(
f"{'Component':<{self.COL_COMPONENT}} | {'Flash (text)':>{self.COL_FLASH_TEXT}} | {'Flash (data)':>{self.COL_FLASH_DATA}} | {'RAM (data)':>{self.COL_RAM_DATA}} | {'RAM (bss)':>{self.COL_RAM_BSS}} | {'Total Flash':>{self.COL_TOTAL_FLASH}} | {'Total RAM':>{self.COL_TOTAL_RAM}}"
)
lines.append(self.MAIN_TABLE_SEPARATOR)
for name, mem in components:
if mem.flash_total > 0 or mem.ram_total > 0:
flash_rodata = mem.rodata_size + mem.data_size
lines.append(
f"{name:<{self.COL_COMPONENT}} | {mem.text_size:>{self.COL_FLASH_TEXT - 2},} B | {flash_rodata:>{self.COL_FLASH_DATA - 2},} B | "
f"{mem.data_size:>{self.COL_RAM_DATA - 2},} B | {mem.bss_size:>{self.COL_RAM_BSS - 2},} B | "
f"{mem.flash_total:>{self.COL_TOTAL_FLASH - 2},} B | {mem.ram_total:>{self.COL_TOTAL_RAM - 2},} B"
)
lines.append(self.MAIN_TABLE_SEPARATOR)
lines.append(
f"{'TOTAL':<{self.COL_COMPONENT}} | {' ':>{self.COL_FLASH_TEXT}} | {' ':>{self.COL_FLASH_DATA}} | "
f"{' ':>{self.COL_RAM_DATA}} | {' ':>{self.COL_RAM_BSS}} | "
f"{total_flash:>{self.COL_TOTAL_FLASH - 2},} B | {total_ram:>{self.COL_TOTAL_RAM - 2},} B"
)
# Top consumers
lines.append("")
lines.append("Top Flash Consumers:")
for i, (name, mem) in enumerate(components[:25]):
if mem.flash_total > 0:
percentage = (
(mem.flash_total / total_flash * 100) if total_flash > 0 else 0
)
lines.append(
f"{i + 1}. {name} ({mem.flash_total:,} B) - {percentage:.1f}% of analyzed flash"
)
lines.append("")
lines.append("Top RAM Consumers:")
ram_components = sorted(components, key=lambda x: x[1].ram_total, reverse=True)
for i, (name, mem) in enumerate(ram_components[:25]):
if mem.ram_total > 0:
percentage = (mem.ram_total / total_ram * 100) if total_ram > 0 else 0
lines.append(
f"{i + 1}. {name} ({mem.ram_total:,} B) - {percentage:.1f}% of analyzed RAM"
)
lines.append("")
lines.append(
"Note: This analysis covers symbols in the ELF file. Some runtime allocations may not be included."
)
lines.append("=" * self.TABLE_WIDTH)
# Add ESPHome core detailed analysis if there are core symbols
if self._esphome_core_symbols:
lines.append("")
lines.append("=" * self.TABLE_WIDTH)
lines.append(
f"{_COMPONENT_CORE} Detailed Analysis".center(self.TABLE_WIDTH)
)
lines.append("=" * self.TABLE_WIDTH)
lines.append("")
# Group core symbols by subcategory
core_subcategories: dict[str, list[tuple[str, str, int]]] = defaultdict(
list
)
for symbol, demangled, size in self._esphome_core_symbols:
# Categorize based on demangled name patterns
subcategory = self._categorize_esphome_core_symbol(demangled)
core_subcategories[subcategory].append((symbol, demangled, size))
# Sort subcategories by total size
sorted_subcategories = sorted(
[
(name, symbols, sum(s[2] for s in symbols))
for name, symbols in core_subcategories.items()
],
key=lambda x: x[2],
reverse=True,
)
lines.append(
f"{'Subcategory':<{self.COL_CORE_SUBCATEGORY}} | {'Size':>{self.COL_CORE_SIZE}} | "
f"{'Count':>{self.COL_CORE_COUNT}} | {'% of Core':>{self.COL_CORE_PERCENT}}"
)
lines.append(self.CORE_TABLE_SEPARATOR)
core_total = sum(size for _, _, size in self._esphome_core_symbols)
for subcategory, symbols, total_size in sorted_subcategories:
percentage = (total_size / core_total * 100) if core_total > 0 else 0
lines.append(
f"{subcategory:<{self.COL_CORE_SUBCATEGORY}} | {total_size:>{self.COL_CORE_SIZE - 2},} B | "
f"{len(symbols):>{self.COL_CORE_COUNT}} | {percentage:>{self.COL_CORE_PERCENT - 1}.1f}%"
)
# Top 15 largest core symbols
lines.append("")
lines.append(f"Top 15 Largest {_COMPONENT_CORE} Symbols:")
sorted_core_symbols = sorted(
self._esphome_core_symbols, key=lambda x: x[2], reverse=True
)
for i, (symbol, demangled, size) in enumerate(sorted_core_symbols[:15]):
lines.append(f"{i + 1}. {demangled} ({size:,} B)")
lines.append("=" * self.TABLE_WIDTH)
# Add detailed analysis for top ESPHome and external components
esphome_components = [
(name, mem)
for name, mem in components
if name.startswith(_COMPONENT_PREFIX_ESPHOME) and name != _COMPONENT_CORE
]
external_components = [
(name, mem)
for name, mem in components
if name.startswith(_COMPONENT_PREFIX_EXTERNAL)
]
top_esphome_components = sorted(
esphome_components, key=lambda x: x[1].flash_total, reverse=True
)[:30]
# Include all external components (they're usually important)
top_external_components = sorted(
external_components, key=lambda x: x[1].flash_total, reverse=True
)
# Check if API component exists and ensure it's included
api_component = None
for name, mem in components:
if name == _COMPONENT_API:
api_component = (name, mem)
break
# Combine all components to analyze: top ESPHome + all external + API if not already included
components_to_analyze = list(top_esphome_components) + list(
top_external_components
)
if api_component and api_component not in components_to_analyze:
components_to_analyze.append(api_component)
if components_to_analyze:
for comp_name, comp_mem in components_to_analyze:
if not (comp_symbols := self._component_symbols.get(comp_name, [])):
continue
lines.append("")
lines.append("=" * self.TABLE_WIDTH)
lines.append(f"{comp_name} Detailed Analysis".center(self.TABLE_WIDTH))
lines.append("=" * self.TABLE_WIDTH)
lines.append("")
# Sort symbols by size
sorted_symbols = sorted(comp_symbols, key=lambda x: x[2], reverse=True)
lines.append(f"Total symbols: {len(sorted_symbols)}")
lines.append(f"Total size: {comp_mem.flash_total:,} B")
lines.append("")
# Show all symbols > 100 bytes for better visibility
large_symbols = [
(sym, dem, size) for sym, dem, size in sorted_symbols if size > 100
]
lines.append(
f"{comp_name} Symbols > 100 B ({len(large_symbols)} symbols):"
)
for i, (symbol, demangled, size) in enumerate(large_symbols):
lines.append(f"{i + 1}. {demangled} ({size:,} B)")
lines.append("=" * self.TABLE_WIDTH)
return "\n".join(lines)
def dump_uncategorized_symbols(self, output_file: str | None = None) -> None:
"""Dump uncategorized symbols for analysis."""
# Sort by size descending
sorted_symbols = sorted(
self._uncategorized_symbols, key=lambda x: x[2], reverse=True
)
lines = ["Uncategorized Symbols Analysis", "=" * 80]
lines.append(f"Total uncategorized symbols: {len(sorted_symbols)}")
lines.append(
f"Total uncategorized size: {sum(s[2] for s in sorted_symbols):,} bytes"
)
lines.append("")
lines.append(f"{'Size':>10} | {'Symbol':<60} | Demangled")
lines.append("-" * 10 + "-+-" + "-" * 60 + "-+-" + "-" * 40)
for symbol, demangled, size in sorted_symbols[:100]: # Top 100
demangled_display = (
demangled[:100] if symbol != demangled else "[not demangled]"
)
lines.append(f"{size:>10,} | {symbol[:60]:<60} | {demangled_display}")
if len(sorted_symbols) > 100:
lines.append(f"\n... and {len(sorted_symbols) - 100} more symbols")
content = "\n".join(lines)
if output_file:
with open(output_file, "w", encoding="utf-8") as f:
f.write(content)
else:
print(content)
def analyze_elf(
elf_path: str,
objdump_path: str | None = None,
readelf_path: str | None = None,
detailed: bool = False,
external_components: set[str] | None = None,
) -> str:
"""Analyze an ELF file and return a memory report."""
analyzer = MemoryAnalyzerCLI(
elf_path, objdump_path, readelf_path, external_components
)
analyzer.analyze()
return analyzer.generate_report(detailed)
def main():
"""CLI entrypoint for memory analysis."""
if len(sys.argv) < 2:
print("Usage: python -m esphome.analyze_memory <build_directory>")
print("\nAnalyze memory usage from an ESPHome build directory.")
print("The build directory should contain firmware.elf and idedata will be")
print("loaded from ~/.esphome/.internal/idedata/<device>.json")
print("\nExamples:")
print(" python -m esphome.analyze_memory ~/.esphome/build/my-device")
print(" python -m esphome.analyze_memory .esphome/build/my-device")
print(" python -m esphome.analyze_memory my-device # Short form")
sys.exit(1)
build_dir = sys.argv[1]
# Load build directory
import json
from pathlib import Path
from esphome.platformio_api import IDEData
build_path = Path(build_dir)
# If no path separator in name, assume it's a device name
if "/" not in build_dir and not build_path.is_dir():
# Try current directory first
cwd_path = Path.cwd() / ".esphome" / "build" / build_dir
if cwd_path.is_dir():
build_path = cwd_path
print(f"Using build directory: {build_path}", file=sys.stderr)
else:
# Fall back to home directory
build_path = Path.home() / ".esphome" / "build" / build_dir
print(f"Using build directory: {build_path}", file=sys.stderr)
if not build_path.is_dir():
print(f"Error: {build_path} is not a directory", file=sys.stderr)
sys.exit(1)
# Find firmware.elf
elf_file = None
for elf_candidate in [
build_path / "firmware.elf",
build_path / ".pioenvs" / build_path.name / "firmware.elf",
]:
if elf_candidate.exists():
elf_file = str(elf_candidate)
break
if not elf_file:
print(f"Error: firmware.elf not found in {build_dir}", file=sys.stderr)
sys.exit(1)
# Find idedata.json - check current directory first, then home
device_name = build_path.name
idedata_candidates = [
Path.cwd() / ".esphome" / "idedata" / f"{device_name}.json",
Path.home() / ".esphome" / "idedata" / f"{device_name}.json",
]
idedata = None
for idedata_path in idedata_candidates:
if not idedata_path.exists():
continue
try:
with open(idedata_path, encoding="utf-8") as f:
raw_data = json.load(f)
idedata = IDEData(raw_data)
print(f"Loaded idedata from: {idedata_path}", file=sys.stderr)
break
except (json.JSONDecodeError, OSError) as e:
print(f"Warning: Failed to load idedata: {e}", file=sys.stderr)
if not idedata:
print(
f"Warning: idedata not found (searched {idedata_candidates[0]} and {idedata_candidates[1]})",
file=sys.stderr,
)
analyzer = MemoryAnalyzerCLI(elf_file, idedata=idedata)
analyzer.analyze()
report = analyzer.generate_report()
print(report)
if __name__ == "__main__":
main()

View File

@@ -1,903 +0,0 @@
"""Constants for memory analysis symbol pattern matching."""
import re
# Pattern to extract ESPHome component namespaces dynamically
ESPHOME_COMPONENT_PATTERN = re.compile(r"esphome::([a-zA-Z0-9_]+)::")
# Section mapping for ELF file sections
# Maps standard section names to their various platform-specific variants
SECTION_MAPPING = {
".text": frozenset([".text", ".iram"]),
".rodata": frozenset([".rodata"]),
".data": frozenset([".data", ".dram"]),
".bss": frozenset([".bss"]),
}
# Section to ComponentMemory attribute mapping
# Maps section names to the attribute name in ComponentMemory dataclass
SECTION_TO_ATTR = {
".text": "text_size",
".rodata": "rodata_size",
".data": "data_size",
".bss": "bss_size",
}
# Component identification rules
# Symbol patterns: patterns found in raw symbol names
SYMBOL_PATTERNS = {
"freertos": [
"vTask",
"xTask",
"xQueue",
"pvPort",
"vPort",
"uxTask",
"pcTask",
"prvTimerTask",
"prvAddNewTaskToReadyList",
"pxReadyTasksLists",
"prvAddCurrentTaskToDelayedList",
"xEventGroupWaitBits",
"xRingbufferSendFromISR",
"prvSendItemDoneNoSplit",
"prvReceiveGeneric",
"prvSendAcquireGeneric",
"prvCopyItemAllowSplit",
"xEventGroup",
"xRingbuffer",
"prvSend",
"prvReceive",
"prvCopy",
"xPort",
"ulTaskGenericNotifyTake",
"prvIdleTask",
"prvInitialiseNewTask",
"prvIsYieldRequiredSMP",
"prvGetItemByteBuf",
"prvInitializeNewRingbuffer",
"prvAcquireItemNoSplit",
"prvNotifyQueueSetContainer",
"ucStaticTimerQueueStorage",
"eTaskGetState",
"main_task",
"do_system_init_fn",
"xSemaphoreCreateGenericWithCaps",
"vListInsert",
"uxListRemove",
"vRingbufferReturnItem",
"vRingbufferReturnItemFromISR",
"prvCheckItemFitsByteBuffer",
"prvGetCurMaxSizeAllowSplit",
"tick_hook",
"sys_sem_new",
"sys_arch_mbox_fetch",
"sys_arch_sem_wait",
"prvDeleteTCB",
"vQueueDeleteWithCaps",
"vRingbufferDeleteWithCaps",
"vSemaphoreDeleteWithCaps",
"prvCheckItemAvail",
"prvCheckTaskCanBeScheduledSMP",
"prvGetCurMaxSizeNoSplit",
"prvResetNextTaskUnblockTime",
"prvReturnItemByteBuf",
"vApplicationStackOverflowHook",
"vApplicationGetIdleTaskMemory",
"sys_init",
"sys_mbox_new",
"sys_arch_mbox_tryfetch",
],
"xtensa": ["xt_", "_xt_", "xPortEnterCriticalTimeout"],
"heap": ["heap_", "multi_heap"],
"spi_flash": ["spi_flash"],
"rtc": ["rtc_", "rtcio_ll_"],
"gpio_driver": ["gpio_", "pins"],
"uart_driver": ["uart", "_uart", "UART"],
"timer": ["timer_", "esp_timer"],
"peripherals": ["periph_", "periman"],
"network_stack": [
"vj_compress",
"raw_sendto",
"raw_input",
"etharp_",
"icmp_input",
"socket_ipv6",
"ip_napt",
"socket_ipv4_multicast",
"socket_ipv6_multicast",
"netconn_",
"recv_raw",
"accept_function",
"netconn_recv_data",
"netconn_accept",
"netconn_write_vectors_partly",
"netconn_drain",
"raw_connect",
"raw_bind",
"icmp_send_response",
"sockets",
"icmp_dest_unreach",
"inet_chksum_pseudo",
"alloc_socket",
"done_socket",
"set_global_fd_sets",
"inet_chksum_pbuf",
"tryget_socket_unconn_locked",
"tryget_socket_unconn",
"cs_create_ctrl_sock",
"netbuf_alloc",
],
"ipv6_stack": ["nd6_", "ip6_", "mld6_", "icmp6_", "icmp6_input"],
"wifi_stack": [
"ieee80211",
"hostap",
"sta_",
"ap_",
"scan_",
"wifi_",
"wpa_",
"wps_",
"esp_wifi",
"cnx_",
"wpa3_",
"sae_",
"wDev_",
"ic_",
"mac_",
"esf_buf",
"gWpaSm",
"sm_WPA",
"eapol_",
"owe_",
"wifiLowLevelInit",
"s_do_mapping",
"gScanStruct",
"ppSearchTxframe",
"ppMapWaitTxq",
"ppFillAMPDUBar",
"ppCheckTxConnTrafficIdle",
"ppCalTkipMic",
],
"bluetooth": ["bt_", "ble_", "l2c_", "gatt_", "gap_", "hci_", "BT_init"],
"wifi_bt_coex": ["coex"],
"bluetooth_rom": ["r_ble", "r_lld", "r_llc", "r_llm"],
"bluedroid_bt": [
"bluedroid",
"btc_",
"bta_",
"btm_",
"btu_",
"BTM_",
"GATT",
"L2CA_",
"smp_",
"gatts_",
"attp_",
"l2cu_",
"l2cb",
"smp_cb",
"BTA_GATTC_",
"SMP_",
"BTU_",
"BTA_Dm",
"GAP_Ble",
"BT_tx_if",
"host_recv_pkt_cb",
"saved_local_oob_data",
"string_to_bdaddr",
"string_is_bdaddr",
"CalConnectParamTimeout",
"transmit_fragment",
"transmit_data",
"event_command_ready",
"read_command_complete_header",
"parse_read_local_extended_features_response",
"parse_read_local_version_info_response",
"should_request_high",
"btdm_wakeup_request",
"BTA_SetAttributeValue",
"BTA_EnableBluetooth",
"transmit_command_futured",
"transmit_command",
"get_waiting_command",
"make_command",
"transmit_downward",
"host_recv_adv_packet",
"copy_extra_byte_in_db",
"parse_read_local_supported_commands_response",
],
"crypto_math": [
"ecp_",
"bignum_",
"mpi_",
"sswu",
"modp",
"dragonfly_",
"gcm_mult",
"__multiply",
"quorem",
"__mdiff",
"__lshift",
"__mprec_tens",
"ECC_",
"multiprecision_",
"mix_sub_columns",
"sbox",
"gfm2_sbox",
"gfm3_sbox",
"curve_p256",
"curve",
"p_256_init_curve",
"shift_sub_rows",
"rshift",
],
"hw_crypto": ["esp_aes", "esp_sha", "esp_rsa", "esp_bignum", "esp_mpi"],
"libc": [
"printf",
"scanf",
"malloc",
"free",
"memcpy",
"memset",
"strcpy",
"strlen",
"_dtoa",
"_fopen",
"__sfvwrite_r",
"qsort",
"__sf",
"__sflush_r",
"__srefill_r",
"_impure_data",
"_reclaim_reent",
"_open_r",
"strncpy",
"_strtod_l",
"__gethex",
"__hexnan",
"_setenv_r",
"_tzset_unlocked_r",
"__tzcalc_limits",
"select",
"scalbnf",
"strtof",
"strtof_l",
"__d2b",
"__b2d",
"__s2b",
"_Balloc",
"__multadd",
"__lo0bits",
"__atexit0",
"__smakebuf_r",
"__swhatbuf_r",
"_sungetc_r",
"_close_r",
"_link_r",
"_unsetenv_r",
"_rename_r",
"__month_lengths",
"tzinfo",
"__ratio",
"__hi0bits",
"__ulp",
"__any_on",
"__copybits",
"L_shift",
"_fcntl_r",
"_lseek_r",
"_read_r",
"_write_r",
"_unlink_r",
"_fstat_r",
"access",
"fsync",
"tcsetattr",
"tcgetattr",
"tcflush",
"tcdrain",
"__ssrefill_r",
"_stat_r",
"__hexdig_fun",
"__mcmp",
"_fwalk_sglue",
"__fpclassifyf",
"_setlocale_r",
"_mbrtowc_r",
"fcntl",
"__match",
"_lock_close",
"__c$",
"__func__$",
"__FUNCTION__$",
"DAYS_IN_MONTH",
"_DAYS_BEFORE_MONTH",
"CSWTCH$",
"dst$",
"sulp",
],
"string_ops": ["strcmp", "strncmp", "strchr", "strstr", "strtok", "strdup"],
"memory_alloc": ["malloc", "calloc", "realloc", "free", "_sbrk"],
"file_io": [
"fread",
"fwrite",
"fopen",
"fclose",
"fseek",
"ftell",
"fflush",
"s_fd_table",
],
"string_formatting": [
"snprintf",
"vsnprintf",
"sprintf",
"vsprintf",
"sscanf",
"vsscanf",
],
"cpp_anonymous": ["_GLOBAL__N_", "n$"],
"cpp_runtime": ["__cxx", "_ZN", "_ZL", "_ZSt", "__gxx_personality", "_Z16"],
"exception_handling": ["__cxa_", "_Unwind_", "__gcc_personality", "uw_frame_state"],
"static_init": ["_GLOBAL__sub_I_"],
"mdns_lib": ["mdns"],
"phy_radio": [
"phy_",
"rf_",
"chip_",
"register_chipv7",
"pbus_",
"bb_",
"fe_",
"rfcal_",
"ram_rfcal",
"tx_pwctrl",
"rx_chan",
"set_rx_gain",
"set_chan",
"agc_reg",
"ram_txiq",
"ram_txdc",
"ram_gen_rx_gain",
"rx_11b_opt",
"set_rx_sense",
"set_rx_gain_cal",
"set_chan_dig_gain",
"tx_pwctrl_init_cal",
"rfcal_txiq",
"set_tx_gain_table",
"correct_rfpll_offset",
"pll_correct_dcap",
"txiq_cal_init",
"pwdet_sar",
"pwdet_sar2_init",
"ram_iq_est_enable",
"ram_rfpll_set_freq",
"ant_wifirx_cfg",
"ant_btrx_cfg",
"force_txrxoff",
"force_txrx_off",
"tx_paon_set",
"opt_11b_resart",
"rfpll_1p2_opt",
"ram_dc_iq_est",
"ram_start_tx_tone",
"ram_en_pwdet",
"ram_cbw2040_cfg",
"rxdc_est_min",
"i2cmst_reg_init",
"temprature_sens_read",
"ram_restart_cal",
"ram_write_gain_mem",
"ram_wait_rfpll_cal_end",
"txcal_debuge_mode",
"ant_wifitx_cfg",
"reg_init_begin",
],
"wifi_phy_pp": ["pp_", "ppT", "ppR", "ppP", "ppInstall", "ppCalTxAMPDULength"],
"wifi_lmac": ["lmac"],
"wifi_device": ["wdev", "wDev_"],
"power_mgmt": [
"pm_",
"sleep",
"rtc_sleep",
"light_sleep",
"deep_sleep",
"power_down",
"g_pm",
],
"memory_mgmt": [
"mem_",
"memory_",
"tlsf_",
"memp_",
"pbuf_",
"pbuf_alloc",
"pbuf_copy_partial_pbuf",
],
"hal_layer": ["hal_"],
"clock_mgmt": [
"clk_",
"clock_",
"rtc_clk",
"apb_",
"cpu_freq",
"setCpuFrequencyMhz",
],
"cache_mgmt": ["cache"],
"flash_ops": ["flash", "image_load"],
"interrupt_handlers": [
"isr",
"interrupt",
"intr_",
"exc_",
"exception",
"port_IntStack",
],
"wrapper_functions": ["_wrapper"],
"error_handling": ["panic", "abort", "assert", "error_", "fault"],
"authentication": ["auth"],
"ppp_protocol": ["ppp", "ipcp_", "lcp_", "chap_", "LcpEchoCheck"],
"dhcp": ["dhcp", "handle_dhcp"],
"ethernet_phy": [
"emac_",
"eth_phy_",
"phy_tlk110",
"phy_lan87",
"phy_ip101",
"phy_rtl",
"phy_dp83",
"phy_ksz",
"lan87xx_",
"rtl8201_",
"ip101_",
"ksz80xx_",
"jl1101_",
"dp83848_",
"eth_on_state_changed",
],
"threading": ["pthread_", "thread_", "_task_"],
"pthread": ["pthread"],
"synchronization": ["mutex", "semaphore", "spinlock", "portMUX"],
"math_lib": [
"sin",
"cos",
"tan",
"sqrt",
"pow",
"exp",
"log",
"atan",
"asin",
"acos",
"floor",
"ceil",
"fabs",
"round",
],
"random": ["rand", "random", "rng_", "prng"],
"time_lib": [
"time",
"clock",
"gettimeofday",
"settimeofday",
"localtime",
"gmtime",
"mktime",
"strftime",
],
"console_io": ["console_", "uart_tx", "uart_rx", "puts", "putchar", "getchar"],
"rom_functions": ["r_", "rom_"],
"compiler_runtime": [
"__divdi3",
"__udivdi3",
"__moddi3",
"__muldi3",
"__ashldi3",
"__ashrdi3",
"__lshrdi3",
"__cmpdi2",
"__fixdfdi",
"__floatdidf",
],
"libgcc": ["libgcc", "_divdi3", "_udivdi3"],
"boot_startup": ["boot", "start_cpu", "call_start", "startup", "bootloader"],
"bootloader": ["bootloader_", "esp_bootloader"],
"app_framework": ["app_", "initArduino", "setup", "loop", "Update"],
"weak_symbols": ["__weak_"],
"compiler_builtins": ["__builtin_"],
"vfs": ["vfs_", "VFS"],
"esp32_sdk": ["esp32_", "esp32c", "esp32s"],
"usb": ["usb_", "USB", "cdc_", "CDC"],
"i2c_driver": ["i2c_", "I2C"],
"i2s_driver": ["i2s_", "I2S"],
"spi_driver": ["spi_", "SPI"],
"adc_driver": ["adc_", "ADC"],
"dac_driver": ["dac_", "DAC"],
"touch_driver": ["touch_", "TOUCH"],
"pwm_driver": ["pwm_", "PWM", "ledc_", "LEDC"],
"rmt_driver": ["rmt_", "RMT"],
"pcnt_driver": ["pcnt_", "PCNT"],
"can_driver": ["can_", "CAN", "twai_", "TWAI"],
"sdmmc_driver": ["sdmmc_", "SDMMC", "sdcard", "sd_card"],
"temp_sensor": ["temp_sensor", "tsens_"],
"watchdog": ["wdt_", "WDT", "watchdog"],
"brownout": ["brownout", "bod_"],
"ulp": ["ulp_", "ULP"],
"psram": ["psram", "PSRAM", "spiram", "SPIRAM"],
"efuse": ["efuse", "EFUSE"],
"partition": ["partition", "esp_partition"],
"esp_event": ["esp_event", "event_loop", "event_callback"],
"esp_console": ["esp_console", "console_"],
"chip_specific": ["chip_", "esp_chip"],
"esp_system_utils": ["esp_system", "esp_hw", "esp_clk", "esp_sleep"],
"ipc": ["esp_ipc", "ipc_"],
"wifi_config": [
"g_cnxMgr",
"gChmCxt",
"g_ic",
"TxRxCxt",
"s_dp",
"s_ni",
"s_reg_dump",
"packet$",
"d_mult_table",
"K",
"fcstab",
],
"smartconfig": ["sc_ack_send"],
"rc_calibration": ["rc_cal", "rcUpdate"],
"noise_floor": ["noise_check"],
"rf_calibration": [
"set_rx_sense",
"set_rx_gain_cal",
"set_chan_dig_gain",
"tx_pwctrl_init_cal",
"rfcal_txiq",
"set_tx_gain_table",
"correct_rfpll_offset",
"pll_correct_dcap",
"txiq_cal_init",
"pwdet_sar",
"rx_11b_opt",
],
"wifi_crypto": [
"pk_use_ecparams",
"process_segments",
"ccmp_",
"rc4_",
"aria_",
"mgf_mask",
"dh_group",
"ccmp_aad_nonce",
"ccmp_encrypt",
"rc4_skip",
"aria_sb1",
"aria_sb2",
"aria_is1",
"aria_is2",
"aria_sl",
"aria_a",
],
"radio_control": ["fsm_input", "fsm_sconfreq"],
"pbuf": [
"pbuf_",
],
"event_group": ["xEventGroup"],
"ringbuffer": ["xRingbuffer", "prvSend", "prvReceive", "prvCopy"],
"provisioning": ["prov_", "prov_stop_and_notify"],
"scan": ["gScanStruct"],
"port": ["xPort"],
"elf_loader": [
"elf_add",
"elf_add_note",
"elf_add_segment",
"process_image",
"read_encoded",
"read_encoded_value",
"read_encoded_value_with_base",
"process_image_header",
],
"socket_api": [
"sockets",
"netconn_",
"accept_function",
"recv_raw",
"socket_ipv4_multicast",
"socket_ipv6_multicast",
],
"igmp": ["igmp_", "igmp_send", "igmp_input"],
"icmp6": ["icmp6_"],
"arp": ["arp_table"],
"ampdu": [
"ampdu_",
"rcAmpdu",
"trc_onAmpduOp",
"rcAmpduLowerRate",
"ampdu_dispatch_upto",
],
"ieee802_11": ["ieee802_11_", "ieee802_11_parse_elems"],
"rate_control": ["rssi_margin", "rcGetSched", "get_rate_fcc_index"],
"nan": ["nan_dp_", "nan_dp_post_tx", "nan_dp_delete_peer"],
"channel_mgmt": ["chm_init", "chm_set_current_channel"],
"trace": ["trc_init", "trc_onAmpduOp"],
"country_code": ["country_info", "country_info_24ghz"],
"multicore": ["do_multicore_settings"],
"Update_lib": ["Update"],
"stdio": [
"__sf",
"__sflush_r",
"__srefill_r",
"_impure_data",
"_reclaim_reent",
"_open_r",
],
"strncpy_ops": ["strncpy"],
"math_internal": ["__mdiff", "__lshift", "__mprec_tens", "quorem"],
"character_class": ["__chclass"],
"camellia": ["camellia_", "camellia_feistel"],
"crypto_tables": ["FSb", "FSb2", "FSb3", "FSb4"],
"event_buffer": ["g_eb_list_desc", "eb_space"],
"base_node": ["base_node_", "base_node_add_handler"],
"file_descriptor": ["s_fd_table"],
"tx_delay": ["tx_delay_cfg"],
"deinit": ["deinit_functions"],
"lcp_echo": ["LcpEchoCheck"],
"raw_api": ["raw_bind", "raw_connect"],
"checksum": ["process_checksum"],
"entry_management": ["add_entry"],
"esp_ota": ["esp_ota", "ota_", "read_otadata"],
"http_server": [
"httpd_",
"parse_url_char",
"cb_headers_complete",
"delete_entry",
"validate_structure",
"config_save",
"config_new",
"verify_url",
"cb_url",
],
"misc_system": [
"alarm_cbs",
"start_up",
"tokens",
"unhex",
"osi_funcs_ro",
"enum_function",
"fragment_and_dispatch",
"alarm_set",
"osi_alarm_new",
"config_set_string",
"config_update_newest_section",
"config_remove_key",
"method_strings",
"interop_match",
"interop_database",
"__state_table",
"__action_table",
"s_stub_table",
"s_context",
"s_mmu_ctx",
"s_get_bus_mask",
"hli_queue_put",
"list_remove",
"list_delete",
"lock_acquire_generic",
"is_vect_desc_usable",
"io_mode_str",
"__c$20233",
"interface",
"read_id_core",
"subscribe_idle",
"unsubscribe_idle",
"s_clkout_handle",
"lock_release_generic",
"config_set_int",
"config_get_int",
"config_get_string",
"config_has_key",
"config_remove_section",
"osi_alarm_init",
"osi_alarm_deinit",
"fixed_queue_enqueue",
"fixed_queue_dequeue",
"fixed_queue_new",
"fixed_pkt_queue_enqueue",
"fixed_pkt_queue_new",
"list_append",
"list_prepend",
"list_insert_after",
"list_contains",
"list_get_node",
"hash_function_blob",
"cb_no_body",
"cb_on_body",
"profile_tab",
"get_arg",
"trim",
"buf$",
"process_appended_hash_and_sig$constprop$0",
"uuidType",
"allocate_svc_db_buf",
"_hostname_is_ours",
"s_hli_handlers",
"tick_cb",
"idle_cb",
"input",
"entry_find",
"section_find",
"find_bucket_entry_",
"config_has_section",
"hli_queue_create",
"hli_queue_get",
"hli_c_handler",
"future_ready",
"future_await",
"future_new",
"pkt_queue_enqueue",
"pkt_queue_dequeue",
"pkt_queue_cleanup",
"pkt_queue_create",
"pkt_queue_destroy",
"fixed_pkt_queue_dequeue",
"osi_alarm_cancel",
"osi_alarm_is_active",
"osi_sem_take",
"osi_event_create",
"osi_event_bind",
"alarm_cb_handler",
"list_foreach",
"list_back",
"list_front",
"list_clear",
"fixed_queue_try_peek_first",
"translate_path",
"get_idx",
"find_key",
"init",
"end",
"start",
"set_read_value",
"copy_address_list",
"copy_and_key",
"sdk_cfg_opts",
"leftshift_onebit",
"config_section_end",
"config_section_begin",
"find_entry_and_check_all_reset",
"image_validate",
"xPendingReadyList",
"vListInitialise",
"lock_init_generic",
"ant_bttx_cfg",
"ant_dft_cfg",
"cs_send_to_ctrl_sock",
"config_llc_util_funcs_reset",
"make_set_adv_report_flow_control",
"make_set_event_mask",
"raw_new",
"raw_remove",
"BTE_InitStack",
"parse_read_local_supported_features_response",
"__math_invalidf",
"tinytens",
"__mprec_tinytens",
"__mprec_bigtens",
"vRingbufferDelete",
"vRingbufferDeleteWithCaps",
"vRingbufferReturnItem",
"vRingbufferReturnItemFromISR",
"get_acl_data_size_ble",
"get_features_ble",
"get_features_classic",
"get_acl_packet_size_ble",
"get_acl_packet_size_classic",
"supports_extended_inquiry_response",
"supports_rssi_with_inquiry_results",
"supports_interlaced_inquiry_scan",
"supports_reading_remote_extended_features",
],
"bluetooth_ll": [
"lld_pdu_",
"ld_acl_",
"lld_stop_ind_handler",
"lld_evt_winsize_change",
"config_lld_evt_funcs_reset",
"config_lld_funcs_reset",
"config_llm_funcs_reset",
"llm_set_long_adv_data",
"lld_retry_tx_prog",
"llc_link_sup_to_ind_handler",
"config_llc_funcs_reset",
"lld_evt_rxwin_compute",
"config_btdm_funcs_reset",
"config_ea_funcs_reset",
"llc_defalut_state_tab_reset",
"config_rwip_funcs_reset",
"ke_lmp_rx_flooding_detect",
],
}
# Demangled patterns: patterns found in demangled C++ names
DEMANGLED_PATTERNS = {
"gpio_driver": ["GPIO"],
"uart_driver": ["UART"],
"network_stack": [
"lwip",
"tcp",
"udp",
"ip4",
"ip6",
"dhcp",
"dns",
"netif",
"ethernet",
"ppp",
"slip",
],
"wifi_stack": ["NetworkInterface"],
"nimble_bt": [
"nimble",
"NimBLE",
"ble_hs",
"ble_gap",
"ble_gatt",
"ble_att",
"ble_l2cap",
"ble_sm",
],
"crypto": ["mbedtls", "crypto", "sha", "aes", "rsa", "ecc", "tls", "ssl"],
"cpp_stdlib": ["std::", "__gnu_cxx::", "__cxxabiv"],
"static_init": ["__static_initialization"],
"rtti": ["__type_info", "__class_type_info"],
"web_server_lib": ["AsyncWebServer", "AsyncWebHandler", "WebServer"],
"async_tcp": ["AsyncClient", "AsyncServer"],
"mdns_lib": ["mdns"],
"json_lib": [
"ArduinoJson",
"JsonDocument",
"JsonArray",
"JsonObject",
"deserialize",
"serialize",
],
"http_lib": ["HTTP", "http_", "Request", "Response", "Uri", "WebSocket"],
"logging": ["log", "Log", "print", "Print", "diag_"],
"authentication": ["checkDigestAuthentication"],
"libgcc": ["libgcc"],
"esp_system": ["esp_", "ESP"],
"arduino": ["arduino"],
"nvs": ["nvs_", "_ZTVN3nvs", "nvs::"],
"filesystem": ["spiffs", "vfs"],
"libc": ["newlib"],
}
# Patterns for categorizing ESPHome core symbols into subcategories
CORE_SUBCATEGORY_PATTERNS = {
"Component Framework": ["Component"],
"Application Core": ["Application"],
"Scheduler": ["Scheduler"],
"Component Iterator": ["ComponentIterator"],
"Helper Functions": ["Helpers", "helpers"],
"Preferences/Storage": ["Preferences", "ESPPreferences"],
"I/O Utilities": ["HighFrequencyLoopRequester"],
"String Utilities": ["str_"],
"Bit Utilities": ["reverse_bits"],
"Data Conversion": ["convert_"],
"Network Utilities": ["network", "IPAddress"],
"API Protocol": ["api::"],
"WiFi Manager": ["wifi::"],
"MQTT Client": ["mqtt::"],
"Logger": ["logger::"],
"OTA Updates": ["ota::"],
"Web Server": ["web_server::"],
"Time Management": ["time::"],
"Sensor Framework": ["sensor::"],
"Binary Sensor": ["binary_sensor::"],
"Switch Framework": ["switch_::"],
"Light Framework": ["light::"],
"Climate Framework": ["climate::"],
"Cover Framework": ["cover::"],
}

View File

@@ -1,121 +0,0 @@
"""Helper functions for memory analysis."""
from functools import cache
from pathlib import Path
from .const import SECTION_MAPPING
# Import namespace constant from parent module
# Note: This would create a circular import if done at module level,
# so we'll define it locally here as well
_NAMESPACE_ESPHOME = "esphome::"
# Get the list of actual ESPHome components by scanning the components directory
@cache
def get_esphome_components():
"""Get set of actual ESPHome components from the components directory."""
# Find the components directory relative to this file
# Go up two levels from analyze_memory/helpers.py to esphome/
current_dir = Path(__file__).parent.parent
components_dir = current_dir / "components"
if not components_dir.exists() or not components_dir.is_dir():
return frozenset()
return frozenset(
item.name
for item in components_dir.iterdir()
if item.is_dir()
and not item.name.startswith(".")
and not item.name.startswith("__")
)
@cache
def get_component_class_patterns(component_name: str) -> list[str]:
"""Generate component class name patterns for symbol matching.
Args:
component_name: The component name (e.g., "ota", "wifi", "api")
Returns:
List of pattern strings to match against demangled symbols
"""
component_upper = component_name.upper()
component_camel = component_name.replace("_", "").title()
return [
f"{_NAMESPACE_ESPHOME}{component_upper}Component", # e.g., esphome::OTAComponent
f"{_NAMESPACE_ESPHOME}ESPHome{component_upper}Component", # e.g., esphome::ESPHomeOTAComponent
f"{_NAMESPACE_ESPHOME}{component_camel}Component", # e.g., esphome::OtaComponent
f"{_NAMESPACE_ESPHOME}ESPHome{component_camel}Component", # e.g., esphome::ESPHomeOtaComponent
]
def map_section_name(raw_section: str) -> str | None:
"""Map raw section name to standard section.
Args:
raw_section: Raw section name from ELF file (e.g., ".iram0.text", ".rodata.str1.1")
Returns:
Standard section name (".text", ".rodata", ".data", ".bss") or None
"""
for standard_section, patterns in SECTION_MAPPING.items():
if any(pattern in raw_section for pattern in patterns):
return standard_section
return None
def parse_symbol_line(line: str) -> tuple[str, str, int, str] | None:
"""Parse a single symbol line from objdump output.
Args:
line: Line from objdump -t output
Returns:
Tuple of (section, name, size, address) or None if not a valid symbol.
Format: address l/g w/d F/O section size name
Example: 40084870 l F .iram0.text 00000000 _xt_user_exc
"""
parts = line.split()
if len(parts) < 5:
return None
try:
# Validate and extract address
address = parts[0]
int(address, 16)
except ValueError:
return None
# Look for F (function) or O (object) flag
if "F" not in parts and "O" not in parts:
return None
# Find section, size, and name
for i, part in enumerate(parts):
if not part.startswith("."):
continue
section = map_section_name(part)
if not section:
break
# Need at least size field after section
if i + 1 >= len(parts):
break
try:
size = int(parts[i + 1], 16)
except ValueError:
break
# Need symbol name and non-zero size
if i + 2 >= len(parts) or size == 0:
break
name = " ".join(parts[i + 2 :])
return (section, name, size, address)
return None

View File

@@ -506,7 +506,7 @@ message ListEntitiesLightResponse {
string name = 3;
reserved 4; // Deprecated: was string unique_id
repeated ColorMode supported_color_modes = 12 [(container_pointer_no_template) = "light::ColorModeMask"];
repeated ColorMode supported_color_modes = 12 [(container_pointer) = "std::set<light::ColorMode>"];
// next four supports_* are for legacy clients, newer clients should use color modes
// Deprecated in API version 1.6
bool legacy_supports_brightness = 5 [deprecated=true];

View File

@@ -453,6 +453,7 @@ uint16_t APIConnection::try_send_light_state(EntityBase *entity, APIConnection *
bool is_single) {
auto *light = static_cast<light::LightState *>(entity);
LightStateResponse resp;
auto traits = light->get_traits();
auto values = light->remote_values;
auto color_mode = values.get_color_mode();
resp.state = values.is_on();
@@ -476,8 +477,7 @@ uint16_t APIConnection::try_send_light_info(EntityBase *entity, APIConnection *c
auto *light = static_cast<light::LightState *>(entity);
ListEntitiesLightResponse msg;
auto traits = light->get_traits();
// Pass pointer to ColorModeMask so the iterator can encode actual ColorMode enum values
msg.supported_color_modes = &traits.get_supported_color_modes();
msg.supported_color_modes = &traits.get_supported_color_modes_for_api_();
if (traits.supports_color_capability(light::ColorCapability::COLOR_TEMPERATURE) ||
traits.supports_color_capability(light::ColorCapability::COLD_WARM_WHITE)) {
msg.min_mireds = traits.get_min_mireds();
@@ -1082,8 +1082,13 @@ void APIConnection::on_get_time_response(const GetTimeResponse &value) {
homeassistant::global_homeassistant_time->set_epoch_time(value.epoch_seconds);
#ifdef USE_TIME_TIMEZONE
if (value.timezone_len > 0) {
homeassistant::global_homeassistant_time->set_timezone(reinterpret_cast<const char *>(value.timezone),
value.timezone_len);
const std::string &current_tz = homeassistant::global_homeassistant_time->get_timezone();
// Compare without allocating a string
if (current_tz.length() != value.timezone_len ||
memcmp(current_tz.c_str(), value.timezone, value.timezone_len) != 0) {
homeassistant::global_homeassistant_time->set_timezone(
std::string(reinterpret_cast<const char *>(value.timezone), value.timezone_len));
}
}
#endif
}

View File

@@ -70,14 +70,4 @@ extend google.protobuf.FieldOptions {
// init(size) before adding elements. This eliminates std::vector template overhead
// and is ideal when the exact size is known before populating the array.
optional bool fixed_vector = 50013 [default=false];
// container_pointer_no_template: Use a non-template container type for repeated fields
// Similar to container_pointer, but for containers that don't take template parameters.
// The container type is used as-is without appending element type.
// The container must have:
// - begin() and end() methods returning iterators
// - empty() method
// Example: [(container_pointer_no_template) = "light::ColorModeMask"]
// generates: const light::ColorModeMask *supported_color_modes{};
optional string container_pointer_no_template = 50014;
}

View File

@@ -790,7 +790,7 @@ class ListEntitiesLightResponse final : public InfoResponseProtoMessage {
#ifdef HAS_PROTO_MESSAGE_DUMP
const char *message_name() const override { return "list_entities_light_response"; }
#endif
const light::ColorModeMask *supported_color_modes{};
const std::set<light::ColorMode> *supported_color_modes{};
float min_mireds{0.0f};
float max_mireds{0.0f};
std::vector<std::string> effects{};

View File

@@ -155,12 +155,16 @@ esp32_ble_tracker::AdvertisementParserType BluetoothProxy::get_advertisement_par
BluetoothConnection *BluetoothProxy::get_connection_(uint64_t address, bool reserve) {
for (uint8_t i = 0; i < this->connection_count_; i++) {
auto *connection = this->connections_[i];
uint64_t conn_addr = connection->get_address();
if (conn_addr == address)
if (connection->get_address() == address)
return connection;
}
if (reserve && conn_addr == 0) {
if (!reserve)
return nullptr;
for (uint8_t i = 0; i < this->connection_count_; i++) {
auto *connection = this->connections_[i];
if (connection->get_address() == 0) {
connection->send_service_ = INIT_SENDING_SERVICES;
connection->set_address(address);
// All connections must start at INIT
@@ -171,6 +175,7 @@ BluetoothConnection *BluetoothProxy::get_connection_(uint64_t address, bool rese
return connection;
}
}
return nullptr;
}

View File

@@ -1,8 +1,8 @@
#pragma once
#include <set>
#include "climate_mode.h"
#include "esphome/core/helpers.h"
#include "climate_mode.h"
#include <set>
namespace esphome {
@@ -109,12 +109,44 @@ class ClimateTraits {
void set_supported_modes(std::set<ClimateMode> modes) { this->supported_modes_ = std::move(modes); }
void add_supported_mode(ClimateMode mode) { this->supported_modes_.insert(mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_auto_mode(bool supports_auto_mode) { set_mode_support_(CLIMATE_MODE_AUTO, supports_auto_mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_cool_mode(bool supports_cool_mode) { set_mode_support_(CLIMATE_MODE_COOL, supports_cool_mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_heat_mode(bool supports_heat_mode) { set_mode_support_(CLIMATE_MODE_HEAT, supports_heat_mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_heat_cool_mode(bool supported) { set_mode_support_(CLIMATE_MODE_HEAT_COOL, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_fan_only_mode(bool supports_fan_only_mode) {
set_mode_support_(CLIMATE_MODE_FAN_ONLY, supports_fan_only_mode);
}
ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20")
void set_supports_dry_mode(bool supports_dry_mode) { set_mode_support_(CLIMATE_MODE_DRY, supports_dry_mode); }
bool supports_mode(ClimateMode mode) const { return this->supported_modes_.count(mode); }
const std::set<ClimateMode> &get_supported_modes() const { return this->supported_modes_; }
void set_supported_fan_modes(std::set<ClimateFanMode> modes) { this->supported_fan_modes_ = std::move(modes); }
void add_supported_fan_mode(ClimateFanMode mode) { this->supported_fan_modes_.insert(mode); }
void add_supported_custom_fan_mode(const std::string &mode) { this->supported_custom_fan_modes_.insert(mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_on(bool supported) { set_fan_mode_support_(CLIMATE_FAN_ON, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_off(bool supported) { set_fan_mode_support_(CLIMATE_FAN_OFF, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_auto(bool supported) { set_fan_mode_support_(CLIMATE_FAN_AUTO, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_low(bool supported) { set_fan_mode_support_(CLIMATE_FAN_LOW, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_medium(bool supported) { set_fan_mode_support_(CLIMATE_FAN_MEDIUM, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_high(bool supported) { set_fan_mode_support_(CLIMATE_FAN_HIGH, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_middle(bool supported) { set_fan_mode_support_(CLIMATE_FAN_MIDDLE, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_focus(bool supported) { set_fan_mode_support_(CLIMATE_FAN_FOCUS, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20")
void set_supports_fan_mode_diffuse(bool supported) { set_fan_mode_support_(CLIMATE_FAN_DIFFUSE, supported); }
bool supports_fan_mode(ClimateFanMode fan_mode) const { return this->supported_fan_modes_.count(fan_mode); }
bool get_supports_fan_modes() const {
return !this->supported_fan_modes_.empty() || !this->supported_custom_fan_modes_.empty();
@@ -146,6 +178,16 @@ class ClimateTraits {
void set_supported_swing_modes(std::set<ClimateSwingMode> modes) { this->supported_swing_modes_ = std::move(modes); }
void add_supported_swing_mode(ClimateSwingMode mode) { this->supported_swing_modes_.insert(mode); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_off(bool supported) { set_swing_mode_support_(CLIMATE_SWING_OFF, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_both(bool supported) { set_swing_mode_support_(CLIMATE_SWING_BOTH, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_vertical(bool supported) { set_swing_mode_support_(CLIMATE_SWING_VERTICAL, supported); }
ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20")
void set_supports_swing_mode_horizontal(bool supported) {
set_swing_mode_support_(CLIMATE_SWING_HORIZONTAL, supported);
}
bool supports_swing_mode(ClimateSwingMode swing_mode) const { return this->supported_swing_modes_.count(swing_mode); }
bool get_supports_swing_modes() const { return !this->supported_swing_modes_.empty(); }
const std::set<ClimateSwingMode> &get_supported_swing_modes() const { return this->supported_swing_modes_; }

View File

@@ -6,7 +6,6 @@
#include <freertos/FreeRTOS.h>
#include <freertos/task.h>
#include <esp_idf_version.h>
#include <esp_ota_ops.h>
#include <esp_task_wdt.h>
#include <esp_timer.h>
#include <soc/rtc.h>
@@ -53,16 +52,6 @@ void arch_init() {
disableCore1WDT();
#endif
#endif
// If the bootloader was compiled with CONFIG_BOOTLOADER_APP_ROLLBACK_ENABLE the current
// partition will get rolled back unless it is marked as valid.
esp_ota_img_states_t state;
const esp_partition_t *running = esp_ota_get_running_partition();
if (esp_ota_get_state_partition(running, &state) == ESP_OK) {
if (state == ESP_OTA_IMG_PENDING_VERIFY) {
esp_ota_mark_app_valid_cancel_rollback();
}
}
}
void IRAM_ATTR HOT arch_feed_wdt() { esp_task_wdt_reset(); }

View File

@@ -61,7 +61,12 @@ class BLEClientBase : public espbt::ESPBTClient, public Component {
this->address_str_ = "";
} else {
char buf[18];
format_mac_addr_upper(this->remote_bda_, buf);
uint8_t mac[6] = {
(uint8_t) ((this->address_ >> 40) & 0xff), (uint8_t) ((this->address_ >> 32) & 0xff),
(uint8_t) ((this->address_ >> 24) & 0xff), (uint8_t) ((this->address_ >> 16) & 0xff),
(uint8_t) ((this->address_ >> 8) & 0xff), (uint8_t) ((this->address_ >> 0) & 0xff),
};
format_mac_addr_upper(mac, buf);
this->address_str_ = buf;
}
}

View File

@@ -14,7 +14,7 @@ void Kuntze::on_modbus_data(const std::vector<uint8_t> &data) {
auto get_16bit = [&](int i) -> uint16_t { return (uint16_t(data[i * 2]) << 8) | uint16_t(data[i * 2 + 1]); };
this->waiting_ = false;
ESP_LOGV(TAG, "Data: %s", format_hex_pretty(data).c_str());
ESP_LOGV(TAG, "Data: %s", hexencode(data).c_str());
float value = (float) get_16bit(0);
for (int i = 0; i < data[3]; i++)

View File

@@ -1,11 +1,11 @@
#pragma once
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/color.h"
#include "esp_color_correction.h"
#include "esp_color_view.h"
#include "esp_range_view.h"
#include "esphome/core/color.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "light_output.h"
#include "light_state.h"
#include "transformers.h"
@@ -17,6 +17,8 @@
namespace esphome {
namespace light {
using ESPColor ESPDEPRECATED("esphome::light::ESPColor is deprecated, use esphome::Color instead.", "v1.21") = Color;
/// Convert the color information from a `LightColorValues` object to a `Color` object (does not apply brightness).
Color color_from_light_color_values(LightColorValues val);

View File

@@ -104,200 +104,5 @@ constexpr ColorModeHelper operator|(ColorModeHelper lhs, ColorMode rhs) {
return static_cast<ColorMode>(static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs));
}
// Type alias for raw color mode bitmask values
using color_mode_bitmask_t = uint16_t;
// Constants for ColorMode count and bit range
static constexpr int COLOR_MODE_COUNT = 10; // UNKNOWN through RGB_COLD_WARM_WHITE
static constexpr int MAX_BIT_INDEX = sizeof(color_mode_bitmask_t) * 8; // Number of bits in bitmask type
// Compile-time array of all ColorMode values in declaration order
// Bit positions (0-9) map directly to enum declaration order
static constexpr ColorMode COLOR_MODES[COLOR_MODE_COUNT] = {
ColorMode::UNKNOWN, // bit 0
ColorMode::ON_OFF, // bit 1
ColorMode::BRIGHTNESS, // bit 2
ColorMode::WHITE, // bit 3
ColorMode::COLOR_TEMPERATURE, // bit 4
ColorMode::COLD_WARM_WHITE, // bit 5
ColorMode::RGB, // bit 6
ColorMode::RGB_WHITE, // bit 7
ColorMode::RGB_COLOR_TEMPERATURE, // bit 8
ColorMode::RGB_COLD_WARM_WHITE, // bit 9
};
/// Map ColorMode enum values to bit positions (0-9)
/// Bit positions follow the enum declaration order
static constexpr int mode_to_bit(ColorMode mode) {
// Linear search through COLOR_MODES array
// Compiler optimizes this to efficient code since array is constexpr
for (int i = 0; i < COLOR_MODE_COUNT; ++i) {
if (COLOR_MODES[i] == mode)
return i;
}
return 0;
}
/// Map bit positions (0-9) to ColorMode enum values
/// Bit positions follow the enum declaration order
static constexpr ColorMode bit_to_mode(int bit) {
// Direct lookup in COLOR_MODES array
return (bit >= 0 && bit < COLOR_MODE_COUNT) ? COLOR_MODES[bit] : ColorMode::UNKNOWN;
}
/// Helper to compute capability bitmask at compile time
static constexpr color_mode_bitmask_t compute_capability_bitmask(ColorCapability capability) {
color_mode_bitmask_t mask = 0;
uint8_t cap_bit = static_cast<uint8_t>(capability);
// Check each ColorMode to see if it has this capability
for (int bit = 0; bit < COLOR_MODE_COUNT; ++bit) {
uint8_t mode_val = static_cast<uint8_t>(bit_to_mode(bit));
if ((mode_val & cap_bit) != 0) {
mask |= (1 << bit);
}
}
return mask;
}
// Number of ColorCapability enum values
static constexpr int COLOR_CAPABILITY_COUNT = 6;
/// Compile-time lookup table mapping ColorCapability to bitmask
/// This array is computed at compile time using constexpr
static constexpr color_mode_bitmask_t CAPABILITY_BITMASKS[] = {
compute_capability_bitmask(ColorCapability::ON_OFF), // 1 << 0
compute_capability_bitmask(ColorCapability::BRIGHTNESS), // 1 << 1
compute_capability_bitmask(ColorCapability::WHITE), // 1 << 2
compute_capability_bitmask(ColorCapability::COLOR_TEMPERATURE), // 1 << 3
compute_capability_bitmask(ColorCapability::COLD_WARM_WHITE), // 1 << 4
compute_capability_bitmask(ColorCapability::RGB), // 1 << 5
};
/// Bitmask for storing a set of ColorMode values efficiently.
/// Replaces std::set<ColorMode> to eliminate red-black tree overhead (~586 bytes).
class ColorModeMask {
public:
constexpr ColorModeMask() = default;
/// Support initializer list syntax: {ColorMode::RGB, ColorMode::WHITE}
constexpr ColorModeMask(std::initializer_list<ColorMode> modes) {
for (auto mode : modes) {
this->add(mode);
}
}
constexpr void add(ColorMode mode) { this->mask_ |= (1 << mode_to_bit(mode)); }
/// Add multiple modes at once using initializer list
constexpr void add(std::initializer_list<ColorMode> modes) {
for (auto mode : modes) {
this->add(mode);
}
}
constexpr bool contains(ColorMode mode) const { return (this->mask_ & (1 << mode_to_bit(mode))) != 0; }
constexpr size_t size() const {
// Count set bits using Brian Kernighan's algorithm
// More efficient for sparse bitmasks (typical case: 2-4 modes out of 10)
uint16_t n = this->mask_;
size_t count = 0;
while (n) {
n &= n - 1; // Clear the least significant set bit
count++;
}
return count;
}
constexpr bool empty() const { return this->mask_ == 0; }
/// Iterator support for API encoding
class Iterator {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = ColorMode;
using difference_type = std::ptrdiff_t;
using pointer = const ColorMode *;
using reference = ColorMode;
constexpr Iterator(color_mode_bitmask_t mask, int bit) : mask_(mask), bit_(bit) { advance_to_next_set_bit_(); }
constexpr ColorMode operator*() const { return bit_to_mode(bit_); }
constexpr Iterator &operator++() {
++bit_;
advance_to_next_set_bit_();
return *this;
}
constexpr bool operator==(const Iterator &other) const { return bit_ == other.bit_; }
constexpr bool operator!=(const Iterator &other) const { return !(*this == other); }
private:
constexpr void advance_to_next_set_bit_() { bit_ = ColorModeMask::find_next_set_bit(mask_, bit_); }
color_mode_bitmask_t mask_;
int bit_;
};
constexpr Iterator begin() const { return Iterator(mask_, 0); }
constexpr Iterator end() const { return Iterator(mask_, MAX_BIT_INDEX); }
/// Get the raw bitmask value for API encoding
constexpr color_mode_bitmask_t get_mask() const { return this->mask_; }
/// Find the next set bit in a bitmask starting from a given position
/// Returns the bit position, or MAX_BIT_INDEX if no more bits are set
static constexpr int find_next_set_bit(color_mode_bitmask_t mask, int start_bit) {
int bit = start_bit;
while (bit < MAX_BIT_INDEX && !(mask & (1 << bit))) {
++bit;
}
return bit;
}
/// Find the first set bit in a bitmask and return the corresponding ColorMode
/// Used for optimizing compute_color_mode_() intersection logic
static constexpr ColorMode first_mode_from_mask(color_mode_bitmask_t mask) {
return bit_to_mode(find_next_set_bit(mask, 0));
}
/// Check if a ColorMode is present in a raw bitmask value
/// Useful for checking intersection results without creating a temporary ColorModeMask
static constexpr bool mask_contains(color_mode_bitmask_t mask, ColorMode mode) {
return (mask & (1 << mode_to_bit(mode))) != 0;
}
/// Check if any mode in the bitmask has a specific capability
/// Used for checking if a light supports a capability (e.g., BRIGHTNESS, RGB)
bool has_capability(ColorCapability capability) const {
// Lookup the pre-computed bitmask for this capability and check intersection with our mask
// ColorCapability values: 1, 2, 4, 8, 16, 32 -> array indices: 0, 1, 2, 3, 4, 5
// We need to convert the power-of-2 value to an index
uint8_t cap_val = static_cast<uint8_t>(capability);
#if defined(__GNUC__) || defined(__clang__)
// Use compiler intrinsic for efficient bit position lookup (O(1) vs O(log n))
int index = __builtin_ctz(cap_val);
#else
// Fallback for compilers without __builtin_ctz
int index = 0;
while (cap_val > 1) {
cap_val >>= 1;
++index;
}
#endif
return (this->mask_ & CAPABILITY_BITMASKS[index]) != 0;
}
private:
// Using uint16_t instead of uint32_t for more efficient iteration (fewer bits to scan).
// Currently only 10 ColorMode values exist, so 16 bits is sufficient.
// Can be changed to uint32_t if more than 16 color modes are needed in the future.
// Note: Due to struct padding, uint16_t and uint32_t result in same LightTraits size (12 bytes).
color_mode_bitmask_t mask_{0};
};
} // namespace light
} // namespace esphome

View File

@@ -406,7 +406,7 @@ void LightCall::transform_parameters_() {
}
}
ColorMode LightCall::compute_color_mode_() {
const auto &supported_modes = this->parent_->get_traits().get_supported_color_modes();
auto supported_modes = this->parent_->get_traits().get_supported_color_modes();
int supported_count = supported_modes.size();
// Some lights don't support any color modes (e.g. monochromatic light), leave it at unknown.
@@ -425,19 +425,20 @@ ColorMode LightCall::compute_color_mode_() {
// If no color mode is specified, we try to guess the color mode. This is needed for backward compatibility to
// pre-colormode clients and automations, but also for the MQTT API, where HA doesn't let us know which color mode
// was used for some reason.
// Compute intersection of suitable and supported modes using bitwise AND
color_mode_bitmask_t intersection = this->get_suitable_color_modes_mask_() & supported_modes.get_mask();
std::set<ColorMode> suitable_modes = this->get_suitable_color_modes_();
// Don't change if the current mode is in the intersection (suitable AND supported)
if (ColorModeMask::mask_contains(intersection, current_mode)) {
// Don't change if the current mode is suitable.
if (suitable_modes.count(current_mode) > 0) {
ESP_LOGI(TAG, "'%s': color mode not specified; retaining %s", this->parent_->get_name().c_str(),
LOG_STR_ARG(color_mode_to_human(current_mode)));
return current_mode;
}
// Use the preferred suitable mode.
if (intersection != 0) {
ColorMode mode = ColorModeMask::first_mode_from_mask(intersection);
for (auto mode : suitable_modes) {
if (supported_modes.count(mode) == 0)
continue;
ESP_LOGI(TAG, "'%s': color mode not specified; using %s", this->parent_->get_name().c_str(),
LOG_STR_ARG(color_mode_to_human(mode)));
return mode;
@@ -450,7 +451,7 @@ ColorMode LightCall::compute_color_mode_() {
LOG_STR_ARG(color_mode_to_human(color_mode)));
return color_mode;
}
color_mode_bitmask_t LightCall::get_suitable_color_modes_mask_() {
std::set<ColorMode> LightCall::get_suitable_color_modes_() {
bool has_white = this->has_white() && this->white_ > 0.0f;
bool has_ct = this->has_color_temperature();
bool has_cwww =
@@ -458,44 +459,36 @@ color_mode_bitmask_t LightCall::get_suitable_color_modes_mask_() {
bool has_rgb = (this->has_color_brightness() && this->color_brightness_ > 0.0f) ||
(this->has_red() || this->has_green() || this->has_blue());
// Build key from flags: [rgb][cwww][ct][white]
// Build key from flags: [rgb][cwww][ct][white]
#define KEY(white, ct, cwww, rgb) ((white) << 0 | (ct) << 1 | (cwww) << 2 | (rgb) << 3)
uint8_t key = KEY(has_white, has_ct, has_cwww, has_rgb);
switch (key) {
case KEY(true, false, false, false): // white only
return ColorModeMask({ColorMode::WHITE, ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE,
ColorMode::COLD_WARM_WHITE, ColorMode::RGB_COLD_WARM_WHITE})
.get_mask();
return {ColorMode::WHITE, ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::COLD_WARM_WHITE,
ColorMode::RGB_COLD_WARM_WHITE};
case KEY(false, true, false, false): // ct only
return ColorModeMask({ColorMode::COLOR_TEMPERATURE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::COLD_WARM_WHITE,
ColorMode::RGB_COLD_WARM_WHITE})
.get_mask();
return {ColorMode::COLOR_TEMPERATURE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::COLD_WARM_WHITE,
ColorMode::RGB_COLD_WARM_WHITE};
case KEY(true, true, false, false): // white + ct
return ColorModeMask(
{ColorMode::COLD_WARM_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE})
.get_mask();
return {ColorMode::COLD_WARM_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE};
case KEY(false, false, true, false): // cwww only
return ColorModeMask({ColorMode::COLD_WARM_WHITE, ColorMode::RGB_COLD_WARM_WHITE}).get_mask();
return {ColorMode::COLD_WARM_WHITE, ColorMode::RGB_COLD_WARM_WHITE};
case KEY(false, false, false, false): // none
return ColorModeMask({ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE,
ColorMode::RGB, ColorMode::WHITE, ColorMode::COLOR_TEMPERATURE, ColorMode::COLD_WARM_WHITE})
.get_mask();
return {ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE, ColorMode::RGB,
ColorMode::WHITE, ColorMode::COLOR_TEMPERATURE, ColorMode::COLD_WARM_WHITE};
case KEY(true, false, false, true): // rgb + white
return ColorModeMask({ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE})
.get_mask();
return {ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE};
case KEY(false, true, false, true): // rgb + ct
case KEY(true, true, false, true): // rgb + white + ct
return ColorModeMask({ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE}).get_mask();
return {ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE};
case KEY(false, false, true, true): // rgb + cwww
return ColorModeMask({ColorMode::RGB_COLD_WARM_WHITE}).get_mask();
return {ColorMode::RGB_COLD_WARM_WHITE};
case KEY(false, false, false, true): // rgb only
return ColorModeMask({ColorMode::RGB, ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE,
ColorMode::RGB_COLD_WARM_WHITE})
.get_mask();
return {ColorMode::RGB, ColorMode::RGB_WHITE, ColorMode::RGB_COLOR_TEMPERATURE, ColorMode::RGB_COLD_WARM_WHITE};
default:
return 0; // conflicting flags
return {}; // conflicting flags
}
#undef KEY

View File

@@ -1,6 +1,7 @@
#pragma once
#include "light_color_values.h"
#include <set>
namespace esphome {
@@ -185,8 +186,8 @@ class LightCall {
//// Compute the color mode that should be used for this call.
ColorMode compute_color_mode_();
/// Get potential color modes bitmask for this light call.
color_mode_bitmask_t get_suitable_color_modes_mask_();
/// Get potential color modes for this light call.
std::set<ColorMode> get_suitable_color_modes_();
/// Some color modes also can be set using non-native parameters, transform those calls.
void transform_parameters_();

View File

@@ -43,6 +43,7 @@ void LightJSONSchema::dump_json(LightState &state, JsonObject root) {
}
auto values = state.remote_values;
auto traits = state.get_output()->get_traits();
const auto color_mode = values.get_color_mode();
const char *mode_str = get_color_mode_json_str(color_mode);

View File

@@ -191,9 +191,11 @@ void LightState::current_values_as_brightness(float *brightness) {
this->current_values.as_brightness(brightness, this->gamma_correct_);
}
void LightState::current_values_as_rgb(float *red, float *green, float *blue, bool color_interlock) {
auto traits = this->get_traits();
this->current_values.as_rgb(red, green, blue, this->gamma_correct_, false);
}
void LightState::current_values_as_rgbw(float *red, float *green, float *blue, float *white, bool color_interlock) {
auto traits = this->get_traits();
this->current_values.as_rgbw(red, green, blue, white, this->gamma_correct_, false);
}
void LightState::current_values_as_rgbww(float *red, float *green, float *blue, float *cold_white, float *warm_white,
@@ -207,6 +209,7 @@ void LightState::current_values_as_rgbct(float *red, float *green, float *blue,
white_brightness, this->gamma_correct_);
}
void LightState::current_values_as_cwww(float *cold_white, float *warm_white, bool constant_brightness) {
auto traits = this->get_traits();
this->current_values.as_cwww(cold_white, warm_white, this->gamma_correct_, constant_brightness);
}
void LightState::current_values_as_ct(float *color_temperature, float *white_brightness) {

View File

@@ -1,7 +1,8 @@
#pragma once
#include "color_mode.h"
#include "esphome/core/helpers.h"
#include "color_mode.h"
#include <set>
namespace esphome {
@@ -18,17 +19,38 @@ class LightTraits {
public:
LightTraits() = default;
const ColorModeMask &get_supported_color_modes() const { return this->supported_color_modes_; }
void set_supported_color_modes(ColorModeMask supported_color_modes) {
this->supported_color_modes_ = supported_color_modes;
}
void set_supported_color_modes(std::initializer_list<ColorMode> modes) {
this->supported_color_modes_ = ColorModeMask(modes);
const std::set<ColorMode> &get_supported_color_modes() const { return this->supported_color_modes_; }
void set_supported_color_modes(std::set<ColorMode> supported_color_modes) {
this->supported_color_modes_ = std::move(supported_color_modes);
}
bool supports_color_mode(ColorMode color_mode) const { return this->supported_color_modes_.contains(color_mode); }
bool supports_color_mode(ColorMode color_mode) const { return this->supported_color_modes_.count(color_mode); }
bool supports_color_capability(ColorCapability color_capability) const {
return this->supported_color_modes_.has_capability(color_capability);
for (auto mode : this->supported_color_modes_) {
if (mode & color_capability)
return true;
}
return false;
}
ESPDEPRECATED("get_supports_brightness() is deprecated, use color modes instead.", "v1.21")
bool get_supports_brightness() const { return this->supports_color_capability(ColorCapability::BRIGHTNESS); }
ESPDEPRECATED("get_supports_rgb() is deprecated, use color modes instead.", "v1.21")
bool get_supports_rgb() const { return this->supports_color_capability(ColorCapability::RGB); }
ESPDEPRECATED("get_supports_rgb_white_value() is deprecated, use color modes instead.", "v1.21")
bool get_supports_rgb_white_value() const {
return this->supports_color_mode(ColorMode::RGB_WHITE) ||
this->supports_color_mode(ColorMode::RGB_COLOR_TEMPERATURE);
}
ESPDEPRECATED("get_supports_color_temperature() is deprecated, use color modes instead.", "v1.21")
bool get_supports_color_temperature() const {
return this->supports_color_capability(ColorCapability::COLOR_TEMPERATURE);
}
ESPDEPRECATED("get_supports_color_interlock() is deprecated, use color modes instead.", "v1.21")
bool get_supports_color_interlock() const {
return this->supports_color_mode(ColorMode::RGB) &&
(this->supports_color_mode(ColorMode::WHITE) || this->supports_color_mode(ColorMode::COLD_WARM_WHITE) ||
this->supports_color_mode(ColorMode::COLOR_TEMPERATURE));
}
float get_min_mireds() const { return this->min_mireds_; }
@@ -37,9 +59,19 @@ class LightTraits {
void set_max_mireds(float max_mireds) { this->max_mireds_ = max_mireds; }
protected:
#ifdef USE_API
// The API connection is a friend class to access internal methods
friend class api::APIConnection;
// This method returns a reference to the internal color modes set.
// It is used by the API to avoid copying data when encoding messages.
// Warning: Do not use this method outside of the API connection code.
// It returns a reference to internal data that can be invalidated.
const std::set<ColorMode> &get_supported_color_modes_for_api_() const { return this->supported_color_modes_; }
#endif
std::set<ColorMode> supported_color_modes_{};
float min_mireds_{0};
float max_mireds_{0};
ColorModeMask supported_color_modes_{};
};
} // namespace light

View File

@@ -31,17 +31,18 @@ void MDNSComponent::setup() {
mdns_instance_name_set(this->hostname_.c_str());
for (const auto &service : services) {
auto txt_records = std::make_unique<mdns_txt_item_t[]>(service.txt_records.size());
for (size_t i = 0; i < service.txt_records.size(); i++) {
const auto &record = service.txt_records[i];
std::vector<mdns_txt_item_t> txt_records;
for (const auto &record : service.txt_records) {
mdns_txt_item_t it{};
// key and value are either compile-time string literals in flash or pointers to dynamic_txt_values_
// Both remain valid for the lifetime of this function, and ESP-IDF makes internal copies
txt_records[i].key = MDNS_STR_ARG(record.key);
txt_records[i].value = MDNS_STR_ARG(record.value);
it.key = MDNS_STR_ARG(record.key);
it.value = MDNS_STR_ARG(record.value);
txt_records.push_back(it);
}
uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value();
err = mdns_service_add(nullptr, MDNS_STR_ARG(service.service_type), MDNS_STR_ARG(service.proto), port,
txt_records.get(), service.txt_records.size());
txt_records.data(), txt_records.size());
if (err != ESP_OK) {
ESP_LOGW(TAG, "Failed to register service %s: %s", MDNS_STR_ARG(service.service_type), esp_err_to_name(err));

View File

@@ -140,8 +140,11 @@ void MQTTClientComponent::send_device_info_() {
#endif
#ifdef USE_API_NOISE
root[api::global_api_server->get_noise_ctx()->has_psk() ? "api_encryption" : "api_encryption_supported"] =
"Noise_NNpsk0_25519_ChaChaPoly_SHA256";
if (api::global_api_server->get_noise_ctx()->has_psk()) {
root["api_encryption"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
} else {
root["api_encryption_supported"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
}
#endif
},
2, this->discovery_info_.retain);

View File

@@ -85,20 +85,24 @@ bool MQTTComponent::send_discovery_() {
}
// Fields from EntityBase
root[MQTT_NAME] = this->get_entity()->has_own_name() ? this->friendly_name() : "";
if (this->get_entity()->has_own_name()) {
root[MQTT_NAME] = this->friendly_name();
} else {
root[MQTT_NAME] = "";
}
if (this->is_disabled_by_default())
root[MQTT_ENABLED_BY_DEFAULT] = false;
if (!this->get_icon().empty())
root[MQTT_ICON] = this->get_icon();
const auto entity_category = this->get_entity()->get_entity_category();
switch (entity_category) {
switch (this->get_entity()->get_entity_category()) {
case ENTITY_CATEGORY_NONE:
break;
case ENTITY_CATEGORY_CONFIG:
root[MQTT_ENTITY_CATEGORY] = "config";
break;
case ENTITY_CATEGORY_DIAGNOSTIC:
root[MQTT_ENTITY_CATEGORY] = entity_category == ENTITY_CATEGORY_CONFIG ? "config" : "diagnostic";
root[MQTT_ENTITY_CATEGORY] = "diagnostic";
break;
}
@@ -109,14 +113,20 @@ bool MQTTComponent::send_discovery_() {
if (this->command_retain_)
root[MQTT_COMMAND_RETAIN] = true;
const Availability &avail =
this->availability_ == nullptr ? global_mqtt_client->get_availability() : *this->availability_;
if (!avail.topic.empty()) {
root[MQTT_AVAILABILITY_TOPIC] = avail.topic;
if (avail.payload_available != "online")
root[MQTT_PAYLOAD_AVAILABLE] = avail.payload_available;
if (avail.payload_not_available != "offline")
root[MQTT_PAYLOAD_NOT_AVAILABLE] = avail.payload_not_available;
if (this->availability_ == nullptr) {
if (!global_mqtt_client->get_availability().topic.empty()) {
root[MQTT_AVAILABILITY_TOPIC] = global_mqtt_client->get_availability().topic;
if (global_mqtt_client->get_availability().payload_available != "online")
root[MQTT_PAYLOAD_AVAILABLE] = global_mqtt_client->get_availability().payload_available;
if (global_mqtt_client->get_availability().payload_not_available != "offline")
root[MQTT_PAYLOAD_NOT_AVAILABLE] = global_mqtt_client->get_availability().payload_not_available;
}
} else if (!this->availability_->topic.empty()) {
root[MQTT_AVAILABILITY_TOPIC] = this->availability_->topic;
if (this->availability_->payload_available != "online")
root[MQTT_PAYLOAD_AVAILABLE] = this->availability_->payload_available;
if (this->availability_->payload_not_available != "offline")
root[MQTT_PAYLOAD_NOT_AVAILABLE] = this->availability_->payload_not_available;
}
const MQTTDiscoveryInfo &discovery_info = global_mqtt_client->get_discovery_info();
@@ -135,8 +145,10 @@ bool MQTTComponent::send_discovery_() {
if (discovery_info.object_id_generator == MQTT_DEVICE_NAME_OBJECT_ID_GENERATOR)
root[MQTT_OBJECT_ID] = node_name + "_" + this->get_default_object_id_();
const std::string &friendly_name_ref = App.get_friendly_name();
const std::string &node_friendly_name = friendly_name_ref.empty() ? node_name : friendly_name_ref;
std::string node_friendly_name = App.get_friendly_name();
if (node_friendly_name.empty()) {
node_friendly_name = node_name;
}
std::string node_area = App.get_area();
JsonObject device_info = root[MQTT_DEVICE].to<JsonObject>();
@@ -146,9 +158,13 @@ bool MQTTComponent::send_discovery_() {
#ifdef ESPHOME_PROJECT_NAME
device_info[MQTT_DEVICE_SW_VERSION] = ESPHOME_PROJECT_VERSION " (ESPHome " ESPHOME_VERSION ")";
const char *model = std::strchr(ESPHOME_PROJECT_NAME, '.');
device_info[MQTT_DEVICE_MODEL] = model == nullptr ? ESPHOME_BOARD : model + 1;
device_info[MQTT_DEVICE_MANUFACTURER] =
model == nullptr ? ESPHOME_PROJECT_NAME : std::string(ESPHOME_PROJECT_NAME, model - ESPHOME_PROJECT_NAME);
if (model == nullptr) { // must never happen but check anyway
device_info[MQTT_DEVICE_MODEL] = ESPHOME_BOARD;
device_info[MQTT_DEVICE_MANUFACTURER] = ESPHOME_PROJECT_NAME;
} else {
device_info[MQTT_DEVICE_MODEL] = model + 1;
device_info[MQTT_DEVICE_MANUFACTURER] = std::string(ESPHOME_PROJECT_NAME, model - ESPHOME_PROJECT_NAME);
}
#else
device_info[MQTT_DEVICE_SW_VERSION] = ESPHOME_VERSION " (" + App.get_compilation_time() + ")";
device_info[MQTT_DEVICE_MODEL] = ESPHOME_BOARD;

View File

@@ -1291,6 +1291,9 @@ void Nextion::check_pending_waveform_() {
void Nextion::set_writer(const nextion_writer_t &writer) { this->writer_ = writer; }
ESPDEPRECATED("set_wait_for_ack(bool) deprecated, no effect", "v1.20")
void Nextion::set_wait_for_ack(bool wait_for_ack) { ESP_LOGE(TAG, "Deprecated"); }
bool Nextion::is_updating() { return this->connection_state_.is_updating_; }
} // namespace nextion

View File

@@ -45,26 +45,13 @@ def get_script(script_id):
def check_max_runs(value):
# Set default for queued mode to prevent unbounded queue growth
if CONF_MAX_RUNS not in value and value[CONF_MODE] == CONF_QUEUED:
value[CONF_MAX_RUNS] = 5
if CONF_MAX_RUNS not in value:
return value
if value[CONF_MODE] not in [CONF_QUEUED, CONF_PARALLEL]:
raise cv.Invalid(
"The option 'max_runs' is only valid in 'queued' and 'parallel' mode.",
"The option 'max_runs' is only valid in 'queue' and 'parallel' mode.",
path=[CONF_MAX_RUNS],
)
# Queued mode must have bounded queue (min 1), parallel mode can be unlimited (0)
if value[CONF_MODE] == CONF_QUEUED and value[CONF_MAX_RUNS] < 1:
raise cv.Invalid(
"The option 'max_runs' must be at least 1 for queued mode.",
path=[CONF_MAX_RUNS],
)
return value
@@ -119,7 +106,7 @@ CONFIG_SCHEMA = automation.validate_automation(
cv.Optional(CONF_MODE, default=CONF_SINGLE): cv.one_of(
*SCRIPT_MODES, lower=True
),
cv.Optional(CONF_MAX_RUNS): cv.int_range(min=0, max=100),
cv.Optional(CONF_MAX_RUNS): cv.positive_int,
cv.Optional(CONF_PARAMETERS, default={}): cv.Schema(
{
validate_parameter_name: validate_parameter_type,

View File

@@ -1,11 +1,10 @@
#pragma once
#include <memory>
#include <tuple>
#include "esphome/core/automation.h"
#include "esphome/core/component.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
#include <queue>
namespace esphome {
namespace script {
@@ -97,41 +96,23 @@ template<typename... Ts> class RestartScript : public Script<Ts...> {
/** A script type that queues new instances that are created.
*
* Only one instance of the script can be active at a time.
*
* Ring buffer implementation:
* - num_queued_ tracks the number of queued (waiting) instances, NOT including the currently running one
* - queue_front_ points to the next item to execute (read position)
* - Buffer size is max_runs_ - 1 (max total instances minus the running one)
* - Write position is calculated as: (queue_front_ + num_queued_) % (max_runs_ - 1)
* - When an item finishes, queue_front_ advances: (queue_front_ + 1) % (max_runs_ - 1)
* - First execute() runs immediately without queuing (num_queued_ stays 0)
* - Subsequent executes while running are queued starting at position 0
* - Maximum total instances = max_runs_ (includes 1 running + (max_runs_ - 1) queued)
*/
template<typename... Ts> class QueueingScript : public Script<Ts...>, public Component {
public:
void execute(Ts... x) override {
if (this->is_action_running() || this->num_queued_ > 0) {
// num_queued_ is the number of *queued* instances (waiting, not including currently running)
// max_runs_ is the maximum *total* instances (running + queued)
// So we reject when num_queued_ + 1 >= max_runs_ (queued + running >= max)
if (this->num_queued_ + 1 >= this->max_runs_) {
this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' max instances (running + queued) reached!"),
if (this->is_action_running() || this->num_runs_ > 0) {
// num_runs_ is the number of *queued* instances, so total number of instances is
// num_runs_ + 1
if (this->max_runs_ != 0 && this->num_runs_ + 1 >= this->max_runs_) {
this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' maximum number of queued runs exceeded!"),
LOG_STR_ARG(this->name_));
return;
}
// Initialize queue on first queued item (after capacity check)
this->lazy_init_queue_();
this->esp_logd_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' queueing new instance (mode: queued)"),
LOG_STR_ARG(this->name_));
// Ring buffer: write to (queue_front_ + num_queued_) % queue_capacity
const size_t queue_capacity = static_cast<size_t>(this->max_runs_ - 1);
size_t write_pos = (this->queue_front_ + this->num_queued_) % queue_capacity;
// Use std::make_unique to replace the unique_ptr
this->var_queue_[write_pos] = std::make_unique<std::tuple<Ts...>>(x...);
this->num_queued_++;
this->num_runs_++;
this->var_queue_.push(std::make_tuple(x...));
return;
}
@@ -141,46 +122,29 @@ template<typename... Ts> class QueueingScript : public Script<Ts...>, public Com
}
void stop() override {
// Clear all queued items to free memory immediately
// Resetting the array automatically destroys all unique_ptrs and their contents
this->var_queue_.reset();
this->num_queued_ = 0;
this->queue_front_ = 0;
this->num_runs_ = 0;
Script<Ts...>::stop();
}
void loop() override {
if (this->num_queued_ != 0 && !this->is_action_running()) {
// Dequeue: decrement count, move tuple out (frees slot), advance read position
this->num_queued_--;
const size_t queue_capacity = static_cast<size_t>(this->max_runs_ - 1);
auto tuple_ptr = std::move(this->var_queue_[this->queue_front_]);
this->queue_front_ = (this->queue_front_ + 1) % queue_capacity;
this->trigger_tuple_(*tuple_ptr, typename gens<sizeof...(Ts)>::type());
if (this->num_runs_ != 0 && !this->is_action_running()) {
this->num_runs_--;
auto &vars = this->var_queue_.front();
this->var_queue_.pop();
this->trigger_tuple_(vars, typename gens<sizeof...(Ts)>::type());
}
}
void set_max_runs(int max_runs) { max_runs_ = max_runs; }
protected:
// Lazy init queue on first use - avoids setup() ordering issues and saves memory
// if script is never executed during this boot cycle
inline void lazy_init_queue_() {
if (!this->var_queue_) {
// Allocate array of max_runs_ - 1 slots for queued items (running item is separate)
// unique_ptr array is zero-initialized, so all slots start as nullptr
this->var_queue_ = std::make_unique<std::unique_ptr<std::tuple<Ts...>>[]>(this->max_runs_ - 1);
}
}
template<int... S> void trigger_tuple_(const std::tuple<Ts...> &tuple, seq<S...> /*unused*/) {
this->trigger(std::get<S>(tuple)...);
}
int num_queued_ = 0; // Number of queued instances (not including currently running)
int max_runs_ = 0; // Maximum total instances (running + queued)
size_t queue_front_ = 0; // Ring buffer read position (next item to execute)
std::unique_ptr<std::unique_ptr<std::tuple<Ts...>>[]> var_queue_; // Ring buffer of queued parameters
int num_runs_ = 0;
int max_runs_ = 0;
std::queue<std::tuple<Ts...>> var_queue_;
};
/** A script type that executes new instances in parallel.

View File

@@ -251,9 +251,6 @@ MaxFilter = sensor_ns.class_("MaxFilter", Filter)
SlidingWindowMovingAverageFilter = sensor_ns.class_(
"SlidingWindowMovingAverageFilter", Filter
)
StreamingMinFilter = sensor_ns.class_("StreamingMinFilter", Filter)
StreamingMaxFilter = sensor_ns.class_("StreamingMaxFilter", Filter)
StreamingMovingAverageFilter = sensor_ns.class_("StreamingMovingAverageFilter", Filter)
ExponentialMovingAverageFilter = sensor_ns.class_(
"ExponentialMovingAverageFilter", Filter
)
@@ -455,21 +452,14 @@ async def skip_initial_filter_to_code(config, filter_id):
return cg.new_Pvariable(filter_id, config)
@FILTER_REGISTRY.register("min", Filter, MIN_SCHEMA)
@FILTER_REGISTRY.register("min", MinFilter, MIN_SCHEMA)
async def min_filter_to_code(config, filter_id):
window_size: int = config[CONF_WINDOW_SIZE]
send_every: int = config[CONF_SEND_EVERY]
send_first_at: int = config[CONF_SEND_FIRST_AT]
# Optimization: Use streaming filter for batch windows (window_size == send_every)
# Saves 99.98% memory for large windows (e.g., 20KB → 4 bytes for window_size=5000)
if window_size == send_every:
# Use streaming filter - O(1) memory instead of O(n)
rhs = StreamingMinFilter.new(window_size, send_first_at)
return cg.Pvariable(filter_id, rhs, StreamingMinFilter)
# Use sliding window filter - maintains ring buffer
rhs = MinFilter.new(window_size, send_every, send_first_at)
return cg.Pvariable(filter_id, rhs, MinFilter)
return cg.new_Pvariable(
filter_id,
config[CONF_WINDOW_SIZE],
config[CONF_SEND_EVERY],
config[CONF_SEND_FIRST_AT],
)
MAX_SCHEMA = cv.All(
@@ -484,18 +474,14 @@ MAX_SCHEMA = cv.All(
)
@FILTER_REGISTRY.register("max", Filter, MAX_SCHEMA)
@FILTER_REGISTRY.register("max", MaxFilter, MAX_SCHEMA)
async def max_filter_to_code(config, filter_id):
window_size: int = config[CONF_WINDOW_SIZE]
send_every: int = config[CONF_SEND_EVERY]
send_first_at: int = config[CONF_SEND_FIRST_AT]
# Optimization: Use streaming filter for batch windows (window_size == send_every)
if window_size == send_every:
rhs = StreamingMaxFilter.new(window_size, send_first_at)
return cg.Pvariable(filter_id, rhs, StreamingMaxFilter)
rhs = MaxFilter.new(window_size, send_every, send_first_at)
return cg.Pvariable(filter_id, rhs, MaxFilter)
return cg.new_Pvariable(
filter_id,
config[CONF_WINDOW_SIZE],
config[CONF_SEND_EVERY],
config[CONF_SEND_FIRST_AT],
)
SLIDING_AVERAGE_SCHEMA = cv.All(
@@ -512,20 +498,16 @@ SLIDING_AVERAGE_SCHEMA = cv.All(
@FILTER_REGISTRY.register(
"sliding_window_moving_average",
Filter,
SlidingWindowMovingAverageFilter,
SLIDING_AVERAGE_SCHEMA,
)
async def sliding_window_moving_average_filter_to_code(config, filter_id):
window_size: int = config[CONF_WINDOW_SIZE]
send_every: int = config[CONF_SEND_EVERY]
send_first_at: int = config[CONF_SEND_FIRST_AT]
# Optimization: Use streaming filter for batch windows (window_size == send_every)
if window_size == send_every:
rhs = StreamingMovingAverageFilter.new(window_size, send_first_at)
return cg.Pvariable(filter_id, rhs, StreamingMovingAverageFilter)
rhs = SlidingWindowMovingAverageFilter.new(window_size, send_every, send_first_at)
return cg.Pvariable(filter_id, rhs, SlidingWindowMovingAverageFilter)
return cg.new_Pvariable(
filter_id,
config[CONF_WINDOW_SIZE],
config[CONF_SEND_EVERY],
config[CONF_SEND_FIRST_AT],
)
EXPONENTIAL_AVERAGE_SCHEMA = cv.All(

View File

@@ -32,76 +32,50 @@ void Filter::initialize(Sensor *parent, Filter *next) {
this->next_ = next;
}
// SlidingWindowFilter
SlidingWindowFilter::SlidingWindowFilter(size_t window_size, size_t send_every, size_t send_first_at)
: window_size_(window_size), send_every_(send_every), send_at_(send_every - send_first_at) {
// Allocate ring buffer once at initialization
this->window_.init(window_size);
}
optional<float> SlidingWindowFilter::new_value(float value) {
// Add value to ring buffer
if (this->window_count_ < this->window_size_) {
// Buffer not yet full - just append
this->window_.push_back(value);
this->window_count_++;
} else {
// Buffer full - overwrite oldest value (ring buffer)
this->window_[this->window_head_] = value;
this->window_head_++;
if (this->window_head_ >= this->window_size_) {
this->window_head_ = 0;
}
// MedianFilter
MedianFilter::MedianFilter(size_t window_size, size_t send_every, size_t send_first_at)
: send_every_(send_every), send_at_(send_every - send_first_at), window_size_(window_size) {}
void MedianFilter::set_send_every(size_t send_every) { this->send_every_ = send_every; }
void MedianFilter::set_window_size(size_t window_size) { this->window_size_ = window_size; }
optional<float> MedianFilter::new_value(float value) {
while (this->queue_.size() >= this->window_size_) {
this->queue_.pop_front();
}
this->queue_.push_back(value);
ESP_LOGVV(TAG, "MedianFilter(%p)::new_value(%f)", this, value);
// Check if we should send a result
if (++this->send_at_ >= this->send_every_) {
this->send_at_ = 0;
float result = this->compute_result();
ESP_LOGVV(TAG, "SlidingWindowFilter(%p)::new_value(%f) SENDING %f", this, value, result);
return result;
float median = NAN;
if (!this->queue_.empty()) {
// Copy queue without NaN values
std::vector<float> median_queue;
median_queue.reserve(this->queue_.size());
for (auto v : this->queue_) {
if (!std::isnan(v)) {
median_queue.push_back(v);
}
}
sort(median_queue.begin(), median_queue.end());
size_t queue_size = median_queue.size();
if (queue_size) {
if (queue_size % 2) {
median = median_queue[queue_size / 2];
} else {
median = (median_queue[queue_size / 2] + median_queue[(queue_size / 2) - 1]) / 2.0f;
}
}
}
ESP_LOGVV(TAG, "MedianFilter(%p)::new_value(%f) SENDING %f", this, value, median);
return median;
}
return {};
}
// SortedWindowFilter
FixedVector<float> SortedWindowFilter::get_window_values_() {
// Copy window without NaN values using FixedVector (no heap allocation)
// Returns unsorted values - caller will use std::nth_element for partial sorting as needed
FixedVector<float> values;
values.init(this->window_count_);
for (size_t i = 0; i < this->window_count_; i++) {
float v = this->window_[i];
if (!std::isnan(v)) {
values.push_back(v);
}
}
return values;
}
// MedianFilter
float MedianFilter::compute_result() {
FixedVector<float> values = this->get_window_values_();
if (values.empty())
return NAN;
size_t size = values.size();
size_t mid = size / 2;
if (size % 2) {
// Odd number of elements - use nth_element to find middle element
std::nth_element(values.begin(), values.begin() + mid, values.end());
return values[mid];
}
// Even number of elements - need both middle elements
// Use nth_element to find upper middle element
std::nth_element(values.begin(), values.begin() + mid, values.end());
float upper = values[mid];
// Find the maximum of the lower half (which is now everything before mid)
float lower = *std::max_element(values.begin(), values.begin() + mid);
return (lower + upper) / 2.0f;
}
// SkipInitialFilter
SkipInitialFilter::SkipInitialFilter(size_t num_to_ignore) : num_to_ignore_(num_to_ignore) {}
optional<float> SkipInitialFilter::new_value(float value) {
@@ -117,39 +91,136 @@ optional<float> SkipInitialFilter::new_value(float value) {
// QuantileFilter
QuantileFilter::QuantileFilter(size_t window_size, size_t send_every, size_t send_first_at, float quantile)
: SortedWindowFilter(window_size, send_every, send_first_at), quantile_(quantile) {}
: send_every_(send_every), send_at_(send_every - send_first_at), window_size_(window_size), quantile_(quantile) {}
void QuantileFilter::set_send_every(size_t send_every) { this->send_every_ = send_every; }
void QuantileFilter::set_window_size(size_t window_size) { this->window_size_ = window_size; }
void QuantileFilter::set_quantile(float quantile) { this->quantile_ = quantile; }
optional<float> QuantileFilter::new_value(float value) {
while (this->queue_.size() >= this->window_size_) {
this->queue_.pop_front();
}
this->queue_.push_back(value);
ESP_LOGVV(TAG, "QuantileFilter(%p)::new_value(%f), quantile:%f", this, value, this->quantile_);
float QuantileFilter::compute_result() {
FixedVector<float> values = this->get_window_values_();
if (values.empty())
return NAN;
if (++this->send_at_ >= this->send_every_) {
this->send_at_ = 0;
size_t position = ceilf(values.size() * this->quantile_) - 1;
ESP_LOGVV(TAG, "QuantileFilter(%p)::position: %zu/%zu", this, position + 1, values.size());
float result = NAN;
if (!this->queue_.empty()) {
// Copy queue without NaN values
std::vector<float> quantile_queue;
for (auto v : this->queue_) {
if (!std::isnan(v)) {
quantile_queue.push_back(v);
}
}
// Use nth_element to find the quantile element (O(n) instead of O(n log n))
std::nth_element(values.begin(), values.begin() + position, values.end());
return values[position];
sort(quantile_queue.begin(), quantile_queue.end());
size_t queue_size = quantile_queue.size();
if (queue_size) {
size_t position = ceilf(queue_size * this->quantile_) - 1;
ESP_LOGVV(TAG, "QuantileFilter(%p)::position: %zu/%zu", this, position + 1, queue_size);
result = quantile_queue[position];
}
}
ESP_LOGVV(TAG, "QuantileFilter(%p)::new_value(%f) SENDING %f", this, value, result);
return result;
}
return {};
}
// MinFilter
float MinFilter::compute_result() { return this->find_extremum_<std::less<float>>(); }
MinFilter::MinFilter(size_t window_size, size_t send_every, size_t send_first_at)
: send_every_(send_every), send_at_(send_every - send_first_at), window_size_(window_size) {}
void MinFilter::set_send_every(size_t send_every) { this->send_every_ = send_every; }
void MinFilter::set_window_size(size_t window_size) { this->window_size_ = window_size; }
optional<float> MinFilter::new_value(float value) {
while (this->queue_.size() >= this->window_size_) {
this->queue_.pop_front();
}
this->queue_.push_back(value);
ESP_LOGVV(TAG, "MinFilter(%p)::new_value(%f)", this, value);
if (++this->send_at_ >= this->send_every_) {
this->send_at_ = 0;
float min = NAN;
for (auto v : this->queue_) {
if (!std::isnan(v)) {
min = std::isnan(min) ? v : std::min(min, v);
}
}
ESP_LOGVV(TAG, "MinFilter(%p)::new_value(%f) SENDING %f", this, value, min);
return min;
}
return {};
}
// MaxFilter
float MaxFilter::compute_result() { return this->find_extremum_<std::greater<float>>(); }
MaxFilter::MaxFilter(size_t window_size, size_t send_every, size_t send_first_at)
: send_every_(send_every), send_at_(send_every - send_first_at), window_size_(window_size) {}
void MaxFilter::set_send_every(size_t send_every) { this->send_every_ = send_every; }
void MaxFilter::set_window_size(size_t window_size) { this->window_size_ = window_size; }
optional<float> MaxFilter::new_value(float value) {
while (this->queue_.size() >= this->window_size_) {
this->queue_.pop_front();
}
this->queue_.push_back(value);
ESP_LOGVV(TAG, "MaxFilter(%p)::new_value(%f)", this, value);
if (++this->send_at_ >= this->send_every_) {
this->send_at_ = 0;
float max = NAN;
for (auto v : this->queue_) {
if (!std::isnan(v)) {
max = std::isnan(max) ? v : std::max(max, v);
}
}
ESP_LOGVV(TAG, "MaxFilter(%p)::new_value(%f) SENDING %f", this, value, max);
return max;
}
return {};
}
// SlidingWindowMovingAverageFilter
float SlidingWindowMovingAverageFilter::compute_result() {
float sum = 0;
size_t valid_count = 0;
for (size_t i = 0; i < this->window_count_; i++) {
float v = this->window_[i];
if (!std::isnan(v)) {
sum += v;
valid_count++;
}
SlidingWindowMovingAverageFilter::SlidingWindowMovingAverageFilter(size_t window_size, size_t send_every,
size_t send_first_at)
: send_every_(send_every), send_at_(send_every - send_first_at), window_size_(window_size) {}
void SlidingWindowMovingAverageFilter::set_send_every(size_t send_every) { this->send_every_ = send_every; }
void SlidingWindowMovingAverageFilter::set_window_size(size_t window_size) { this->window_size_ = window_size; }
optional<float> SlidingWindowMovingAverageFilter::new_value(float value) {
while (this->queue_.size() >= this->window_size_) {
this->queue_.pop_front();
}
return valid_count ? sum / valid_count : NAN;
this->queue_.push_back(value);
ESP_LOGVV(TAG, "SlidingWindowMovingAverageFilter(%p)::new_value(%f)", this, value);
if (++this->send_at_ >= this->send_every_) {
this->send_at_ = 0;
float sum = 0;
size_t valid_count = 0;
for (auto v : this->queue_) {
if (!std::isnan(v)) {
sum += v;
valid_count++;
}
}
float average = NAN;
if (valid_count) {
average = sum / valid_count;
}
ESP_LOGVV(TAG, "SlidingWindowMovingAverageFilter(%p)::new_value(%f) SENDING %f", this, value, average);
return average;
}
return {};
}
// ExponentialMovingAverageFilter
@@ -472,78 +543,5 @@ optional<float> ToNTCTemperatureFilter::new_value(float value) {
return temp;
}
// StreamingFilter (base class)
StreamingFilter::StreamingFilter(size_t window_size, size_t send_first_at)
: window_size_(window_size), send_first_at_(send_first_at) {}
optional<float> StreamingFilter::new_value(float value) {
// Process the value (child class tracks min/max/sum/etc)
this->process_value(value);
this->count_++;
// Check if we should send (handle send_first_at for first value)
bool should_send = false;
if (this->first_send_ && this->count_ >= this->send_first_at_) {
should_send = true;
this->first_send_ = false;
} else if (!this->first_send_ && this->count_ >= this->window_size_) {
should_send = true;
}
if (should_send) {
float result = this->compute_batch_result();
// Reset for next batch
this->count_ = 0;
this->reset_batch();
ESP_LOGVV(TAG, "StreamingFilter(%p)::new_value(%f) SENDING %f", this, value, result);
return result;
}
return {};
}
// StreamingMinFilter
void StreamingMinFilter::process_value(float value) {
// Update running minimum (ignore NaN values)
if (!std::isnan(value)) {
this->current_min_ = std::isnan(this->current_min_) ? value : std::min(this->current_min_, value);
}
}
float StreamingMinFilter::compute_batch_result() { return this->current_min_; }
void StreamingMinFilter::reset_batch() { this->current_min_ = NAN; }
// StreamingMaxFilter
void StreamingMaxFilter::process_value(float value) {
// Update running maximum (ignore NaN values)
if (!std::isnan(value)) {
this->current_max_ = std::isnan(this->current_max_) ? value : std::max(this->current_max_, value);
}
}
float StreamingMaxFilter::compute_batch_result() { return this->current_max_; }
void StreamingMaxFilter::reset_batch() { this->current_max_ = NAN; }
// StreamingMovingAverageFilter
void StreamingMovingAverageFilter::process_value(float value) {
// Accumulate sum (ignore NaN values)
if (!std::isnan(value)) {
this->sum_ += value;
this->valid_count_++;
}
}
float StreamingMovingAverageFilter::compute_batch_result() {
return this->valid_count_ > 0 ? this->sum_ / this->valid_count_ : NAN;
}
void StreamingMovingAverageFilter::reset_batch() {
this->sum_ = 0.0f;
this->valid_count_ = 0;
}
} // namespace sensor
} // namespace esphome

View File

@@ -44,75 +44,11 @@ class Filter {
Sensor *parent_{nullptr};
};
/** Base class for filters that use a sliding window of values.
*
* Uses a ring buffer to efficiently maintain a fixed-size sliding window without
* reallocations or pop_front() overhead. Eliminates deque fragmentation issues.
*/
class SlidingWindowFilter : public Filter {
public:
SlidingWindowFilter(size_t window_size, size_t send_every, size_t send_first_at);
optional<float> new_value(float value) final;
protected:
/// Called by new_value() to compute the filtered result from the current window
virtual float compute_result() = 0;
/// Access the sliding window values (ring buffer implementation)
/// Use: for (size_t i = 0; i < window_count_; i++) { float val = window_[i]; }
FixedVector<float> window_;
size_t window_head_{0}; ///< Index where next value will be written
size_t window_count_{0}; ///< Number of valid values in window (0 to window_size_)
size_t window_size_; ///< Maximum window size
size_t send_every_; ///< Send result every N values
size_t send_at_; ///< Counter for send_every
};
/** Base class for Min/Max filters.
*
* Provides a templated helper to find extremum values efficiently.
*/
class MinMaxFilter : public SlidingWindowFilter {
public:
using SlidingWindowFilter::SlidingWindowFilter;
protected:
/// Helper to find min or max value in window, skipping NaN values
/// Usage: find_extremum_<std::less<float>>() for min, find_extremum_<std::greater<float>>() for max
template<typename Compare> float find_extremum_() {
float result = NAN;
Compare comp;
for (size_t i = 0; i < this->window_count_; i++) {
float v = this->window_[i];
if (!std::isnan(v)) {
result = std::isnan(result) ? v : (comp(v, result) ? v : result);
}
}
return result;
}
};
/** Base class for filters that need a sorted window (Median, Quantile).
*
* Extends SlidingWindowFilter to provide a helper that filters out NaN values.
* Derived classes use std::nth_element for efficient partial sorting.
*/
class SortedWindowFilter : public SlidingWindowFilter {
public:
using SlidingWindowFilter::SlidingWindowFilter;
protected:
/// Helper to get non-NaN values from the window (not sorted - caller will use nth_element)
/// Returns empty FixedVector if all values are NaN
FixedVector<float> get_window_values_();
};
/** Simple quantile filter.
*
* Takes the quantile of the last <window_size> values and pushes it out every <send_every>.
* Takes the quantile of the last <send_every> values and pushes it out every <send_every>.
*/
class QuantileFilter : public SortedWindowFilter {
class QuantileFilter : public Filter {
public:
/** Construct a QuantileFilter.
*
@@ -125,18 +61,25 @@ class QuantileFilter : public SortedWindowFilter {
*/
explicit QuantileFilter(size_t window_size, size_t send_every, size_t send_first_at, float quantile);
void set_quantile(float quantile) { this->quantile_ = quantile; }
optional<float> new_value(float value) override;
void set_send_every(size_t send_every);
void set_window_size(size_t window_size);
void set_quantile(float quantile);
protected:
float compute_result() override;
std::deque<float> queue_;
size_t send_every_;
size_t send_at_;
size_t window_size_;
float quantile_;
};
/** Simple median filter.
*
* Takes the median of the last <window_size> values and pushes it out every <send_every>.
* Takes the median of the last <send_every> values and pushes it out every <send_every>.
*/
class MedianFilter : public SortedWindowFilter {
class MedianFilter : public Filter {
public:
/** Construct a MedianFilter.
*
@@ -146,10 +89,18 @@ class MedianFilter : public SortedWindowFilter {
* on startup being published on the first *raw* value, so with no filter applied. Must be less than or equal to
* send_every.
*/
using SortedWindowFilter::SortedWindowFilter;
explicit MedianFilter(size_t window_size, size_t send_every, size_t send_first_at);
optional<float> new_value(float value) override;
void set_send_every(size_t send_every);
void set_window_size(size_t window_size);
protected:
float compute_result() override;
std::deque<float> queue_;
size_t send_every_;
size_t send_at_;
size_t window_size_;
};
/** Simple skip filter.
@@ -172,9 +123,9 @@ class SkipInitialFilter : public Filter {
/** Simple min filter.
*
* Takes the min of the last <window_size> values and pushes it out every <send_every>.
* Takes the min of the last <send_every> values and pushes it out every <send_every>.
*/
class MinFilter : public MinMaxFilter {
class MinFilter : public Filter {
public:
/** Construct a MinFilter.
*
@@ -184,17 +135,25 @@ class MinFilter : public MinMaxFilter {
* on startup being published on the first *raw* value, so with no filter applied. Must be less than or equal to
* send_every.
*/
using MinMaxFilter::MinMaxFilter;
explicit MinFilter(size_t window_size, size_t send_every, size_t send_first_at);
optional<float> new_value(float value) override;
void set_send_every(size_t send_every);
void set_window_size(size_t window_size);
protected:
float compute_result() override;
std::deque<float> queue_;
size_t send_every_;
size_t send_at_;
size_t window_size_;
};
/** Simple max filter.
*
* Takes the max of the last <window_size> values and pushes it out every <send_every>.
* Takes the max of the last <send_every> values and pushes it out every <send_every>.
*/
class MaxFilter : public MinMaxFilter {
class MaxFilter : public Filter {
public:
/** Construct a MaxFilter.
*
@@ -204,10 +163,18 @@ class MaxFilter : public MinMaxFilter {
* on startup being published on the first *raw* value, so with no filter applied. Must be less than or equal to
* send_every.
*/
using MinMaxFilter::MinMaxFilter;
explicit MaxFilter(size_t window_size, size_t send_every, size_t send_first_at);
optional<float> new_value(float value) override;
void set_send_every(size_t send_every);
void set_window_size(size_t window_size);
protected:
float compute_result() override;
std::deque<float> queue_;
size_t send_every_;
size_t send_at_;
size_t window_size_;
};
/** Simple sliding window moving average filter.
@@ -215,7 +182,7 @@ class MaxFilter : public MinMaxFilter {
* Essentially just takes takes the average of the last window_size values and pushes them out
* every send_every.
*/
class SlidingWindowMovingAverageFilter : public SlidingWindowFilter {
class SlidingWindowMovingAverageFilter : public Filter {
public:
/** Construct a SlidingWindowMovingAverageFilter.
*
@@ -225,10 +192,18 @@ class SlidingWindowMovingAverageFilter : public SlidingWindowFilter {
* on startup being published on the first *raw* value, so with no filter applied. Must be less than or equal to
* send_every.
*/
using SlidingWindowFilter::SlidingWindowFilter;
explicit SlidingWindowMovingAverageFilter(size_t window_size, size_t send_every, size_t send_first_at);
optional<float> new_value(float value) override;
void set_send_every(size_t send_every);
void set_window_size(size_t window_size);
protected:
float compute_result() override;
std::deque<float> queue_;
size_t send_every_;
size_t send_at_;
size_t window_size_;
};
/** Simple exponential moving average filter.
@@ -501,81 +476,5 @@ class ToNTCTemperatureFilter : public Filter {
double c_;
};
/** Base class for streaming filters (batch windows where window_size == send_every).
*
* When window_size equals send_every, we don't need a sliding window.
* This base class handles the common batching logic.
*/
class StreamingFilter : public Filter {
public:
StreamingFilter(size_t window_size, size_t send_first_at);
optional<float> new_value(float value) final;
protected:
/// Called by new_value() to process each value in the batch
virtual void process_value(float value) = 0;
/// Called by new_value() to compute the result after collecting window_size values
virtual float compute_batch_result() = 0;
/// Called by new_value() to reset internal state after sending a result
virtual void reset_batch() = 0;
size_t window_size_;
size_t count_{0};
size_t send_first_at_;
bool first_send_{true};
};
/** Streaming min filter for batch windows (window_size == send_every).
*
* Uses O(1) memory instead of O(n) by tracking only the minimum value.
*/
class StreamingMinFilter : public StreamingFilter {
public:
using StreamingFilter::StreamingFilter;
protected:
void process_value(float value) override;
float compute_batch_result() override;
void reset_batch() override;
float current_min_{NAN};
};
/** Streaming max filter for batch windows (window_size == send_every).
*
* Uses O(1) memory instead of O(n) by tracking only the maximum value.
*/
class StreamingMaxFilter : public StreamingFilter {
public:
using StreamingFilter::StreamingFilter;
protected:
void process_value(float value) override;
float compute_batch_result() override;
void reset_batch() override;
float current_max_{NAN};
};
/** Streaming moving average filter for batch windows (window_size == send_every).
*
* Uses O(1) memory instead of O(n) by tracking only sum and count.
*/
class StreamingMovingAverageFilter : public StreamingFilter {
public:
using StreamingFilter::StreamingFilter;
protected:
void process_value(float value) override;
float compute_batch_result() override;
void reset_batch() override;
float sum_{0.0f};
size_t valid_count_{0};
};
} // namespace sensor
} // namespace esphome

View File

@@ -27,14 +27,6 @@ class RealTimeClock : public PollingComponent {
this->apply_timezone_();
}
/// Set the time zone from raw buffer, only if it differs from the current one.
void set_timezone(const char *tz, size_t len) {
if (this->timezone_.length() != len || memcmp(this->timezone_.c_str(), tz, len) != 0) {
this->timezone_.assign(tz, len);
this->apply_timezone_();
}
}
/// Get the time zone currently in use.
std::string get_timezone() { return this->timezone_; }
#endif

View File

@@ -7,13 +7,24 @@ namespace touchscreen {
static const char *const TAG = "touchscreen";
void TouchscreenInterrupt::gpio_intr(TouchscreenInterrupt *store) { store->touched = true; }
void TouchscreenInterrupt::gpio_intr(TouchscreenInterrupt *store) {
bool new_state = store->isr_pin_.digital_read();
if (new_state != store->inverted) {
store->touched = true;
if (store->component_ != nullptr) {
store->component_->enable_loop_soon_any_context();
}
}
}
void Touchscreen::attach_interrupt_(InternalGPIOPin *irq_pin, esphome::gpio::InterruptType type) {
this->store_.isr_pin_ = irq_pin->to_isr();
this->store_.component_ = this;
this->store_.inverted = irq_pin->is_inverted();
irq_pin->attach_interrupt(TouchscreenInterrupt::gpio_intr, &this->store_, type);
this->store_.init = true;
this->store_.touched = false;
ESP_LOGD(TAG, "Attach Touch Interupt");
ESP_LOGD(TAG, "Attach Touch Interrupt");
}
void Touchscreen::call_setup() {
@@ -71,6 +82,8 @@ void Touchscreen::loop() {
}
}
}
if (this->store_.init)
this->disable_loop();
}
void Touchscreen::add_raw_touch_position_(uint8_t id, int16_t x_raw, int16_t y_raw, int16_t z_raw) {

View File

@@ -1,13 +1,13 @@
#pragma once
#include "esphome/core/defines.h"
#include "esphome/components/display/display.h"
#include "esphome/core/defines.h"
#include "esphome/core/automation.h"
#include "esphome/core/hal.h"
#include <vector>
#include <map>
#include <vector>
namespace esphome {
namespace touchscreen {
@@ -30,9 +30,12 @@ struct TouchPoint {
using TouchPoints_t = std::vector<TouchPoint>;
struct TouchscreenInterrupt {
ISRInternalGPIOPin isr_pin_;
volatile bool touched{true};
bool init{false};
bool inverted{false};
static void gpio_intr(TouchscreenInterrupt *store);
Component *component_{nullptr};
};
class TouchListener {

View File

@@ -407,8 +407,7 @@ async def to_code(config):
cg.add(var.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT]))
cg.add(var.set_power_save_mode(config[CONF_POWER_SAVE_MODE]))
if config[CONF_FAST_CONNECT]:
cg.add_define("USE_WIFI_FAST_CONNECT")
cg.add(var.set_fast_connect(config[CONF_FAST_CONNECT]))
cg.add(var.set_passive_scan(config[CONF_PASSIVE_SCAN]))
if CONF_OUTPUT_POWER in config:
cg.add(var.set_output_power(config[CONF_OUTPUT_POWER]))

View File

@@ -84,9 +84,9 @@ void WiFiComponent::start() {
uint32_t hash = this->has_sta() ? fnv1_hash(App.get_compilation_time()) : 88491487UL;
this->pref_ = global_preferences->make_preference<wifi::SavedWifiSettings>(hash, true);
#ifdef USE_WIFI_FAST_CONNECT
this->fast_connect_pref_ = global_preferences->make_preference<wifi::SavedWifiFastConnectSettings>(hash + 1, false);
#endif
if (this->fast_connect_) {
this->fast_connect_pref_ = global_preferences->make_preference<wifi::SavedWifiFastConnectSettings>(hash + 1, false);
}
SavedWifiSettings save{};
if (this->pref_.load(&save)) {
@@ -108,16 +108,16 @@ void WiFiComponent::start() {
ESP_LOGV(TAG, "Setting Power Save Option failed");
}
#ifdef USE_WIFI_FAST_CONNECT
this->trying_loaded_ap_ = this->load_fast_connect_settings_();
if (!this->trying_loaded_ap_) {
this->ap_index_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
if (this->fast_connect_) {
this->trying_loaded_ap_ = this->load_fast_connect_settings_();
if (!this->trying_loaded_ap_) {
this->ap_index_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
}
this->start_connecting(this->selected_ap_, false);
} else {
this->start_scanning();
}
this->start_connecting(this->selected_ap_, false);
#else
this->start_scanning();
#endif
#ifdef USE_WIFI_AP
} else if (this->has_ap()) {
this->setup_ap_config_();
@@ -168,20 +168,13 @@ void WiFiComponent::loop() {
case WIFI_COMPONENT_STATE_COOLDOWN: {
this->status_set_warning(LOG_STR("waiting to reconnect"));
if (millis() - this->action_started_ > 5000) {
#ifdef USE_WIFI_FAST_CONNECT
// NOTE: This check may not make sense here as it could interfere with AP cycling
if (!this->selected_ap_.get_bssid().has_value())
this->selected_ap_ = this->sta_[0];
this->start_connecting(this->selected_ap_, false);
#else
if (this->retry_hidden_) {
if (this->fast_connect_ || this->retry_hidden_) {
if (!this->selected_ap_.get_bssid().has_value())
this->selected_ap_ = this->sta_[0];
this->start_connecting(this->selected_ap_, false);
} else {
this->start_scanning();
}
#endif
}
break;
}
@@ -251,6 +244,7 @@ WiFiComponent::WiFiComponent() { global_wifi_component = this; }
bool WiFiComponent::has_ap() const { return this->has_ap_; }
bool WiFiComponent::has_sta() const { return !this->sta_.empty(); }
void WiFiComponent::set_fast_connect(bool fast_connect) { this->fast_connect_ = fast_connect; }
#ifdef USE_WIFI_11KV_SUPPORT
void WiFiComponent::set_btm(bool btm) { this->btm_ = btm; }
void WiFiComponent::set_rrm(bool rrm) { this->rrm_ = rrm; }
@@ -613,12 +607,10 @@ void WiFiComponent::check_scanning_finished() {
for (auto &ap : this->sta_) {
if (res.matches(ap)) {
res.set_matches(true);
// Cache priority lookup - do single search instead of 2 separate searches
const bssid_t &bssid = res.get_bssid();
if (!this->has_sta_priority(bssid)) {
this->set_sta_priority(bssid, ap.get_priority());
if (!this->has_sta_priority(res.get_bssid())) {
this->set_sta_priority(res.get_bssid(), ap.get_priority());
}
res.set_priority(this->get_sta_priority(bssid));
res.set_priority(this->get_sta_priority(res.get_bssid()));
break;
}
}
@@ -637,9 +629,8 @@ void WiFiComponent::check_scanning_finished() {
return;
}
// Build connection params directly into selected_ap_ to avoid extra copy
const WiFiScanResult &scan_res = this->scan_result_[0];
WiFiAP &selected = this->selected_ap_;
WiFiAP connect_params;
WiFiScanResult scan_res = this->scan_result_[0];
for (auto &config : this->sta_) {
// search for matching STA config, at least one will match (from checks before)
if (!scan_res.matches(config)) {
@@ -648,38 +639,37 @@ void WiFiComponent::check_scanning_finished() {
if (config.get_hidden()) {
// selected network is hidden, we use the data from the config
selected.set_hidden(true);
selected.set_ssid(config.get_ssid());
// Clear channel and BSSID for hidden networks - there might be multiple hidden networks
connect_params.set_hidden(true);
connect_params.set_ssid(config.get_ssid());
// don't set BSSID and channel, there might be multiple hidden networks
// but we can't know which one is the correct one. Rely on probe-req with just SSID.
selected.set_channel(0);
selected.set_bssid(optional<bssid_t>{});
} else {
// selected network is visible, we use the data from the scan
// limit the connect params to only connect to exactly this network
// (network selection is done during scan phase).
selected.set_hidden(false);
selected.set_ssid(scan_res.get_ssid());
selected.set_channel(scan_res.get_channel());
selected.set_bssid(scan_res.get_bssid());
connect_params.set_hidden(false);
connect_params.set_ssid(scan_res.get_ssid());
connect_params.set_channel(scan_res.get_channel());
connect_params.set_bssid(scan_res.get_bssid());
}
// copy manual IP (if set)
selected.set_manual_ip(config.get_manual_ip());
connect_params.set_manual_ip(config.get_manual_ip());
#ifdef USE_WIFI_WPA2_EAP
// copy EAP parameters (if set)
selected.set_eap(config.get_eap());
connect_params.set_eap(config.get_eap());
#endif
// copy password (if set)
selected.set_password(config.get_password());
connect_params.set_password(config.get_password());
break;
}
yield();
this->start_connecting(this->selected_ap_, false);
this->selected_ap_ = connect_params;
this->start_connecting(connect_params, false);
}
void WiFiComponent::dump_config() {
@@ -729,9 +719,9 @@ void WiFiComponent::check_connecting_finished() {
this->scan_result_.shrink_to_fit();
}
#ifdef USE_WIFI_FAST_CONNECT
this->save_fast_connect_settings_();
#endif
if (this->fast_connect_) {
this->save_fast_connect_settings_();
}
return;
}
@@ -779,31 +769,31 @@ void WiFiComponent::retry_connect() {
delay(10);
if (!this->is_captive_portal_active_() && !this->is_esp32_improv_active_() &&
(this->num_retried_ > 3 || this->error_from_callback_)) {
#ifdef USE_WIFI_FAST_CONNECT
if (this->trying_loaded_ap_) {
this->trying_loaded_ap_ = false;
this->ap_index_ = 0; // Retry from the first configured AP
} else if (this->ap_index_ >= this->sta_.size() - 1) {
ESP_LOGW(TAG, "No more APs to try");
this->ap_index_ = 0;
this->restart_adapter();
if (this->fast_connect_) {
if (this->trying_loaded_ap_) {
this->trying_loaded_ap_ = false;
this->ap_index_ = 0; // Retry from the first configured AP
} else if (this->ap_index_ >= this->sta_.size() - 1) {
ESP_LOGW(TAG, "No more APs to try");
this->ap_index_ = 0;
this->restart_adapter();
} else {
// Try next AP
this->ap_index_++;
}
this->num_retried_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
} else {
// Try next AP
this->ap_index_++;
if (this->num_retried_ > 5) {
// If retry failed for more than 5 times, let's restart STA
this->restart_adapter();
} else {
// Try hidden networks after 3 failed retries
ESP_LOGD(TAG, "Retrying with hidden networks");
this->retry_hidden_ = true;
this->num_retried_++;
}
}
this->num_retried_ = 0;
this->selected_ap_ = this->sta_[this->ap_index_];
#else
if (this->num_retried_ > 5) {
// If retry failed for more than 5 times, let's restart STA
this->restart_adapter();
} else {
// Try hidden networks after 3 failed retries
ESP_LOGD(TAG, "Retrying with hidden networks");
this->retry_hidden_ = true;
this->num_retried_++;
}
#endif
} else {
this->num_retried_++;
}
@@ -849,7 +839,6 @@ bool WiFiComponent::is_esp32_improv_active_() {
#endif
}
#ifdef USE_WIFI_FAST_CONNECT
bool WiFiComponent::load_fast_connect_settings_() {
SavedWifiFastConnectSettings fast_connect_save{};
@@ -884,7 +873,6 @@ void WiFiComponent::save_fast_connect_settings_() {
ESP_LOGD(TAG, "Saved fast_connect settings");
}
}
#endif
void WiFiAP::set_ssid(const std::string &ssid) { this->ssid_ = ssid; }
void WiFiAP::set_bssid(bssid_t bssid) { this->bssid_ = bssid; }
@@ -914,7 +902,7 @@ WiFiScanResult::WiFiScanResult(const bssid_t &bssid, std::string ssid, uint8_t c
rssi_(rssi),
with_auth_(with_auth),
is_hidden_(is_hidden) {}
bool WiFiScanResult::matches(const WiFiAP &config) const {
bool WiFiScanResult::matches(const WiFiAP &config) {
if (config.get_hidden()) {
// User configured a hidden network, only match actually hidden networks
// don't match SSID

View File

@@ -170,7 +170,7 @@ class WiFiScanResult {
public:
WiFiScanResult(const bssid_t &bssid, std::string ssid, uint8_t channel, int8_t rssi, bool with_auth, bool is_hidden);
bool matches(const WiFiAP &config) const;
bool matches(const WiFiAP &config);
bool get_matches() const;
void set_matches(bool matches);
@@ -240,6 +240,7 @@ class WiFiComponent : public Component {
void start_scanning();
void check_scanning_finished();
void start_connecting(const WiFiAP &ap, bool two);
void set_fast_connect(bool fast_connect);
void set_ap_timeout(uint32_t ap_timeout) { ap_timeout_ = ap_timeout; }
void check_connecting_finished();
@@ -363,10 +364,8 @@ class WiFiComponent : public Component {
bool is_captive_portal_active_();
bool is_esp32_improv_active_();
#ifdef USE_WIFI_FAST_CONNECT
bool load_fast_connect_settings_();
void save_fast_connect_settings_();
#endif
#ifdef USE_ESP8266
static void wifi_event_callback(System_Event_t *event);
@@ -400,9 +399,7 @@ class WiFiComponent : public Component {
WiFiAP ap_;
optional<float> output_power_;
ESPPreferenceObject pref_;
#ifdef USE_WIFI_FAST_CONNECT
ESPPreferenceObject fast_connect_pref_;
#endif
// Group all 32-bit integers together
uint32_t action_started_;
@@ -414,17 +411,14 @@ class WiFiComponent : public Component {
WiFiComponentState state_{WIFI_COMPONENT_STATE_OFF};
WiFiPowerSaveMode power_save_{WIFI_POWER_SAVE_NONE};
uint8_t num_retried_{0};
#ifdef USE_WIFI_FAST_CONNECT
uint8_t ap_index_{0};
#endif
#if USE_NETWORK_IPV6
uint8_t num_ipv6_addresses_{0};
#endif /* USE_NETWORK_IPV6 */
// Group all boolean values together
#ifdef USE_WIFI_FAST_CONNECT
bool fast_connect_{false};
bool trying_loaded_ap_{false};
#endif
bool retry_hidden_{false};
bool has_ap_{false};
bool handled_connected_state_{false};

View File

@@ -706,10 +706,10 @@ void WiFiComponent::wifi_scan_done_callback_(void *arg, STATUS status) {
this->scan_result_.init(count);
for (bss_info *it = head; it != nullptr; it = STAILQ_NEXT(it, next)) {
this->scan_result_.emplace_back(
bssid_t{it->bssid[0], it->bssid[1], it->bssid[2], it->bssid[3], it->bssid[4], it->bssid[5]},
std::string(reinterpret_cast<char *>(it->ssid), it->ssid_len), it->channel, it->rssi, it->authmode != AUTH_OPEN,
it->is_hidden != 0);
WiFiScanResult res({it->bssid[0], it->bssid[1], it->bssid[2], it->bssid[3], it->bssid[4], it->bssid[5]},
std::string(reinterpret_cast<char *>(it->ssid), it->ssid_len), it->channel, it->rssi,
it->authmode != AUTH_OPEN, it->is_hidden != 0);
this->scan_result_.push_back(res);
}
this->scan_done_ = true;
}

View File

@@ -776,12 +776,13 @@ void WiFiComponent::wifi_process_event_(IDFWiFiEvent *data) {
}
uint16_t number = it.number;
auto records = std::make_unique<wifi_ap_record_t[]>(number);
err = esp_wifi_scan_get_ap_records(&number, records.get());
std::vector<wifi_ap_record_t> records(number);
err = esp_wifi_scan_get_ap_records(&number, records.data());
if (err != ESP_OK) {
ESP_LOGW(TAG, "esp_wifi_scan_get_ap_records failed: %s", esp_err_to_name(err));
return;
}
records.resize(number);
scan_result_.init(number);
for (int i = 0; i < number; i++) {
@@ -789,8 +790,8 @@ void WiFiComponent::wifi_process_event_(IDFWiFiEvent *data) {
bssid_t bssid;
std::copy(record.bssid, record.bssid + 6, bssid.begin());
std::string ssid(reinterpret_cast<const char *>(record.ssid));
scan_result_.emplace_back(bssid, ssid, record.primary, record.rssi, record.authmode != WIFI_AUTH_OPEN,
ssid.empty());
WiFiScanResult result(bssid, ssid, record.primary, record.rssi, record.authmode != WIFI_AUTH_OPEN, ssid.empty());
scan_result_.push_back(result);
}
} else if (data->event_base == WIFI_EVENT && data->event_id == WIFI_EVENT_AP_START) {

View File

@@ -419,9 +419,9 @@ void WiFiComponent::wifi_scan_done_callback_() {
uint8_t *bssid = WiFi.BSSID(i);
int32_t channel = WiFi.channel(i);
this->scan_result_.emplace_back(bssid_t{bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]},
std::string(ssid.c_str()), channel, rssi, authmode != WIFI_AUTH_OPEN,
ssid.length() == 0);
WiFiScanResult scan({bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]}, std::string(ssid.c_str()),
channel, rssi, authmode != WIFI_AUTH_OPEN, ssid.length() == 0);
this->scan_result_.push_back(scan);
}
WiFi.scanDelete();
this->scan_done_ = true;

View File

@@ -12,7 +12,7 @@ from typing import Any
import voluptuous as vol
from esphome import core, loader, pins, yaml_util
from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered
from esphome.config_helpers import Extend, Remove, merge_dicts_ordered
import esphome.config_validation as cv
from esphome.const import (
CONF_ESPHOME,
@@ -324,7 +324,13 @@ def iter_ids(config, path=None):
yield from iter_ids(value, path + [key])
def check_replaceme(value):
def recursive_check_replaceme(value):
if isinstance(value, list):
return cv.Schema([recursive_check_replaceme])(value)
if isinstance(value, dict):
return cv.Schema({cv.valid: recursive_check_replaceme})(value)
if isinstance(value, ESPLiteralValue):
pass
if isinstance(value, str) and value == "REPLACEME":
raise cv.Invalid(
"Found 'REPLACEME' in configuration, this is most likely an error. "
@@ -333,86 +339,7 @@ def check_replaceme(value):
"If you want to use the literal REPLACEME string, "
'please use "!literal REPLACEME"'
)
def _build_list_index(lst):
index = OrderedDict()
extensions, removals = [], set()
for item in lst:
if item is None:
removals.add(None)
continue
item_id = None
if isinstance(item, dict) and (item_id := item.get(CONF_ID)):
if isinstance(item_id, Extend):
extensions.append(item)
continue
if isinstance(item_id, Remove):
removals.add(item_id.value)
continue
if not item_id or item_id in index:
# no id or duplicate -> pass through with identity-based key
item_id = id(item)
index[item_id] = item
return index, extensions, removals
def resolve_extend_remove(value, is_key=None):
if isinstance(value, ESPLiteralValue):
return # do not check inside literal blocks
if isinstance(value, list):
index, extensions, removals = _build_list_index(value)
if extensions or removals:
# Rebuild the original list after
# processing all extensions and removals
for item in extensions:
item_id = item[CONF_ID].value
if item_id in removals:
continue
old = index.get(item_id)
if old is None:
# Failed to find source for extension
# Find index of item to show error at correct position
i = next(
(
i
for i, d in enumerate(value)
if d.get(CONF_ID) == item[CONF_ID]
)
)
with cv.prepend_path(i):
raise cv.Invalid(
f"Source for extension of ID '{item_id}' was not found."
)
item[CONF_ID] = item_id
index[item_id] = merge_config(old, item)
for item_id in removals:
index.pop(item_id, None)
value[:] = index.values()
for i, item in enumerate(value):
with cv.prepend_path(i):
resolve_extend_remove(item, False)
return
if isinstance(value, dict):
removals = []
for k, v in value.items():
with cv.prepend_path(k):
if isinstance(v, Remove):
removals.append(k)
continue
resolve_extend_remove(k, True)
resolve_extend_remove(v, False)
for k in removals:
value.pop(k, None)
return
if is_key:
return # do not check keys (yet)
check_replaceme(value)
return
return value
class ConfigValidationStep(abc.ABC):
@@ -510,6 +437,19 @@ class LoadValidationStep(ConfigValidationStep):
continue
p_name = p_config.get("platform")
if p_name is None:
p_id = p_config.get(CONF_ID)
if isinstance(p_id, Extend):
result.add_str_error(
f"Source for extension of ID '{p_id.value}' was not found.",
path + [CONF_ID],
)
continue
if isinstance(p_id, Remove):
result.add_str_error(
f"Source for removal of ID '{p_id.value}' was not found.",
path + [CONF_ID],
)
continue
result.add_str_error(
f"'{self.domain}' requires a 'platform' key but it was not specified.",
path,
@@ -994,10 +934,9 @@ def validate_config(
CORE.raw_config = config
# 1.1. Resolve !extend and !remove and check for REPLACEME
# After this step, there will not be any Extend or Remove values in the config anymore
# 1.1. Check for REPLACEME special value
try:
resolve_extend_remove(config)
recursive_check_replaceme(config)
except vol.Invalid as err:
result.add_error(err)

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from esphome.const import (
CONF_ID,
CONF_LEVEL,
CONF_LOGGER,
KEY_CORE,
@@ -74,28 +75,73 @@ class Remove:
return isinstance(b, Remove) and self.value == b.value
def merge_config(old, new):
if isinstance(new, Remove):
return new
if isinstance(new, dict):
if not isinstance(old, dict):
return new
# Preserve OrderedDict type by copying to OrderedDict if either input is OrderedDict
if isinstance(old, OrderedDict) or isinstance(new, OrderedDict):
res = OrderedDict(old)
else:
def merge_config(full_old, full_new):
def merge(old, new):
if isinstance(new, dict):
if not isinstance(old, dict):
return new
# Preserve OrderedDict type by copying to OrderedDict if either input is OrderedDict
if isinstance(old, OrderedDict) or isinstance(new, OrderedDict):
res = OrderedDict(old)
else:
res = old.copy()
for k, v in new.items():
if isinstance(v, Remove) and k in old:
del res[k]
else:
res[k] = merge(old[k], v) if k in old else v
return res
if isinstance(new, list):
if not isinstance(old, list):
return new
res = old.copy()
for k, v in new.items():
res[k] = merge_config(old.get(k), v)
return res
if isinstance(new, list):
if not isinstance(old, list):
return new
return old + new
if new is None:
return old
ids = {
v_id: i
for i, v in enumerate(res)
if isinstance(v, dict)
and (v_id := v.get(CONF_ID))
and isinstance(v_id, str)
}
extend_ids = {
v_id.value: i
for i, v in enumerate(res)
if isinstance(v, dict)
and (v_id := v.get(CONF_ID))
and isinstance(v_id, Extend)
}
return new
ids_to_delete = []
for v in new:
if isinstance(v, dict) and (new_id := v.get(CONF_ID)):
if isinstance(new_id, Extend):
new_id = new_id.value
if new_id in ids:
v[CONF_ID] = new_id
res[ids[new_id]] = merge(res[ids[new_id]], v)
continue
elif isinstance(new_id, Remove):
new_id = new_id.value
if new_id in ids:
ids_to_delete.append(ids[new_id])
continue
elif (
new_id in extend_ids
): # When a package is extending a non-packaged item
extend_res = res[extend_ids[new_id]]
extend_res[CONF_ID] = new_id
new_v = merge(v, extend_res)
res[extend_ids[new_id]] = new_v
continue
else:
ids[new_id] = len(res)
res.append(v)
return [v for i, v in enumerate(res) if i not in ids_to_delete]
if new is None:
return old
return new
return merge(full_old, full_new)
def filter_source_files_from_platform(

View File

@@ -24,6 +24,7 @@ import voluptuous as vol
from esphome import core
import esphome.codegen as cg
from esphome.config_helpers import Extend, Remove
from esphome.const import (
ALLOWED_NAME_CHARS,
CONF_AVAILABILITY,
@@ -623,6 +624,12 @@ def declare_id(type):
if value is None:
return core.ID(None, is_declaration=True, type=type)
if isinstance(value, Extend):
raise Invalid(f"Source for extension of ID '{value.value}' was not found.")
if isinstance(value, Remove):
raise Invalid(f"Source for Removal of ID '{value.value}' was not found.")
return core.ID(validate_id_name(value), is_declaration=True, type=type)
return validator

View File

@@ -199,7 +199,6 @@
#define USE_WEBSERVER_PORT 80 // NOLINT
#define USE_WEBSERVER_SORTING
#define USE_WIFI_11KV_SUPPORT
#define USE_WIFI_FAST_CONNECT
#define USB_HOST_MAX_REQUESTS 16
#ifdef USE_ARDUINO

View File

@@ -281,13 +281,13 @@ template<typename T> class FixedVector {
}
}
/// Emplace element without bounds checking - constructs in-place with arguments
/// Emplace element without bounds checking - constructs in-place
/// Caller must ensure sufficient capacity was allocated via init()
/// Returns reference to the newly constructed element
/// NOTE: Caller MUST ensure size_ < capacity_ before calling
template<typename... Args> T &emplace_back(Args &&...args) {
// Use placement new to construct the object in pre-allocated memory
new (&data_[size_]) T(std::forward<Args>(args)...);
T &emplace_back() {
// Use placement new to default-construct the object in pre-allocated memory
new (&data_[size_]) T();
size_++;
return data_[size_ - 1];
}
@@ -1158,4 +1158,18 @@ template<typename T, enable_if_t<std::is_pointer<T *>::value, int> = 0> T &id(T
///@}
/// @name Deprecated functions
///@{
ESPDEPRECATED("hexencode() is deprecated, use format_hex_pretty() instead.", "2022.1")
inline std::string hexencode(const uint8_t *data, uint32_t len) { return format_hex_pretty(data, len); }
template<typename T>
ESPDEPRECATED("hexencode() is deprecated, use format_hex_pretty() instead.", "2022.1")
std::string hexencode(const T &data) {
return hexencode(data.data(), data.size());
}
///@}
} // namespace esphome

View File

@@ -328,30 +328,17 @@ void HOT Scheduler::call(uint32_t now) {
// Single-core platforms don't use this queue and fall back to the heap-based approach.
//
// Note: Items cancelled via cancel_item_locked_() are marked with remove=true but still
// processed here. They are skipped during execution by should_skip_item_().
// This is intentional - no memory leak occurs.
//
// We use an index (defer_queue_front_) to track the read position instead of calling
// erase() on every pop, which would be O(n). The queue is processed once per loop -
// any items added during processing are left for the next loop iteration.
// Snapshot the queue end point - only process items that existed at loop start
// Items added during processing (by callbacks or other threads) run next loop
// No lock needed: single consumer (main loop), stale read just means we process less this iteration
size_t defer_queue_end = this->defer_queue_.size();
while (this->defer_queue_front_ < defer_queue_end) {
// processed here. They are removed from the queue normally via pop_front() but skipped
// during execution by should_skip_item_(). This is intentional - no memory leak occurs.
while (!this->defer_queue_.empty()) {
// The outer check is done without a lock for performance. If the queue
// appears non-empty, we lock and process an item. We don't need to check
// empty() again inside the lock because only this thread can remove items.
std::unique_ptr<SchedulerItem> item;
{
LockGuard lock(this->lock_);
// SAFETY: Moving out the unique_ptr leaves a nullptr in the vector at defer_queue_front_.
// This is intentional and safe because:
// 1. The vector is only cleaned up by cleanup_defer_queue_locked_() at the end of this function
// 2. Any code iterating defer_queue_ MUST check for nullptr items (see mark_matching_items_removed_
// and has_cancelled_timeout_in_container_ in scheduler.h)
// 3. The lock protects concurrent access, but the nullptr remains until cleanup
item = std::move(this->defer_queue_[this->defer_queue_front_]);
this->defer_queue_front_++;
item = std::move(this->defer_queue_.front());
this->defer_queue_.pop_front();
}
// Execute callback without holding lock to prevent deadlocks
@@ -362,13 +349,6 @@ void HOT Scheduler::call(uint32_t now) {
// Recycle the defer item after execution
this->recycle_item_(std::move(item));
}
// If we've consumed all items up to the snapshot point, clean up the dead space
// Single consumer (main loop), so no lock needed for this check
if (this->defer_queue_front_ >= defer_queue_end) {
LockGuard lock(this->lock_);
this->cleanup_defer_queue_locked_();
}
#endif /* not ESPHOME_THREAD_SINGLE */
// Convert the fresh timestamp from main loop to 64-bit for scheduler operations

View File

@@ -264,36 +264,6 @@ class Scheduler {
// Helper to recycle a SchedulerItem
void recycle_item_(std::unique_ptr<SchedulerItem> item);
#ifndef ESPHOME_THREAD_SINGLE
// Helper to cleanup defer_queue_ after processing
// IMPORTANT: Caller must hold the scheduler lock before calling this function.
inline void cleanup_defer_queue_locked_() {
// Check if new items were added by producers during processing
if (this->defer_queue_front_ >= this->defer_queue_.size()) {
// Common case: no new items - clear everything
this->defer_queue_.clear();
} else {
// Rare case: new items were added during processing - compact the vector
// This only happens when:
// 1. A deferred callback calls defer() again, or
// 2. Another thread calls defer() while we're processing
//
// Move unprocessed items (added during this loop) to the front for next iteration
//
// SAFETY: Compacted items may include cancelled items (marked for removal via
// cancel_item_locked_() during execution). This is safe because should_skip_item_()
// checks is_item_removed_() before executing, so cancelled items will be skipped
// and recycled on the next loop iteration.
size_t remaining = this->defer_queue_.size() - this->defer_queue_front_;
for (size_t i = 0; i < remaining; i++) {
this->defer_queue_[i] = std::move(this->defer_queue_[this->defer_queue_front_ + i]);
}
this->defer_queue_.resize(remaining);
}
this->defer_queue_front_ = 0;
}
#endif /* not ESPHOME_THREAD_SINGLE */
// Helper to check if item is marked for removal (platform-specific)
// Returns true if item should be skipped, handles platform-specific synchronization
// For ESPHOME_THREAD_MULTI_NO_ATOMICS platforms, the caller must hold the scheduler lock before calling this
@@ -312,18 +282,13 @@ class Scheduler {
// Helper to mark matching items in a container as removed
// Returns the number of items marked for removal
// IMPORTANT: Caller must hold the scheduler lock before calling this function.
// For ESPHOME_THREAD_MULTI_NO_ATOMICS platforms, the caller must hold the scheduler lock before calling this
// function.
template<typename Container>
size_t mark_matching_items_removed_(Container &container, Component *component, const char *name_cstr,
SchedulerItem::Type type, bool match_retry) {
size_t count = 0;
for (auto &item : container) {
// Skip nullptr items (can happen in defer_queue_ when items are being processed)
// The defer_queue_ uses index-based processing: items are std::moved out but left in the
// vector as nullptr until cleanup. Even though this function is called with lock held,
// the vector can still contain nullptr items from the processing loop. This check prevents crashes.
if (!item)
continue;
if (this->matches_item_(item, component, name_cstr, type, match_retry)) {
// Mark item for removal (platform-specific)
#ifdef ESPHOME_THREAD_MULTI_ATOMICS
@@ -346,12 +311,6 @@ class Scheduler {
bool has_cancelled_timeout_in_container_(const Container &container, Component *component, const char *name_cstr,
bool match_retry) const {
for (const auto &item : container) {
// Skip nullptr items (can happen in defer_queue_ when items are being processed)
// The defer_queue_ uses index-based processing: items are std::moved out but left in the
// vector as nullptr until cleanup. If this function is called during defer queue processing,
// it will iterate over these nullptr items. This check prevents crashes.
if (!item)
continue;
if (is_item_removed_(item.get()) &&
this->matches_item_(item, component, name_cstr, SchedulerItem::TIMEOUT, match_retry,
/* skip_removed= */ false)) {
@@ -365,12 +324,9 @@ class Scheduler {
std::vector<std::unique_ptr<SchedulerItem>> items_;
std::vector<std::unique_ptr<SchedulerItem>> to_add_;
#ifndef ESPHOME_THREAD_SINGLE
// Single-core platforms don't need the defer queue and save ~32 bytes of RAM
// Using std::vector instead of std::deque avoids 512-byte chunked allocations
// Index tracking avoids O(n) erase() calls when draining the queue each loop
std::vector<std::unique_ptr<SchedulerItem>> defer_queue_; // FIFO queue for defer() calls
size_t defer_queue_front_{0}; // Index of first valid item in defer_queue_ (tracks consumed items)
#endif /* ESPHOME_THREAD_SINGLE */
// Single-core platforms don't need the defer queue and save 40 bytes of RAM
std::deque<std::unique_ptr<SchedulerItem>> defer_queue_; // FIFO queue for defer() calls
#endif /* ESPHOME_THREAD_SINGLE */
uint32_t to_remove_{0};
// Memory pool for recycling SchedulerItem objects to reduce heap churn.

View File

@@ -1,362 +0,0 @@
"""GitHub download cache for ESPHome.
This module provides caching functionality for GitHub release downloads
to avoid redundant network I/O when switching between platforms.
"""
from __future__ import annotations
import hashlib
import json
import logging
from pathlib import Path
import shutil
import time
import urllib.error
import urllib.request
_LOGGER = logging.getLogger(__name__)
class GitHubCache:
"""Manages caching of GitHub release downloads."""
# Cache expiration time in seconds (30 days)
CACHE_EXPIRATION_SECONDS = 30 * 24 * 60 * 60
def __init__(self, cache_dir: Path | None = None):
"""Initialize the cache manager.
Args:
cache_dir: Directory to store cached files.
Defaults to ~/.esphome_cache/github
"""
if cache_dir is None:
cache_dir = Path.home() / ".esphome_cache" / "github"
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.metadata_file = self.cache_dir / "cache_metadata.json"
# Prune old files on initialization
try:
self._prune_old_files()
except Exception as e:
_LOGGER.debug("Failed to prune old cache files: %s", e)
def _load_metadata(self) -> dict:
"""Load cache metadata from disk."""
if self.metadata_file.exists():
try:
with open(self.metadata_file) as f:
return json.load(f)
except (OSError, ValueError, json.JSONDecodeError):
return {}
return {}
def _save_metadata(self, metadata: dict) -> None:
"""Save cache metadata to disk."""
try:
with open(self.metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
except OSError as e:
_LOGGER.debug("Failed to save cache metadata: %s", e)
@staticmethod
def is_github_url(url: str) -> bool:
"""Check if URL is a GitHub release download."""
return "github.com" in url.lower() and url.endswith(".zip")
def _get_cache_key(self, url: str) -> str:
"""Get cache key (hash) for a URL."""
return hashlib.sha256(url.encode()).hexdigest()
def _get_cache_path(self, url: str) -> Path:
"""Get cache file path for a URL."""
cache_key = self._get_cache_key(url)
ext = Path(url.split("?")[0]).suffix
return self.cache_dir / f"{cache_key}{ext}"
def _check_if_modified(
self,
url: str,
last_modified: str | None = None,
etag: str | None = None,
) -> bool:
"""Check if a URL has been modified using HTTP 304.
Args:
url: URL to check
last_modified: Last-Modified header from previous response
etag: ETag header from previous response
Returns:
True if modified, False if not modified (or offline/unreachable)
"""
if not last_modified and not etag:
# No cache headers available, assume modified
return True
try:
request = urllib.request.Request(url)
request.get_method = lambda: "HEAD"
if last_modified:
request.add_header("If-Modified-Since", last_modified)
if etag:
request.add_header("If-None-Match", etag)
try:
urllib.request.urlopen(request, timeout=10)
# 200 OK = file was modified
return True
except urllib.error.HTTPError as e:
if e.code == 304:
# Not modified
_LOGGER.debug("File not modified (HTTP 304): %s", url)
return False
# Other errors, assume modified to be safe
return True
except (OSError, urllib.error.URLError):
# If check fails (offline/network error), assume not modified (use cache)
_LOGGER.info("Cannot reach server (offline?), using cached file: %s", url)
return False
def get_cached_path(self, url: str, check_updates: bool = True) -> Path | None:
"""Get path to cached file if available and valid.
Args:
url: URL to check
check_updates: Whether to check for updates using HTTP 304
Returns:
Path to cached file if valid, None if needs download
"""
if not self.is_github_url(url):
return None
cache_path = self._get_cache_path(url)
if not cache_path.exists():
return None
# Load metadata
metadata = self._load_metadata()
cache_key = self._get_cache_key(url)
# Check if file should be re-downloaded
should_redownload = False
if check_updates and cache_key in metadata:
last_modified = metadata[cache_key].get("last_modified")
etag = metadata[cache_key].get("etag")
if self._check_if_modified(url, last_modified, etag):
# File was modified, need to re-download
_LOGGER.debug("Cached file is outdated: %s", url)
should_redownload = True
if should_redownload:
return None
# File is valid, update cached_at timestamp to keep it fresh
if cache_key in metadata:
metadata[cache_key]["cached_at"] = time.time()
self._save_metadata(metadata)
# Log appropriate message
if not check_updates:
_LOGGER.debug("Using cached file (no update check): %s", url)
elif cache_key not in metadata:
_LOGGER.debug("Using cached file (no metadata): %s", url)
else:
_LOGGER.debug("Using cached file: %s", url)
return cache_path
def save_to_cache(self, url: str, source_path: Path) -> None:
"""Save a downloaded file to cache.
Args:
url: URL the file was downloaded from
source_path: Path to the downloaded file
"""
if not self.is_github_url(url):
return
try:
cache_path = self._get_cache_path(url)
# Only copy if source and destination are different
if source_path.resolve() != cache_path.resolve():
shutil.copy2(source_path, cache_path)
# Try to get HTTP headers for caching
last_modified = None
etag = None
try:
request = urllib.request.Request(url)
request.get_method = lambda: "HEAD"
response = urllib.request.urlopen(request, timeout=10)
last_modified = response.headers.get("Last-Modified")
etag = response.headers.get("ETag")
except (OSError, urllib.error.URLError):
pass
# Update metadata
metadata = self._load_metadata()
cache_key = self._get_cache_key(url)
metadata[cache_key] = {
"url": url,
"size": cache_path.stat().st_size,
"cached_at": time.time(),
"last_modified": last_modified,
"etag": etag,
}
self._save_metadata(metadata)
_LOGGER.debug("Saved to cache: %s", url)
except OSError as e:
_LOGGER.debug("Failed to save to cache: %s", e)
def copy_from_cache(self, url: str, destination: Path) -> bool:
"""Copy a cached file to destination.
Args:
url: URL of the cached file
destination: Where to copy the file
Returns:
True if successful, False otherwise
"""
cached_path = self.get_cached_path(url, check_updates=True)
if not cached_path:
return False
try:
shutil.copy2(cached_path, destination)
_LOGGER.info("Using cached download for %s", url)
return True
except OSError as e:
_LOGGER.warning("Failed to use cache: %s", e)
return False
def cache_size(self) -> int:
"""Get total size of cached files in bytes."""
total = 0
try:
for file_path in self.cache_dir.glob("*"):
if file_path.is_file() and file_path != self.metadata_file:
total += file_path.stat().st_size
except OSError:
pass
return total
def list_cached(self) -> list[dict]:
"""List all cached files with metadata."""
cached_files = []
metadata = self._load_metadata()
for cache_key, meta in metadata.items():
cache_path = (
self.cache_dir / f"{cache_key}{Path(meta['url'].split('?')[0]).suffix}"
)
if cache_path.exists():
cached_files.append(
{
"url": meta["url"],
"path": cache_path,
"size": meta["size"],
"cached_at": meta.get("cached_at"),
"last_modified": meta.get("last_modified"),
"etag": meta.get("etag"),
}
)
return cached_files
def clear_cache(self) -> None:
"""Clear all cached files."""
try:
for file_path in self.cache_dir.glob("*"):
if file_path.is_file():
file_path.unlink()
_LOGGER.info("Cache cleared: %s", self.cache_dir)
except OSError as e:
_LOGGER.warning("Failed to clear cache: %s", e)
def _prune_old_files(self) -> None:
"""Remove cache files older than CACHE_EXPIRATION_SECONDS."""
current_time = time.time()
metadata = self._load_metadata()
removed_count = 0
removed_size = 0
# Check each file in metadata
for cache_key, meta in list(metadata.items()):
cached_at = meta.get("cached_at", 0)
age_seconds = current_time - cached_at
if age_seconds > self.CACHE_EXPIRATION_SECONDS:
# File is too old, remove it
cache_path = (
self.cache_dir
/ f"{cache_key}{Path(meta['url'].split('?')[0]).suffix}"
)
if cache_path.exists():
file_size = cache_path.stat().st_size
cache_path.unlink()
removed_size += file_size
removed_count += 1
_LOGGER.debug(
"Pruned old cache file (age: %.1f days): %s",
age_seconds / (24 * 60 * 60),
meta["url"],
)
# Remove from metadata
del metadata[cache_key]
# Also check for orphaned files (files without metadata)
for file_path in self.cache_dir.glob("*.zip"):
if file_path == self.metadata_file:
continue
# Check if file is in metadata
found_in_metadata = False
for cache_key in metadata:
if file_path.name.startswith(cache_key):
found_in_metadata = True
break
if not found_in_metadata:
# Orphaned file - check age by modification time
file_age = current_time - file_path.stat().st_mtime
if file_age > self.CACHE_EXPIRATION_SECONDS:
file_size = file_path.stat().st_size
file_path.unlink()
removed_size += file_size
removed_count += 1
_LOGGER.debug(
"Pruned orphaned cache file (age: %.1f days): %s",
file_age / (24 * 60 * 60),
file_path.name,
)
# Save updated metadata if anything was removed
if removed_count > 0:
self._save_metadata(metadata)
removed_mb = removed_size / (1024 * 1024)
_LOGGER.info(
"Pruned %d old cache file(s), freed %.2f MB",
removed_count,
removed_mb,
)
# Global cache instance
_cache: GitHubCache | None = None
def get_cache() -> GitHubCache:
"""Get the global GitHub cache instance."""
global _cache # noqa: PLW0603
if _cache is None:
_cache = GitHubCache()
return _cache

View File

@@ -5,6 +5,7 @@ import os
from pathlib import Path
import re
import subprocess
from typing import Any
from esphome.const import CONF_COMPILE_PROCESS_LIMIT, CONF_ESPHOME, KEY_CORE
from esphome.core import CORE, EsphomeError
@@ -43,168 +44,32 @@ def patch_structhash():
def patch_file_downloader():
"""Patch PlatformIO's FileDownloader to add caching and retry on PackageException errors.
"""Patch PlatformIO's FileDownloader to retry on PackageException errors."""
from platformio.package.download import FileDownloader
from platformio.package.exception import PackageException
This function attempts to patch PlatformIO's internal download mechanism.
If patching fails (due to API changes), it gracefully falls back to no caching.
"""
try:
from platformio.package.download import FileDownloader
from platformio.package.exception import PackageException
except ImportError as e:
_LOGGER.debug("Could not import PlatformIO modules for patching: %s", e)
return
original_init = FileDownloader.__init__
# Import our cache module
from esphome.github_cache import GitHubCache
def patched_init(self, *args: Any, **kwargs: Any) -> None:
max_retries = 3
_LOGGER.debug("Applying GitHub download cache patch...")
# Verify the classes have the expected methods before patching
if not hasattr(FileDownloader, "__init__") or not hasattr(FileDownloader, "start"):
_LOGGER.warning(
"PlatformIO FileDownloader API has changed, skipping cache patch"
)
return
try:
original_init = FileDownloader.__init__
original_start = FileDownloader.start
# Initialize cache in .platformio directory so it benefits from GitHub Actions cache
platformio_dir = Path.home() / ".platformio"
cache_dir = platformio_dir / "esphome_download_cache"
cache_dir_existed = cache_dir.exists()
cache = GitHubCache(cache_dir=cache_dir)
if not cache_dir_existed:
_LOGGER.info("Created GitHub download cache at: %s", cache.cache_dir)
except Exception as e:
_LOGGER.warning("Failed to initialize GitHub download cache: %s", e)
return
def patched_init(self, *args, **kwargs):
"""Patched init that checks cache before making HTTP connection."""
try:
# Extract URL from args (first positional argument)
url = args[0] if args else kwargs.get("url")
dest_dir = args[1] if len(args) > 1 else kwargs.get("dest_dir")
# Debug: Log all downloads
_LOGGER.debug("[GitHub Cache] Download request for: %s", url)
# Store URL for later use (original FileDownloader doesn't store it)
self._esphome_cache_url = url if cache.is_github_url(url) else None
# Check cache for GitHub URLs BEFORE making HTTP request
if self._esphome_cache_url:
_LOGGER.debug("[GitHub Cache] This is a GitHub URL, checking cache...")
self._esphome_use_cache = cache.get_cached_path(url, check_updates=True)
if self._esphome_use_cache:
_LOGGER.info(
"Found %s in cache, will restore instead of downloading",
Path(url.split("?")[0]).name,
)
_LOGGER.debug(
"[GitHub Cache] Found in cache: %s", self._esphome_use_cache
for attempt in range(max_retries):
try:
return original_init(self, *args, **kwargs)
except PackageException as e:
if attempt < max_retries - 1:
_LOGGER.warning(
"Package download failed: %s. Retrying... (attempt %d/%d)",
str(e),
attempt + 1,
max_retries,
)
else:
_LOGGER.debug(
"[GitHub Cache] Not in cache, will download and cache"
)
else:
self._esphome_use_cache = None
if url and str(url).startswith("http"):
_LOGGER.debug("[GitHub Cache] Not a GitHub URL, skipping cache")
# Final attempt - re-raise
raise
return None
# Only make HTTP connection if we don't have cached file
if self._esphome_use_cache:
# Skip HTTP connection, we'll handle this in start()
# Set minimal attributes to satisfy FileDownloader
# Create a mock session that can be safely closed in __del__
class MockSession:
def close(self):
pass
self._http_session = MockSession()
self._http_response = None
self._fname = Path(url.split("?")[0]).name
self._destination = self._fname
if dest_dir:
from os.path import join
self._destination = join(dest_dir, self._fname)
# Note: Actual restoration logged in patched_start
return None # Don't call original_init
# Normal initialization with retry logic
max_retries = 3
for attempt in range(max_retries):
try:
return original_init(self, *args, **kwargs)
except PackageException as e:
if attempt < max_retries - 1:
_LOGGER.warning(
"Package download failed: %s. Retrying... (attempt %d/%d)",
str(e),
attempt + 1,
max_retries,
)
else:
# Final attempt - re-raise
raise
return None
except Exception as e:
# If anything goes wrong in our cache logic, fall back to normal download
_LOGGER.debug("Cache check failed, falling back to normal download: %s", e)
self._esphome_cache_url = None
self._esphome_use_cache = None
return original_init(self, *args, **kwargs)
def patched_start(self, *args, **kwargs):
"""Patched start that uses cache when available."""
try:
import shutil
# Get the cache URL and path that were set in __init__
cache_url = getattr(self, "_esphome_cache_url", None)
cached_file = getattr(self, "_esphome_use_cache", None)
# If we're using cache, copy file instead of downloading
if cached_file:
try:
shutil.copy2(cached_file, self._destination)
_LOGGER.info(
"Restored %s from cache (avoided download)",
Path(cached_file).name,
)
return True
except OSError as e:
_LOGGER.warning("Failed to copy from cache: %s", e)
# Fall through to re-download
# Perform normal download
result = original_start(self, *args, **kwargs)
# Save to cache if it was a GitHub URL
if cache_url:
try:
cache.save_to_cache(cache_url, Path(self._destination))
except OSError as e:
_LOGGER.debug("Failed to save to cache: %s", e)
return result
except Exception as e:
# If anything goes wrong, fall back to normal download
_LOGGER.debug("Cache restoration failed, using normal download: %s", e)
return original_start(self, *args, **kwargs)
# Apply the patches
try:
FileDownloader.__init__ = patched_init
FileDownloader.start = patched_start
_LOGGER.debug("GitHub download cache patch applied successfully")
except Exception as e:
_LOGGER.warning("Failed to apply GitHub download cache patch: %s", e)
FileDownloader.__init__ = patched_init
IGNORE_LIB_WARNINGS = f"(?:{'|'.join(['Hash', 'Update'])})"
@@ -222,8 +87,6 @@ FILTER_PLATFORMIO_LINES = [
r"Memory Usage -> https://bit.ly/pio-memory-usage",
r"Found: https://platformio.org/lib/show/.*",
r"Using cache: .*",
# Don't filter our cache messages - let users see when cache is being used
# r"Using cached download for .*",
r"Installing dependencies",
r"Library Manager: Already installed, built-in library",
r"Building in .* mode",
@@ -511,23 +374,3 @@ class IDEData:
return f"{self.cc_path[:-7]}addr2line.exe"
return f"{self.cc_path[:-3]}addr2line"
@property
def objdump_path(self) -> str:
# replace gcc at end with objdump
path = self.cc_path
return (
f"{path[:-7]}objdump.exe"
if path.endswith(".exe")
else f"{path[:-3]}objdump"
)
@property
def readelf_path(self) -> str:
# replace gcc at end with readelf
path = self.cc_path
return (
f"{path[:-7]}readelf.exe"
if path.endswith(".exe")
else f"{path[:-3]}readelf"
)

View File

@@ -34,8 +34,6 @@ from typing import Any
# Add esphome to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from helpers import BASE_BUS_COMPONENTS
from esphome import yaml_util
from esphome.config_helpers import Extend, Remove
@@ -69,6 +67,18 @@ NO_BUSES_SIGNATURE = "no_buses"
# Isolated components have unique signatures and cannot be merged with others
ISOLATED_SIGNATURE_PREFIX = "isolated_"
# Base bus components - these ARE the bus implementations and should not
# be flagged as needing migration since they are the platform/base components
BASE_BUS_COMPONENTS = {
"i2c",
"spi",
"uart",
"modbus",
"canbus",
"remote_transmitter",
"remote_receiver",
}
# Components that must be tested in isolation (not grouped or batched with others)
# These have known build issues that prevent grouping
# NOTE: This should be kept in sync with both test_build_components and split_components_for_ci.py

View File

@@ -1415,13 +1415,7 @@ class RepeatedTypeInfo(TypeInfo):
super().__init__(field)
# Check if this is a pointer field by looking for container_pointer option
self._container_type = get_field_opt(field, pb.container_pointer, "")
# Check for non-template container pointer
self._container_no_template = get_field_opt(
field, pb.container_pointer_no_template, ""
)
self._use_pointer = bool(self._container_type) or bool(
self._container_no_template
)
self._use_pointer = bool(self._container_type)
# Check if this should use FixedVector instead of std::vector
self._use_fixed_vector = get_field_opt(field, pb.fixed_vector, False)
@@ -1440,18 +1434,12 @@ class RepeatedTypeInfo(TypeInfo):
@property
def cpp_type(self) -> str:
if self._container_no_template:
# Non-template container: use type as-is without appending template parameters
return f"const {self._container_no_template}*"
if self._use_pointer and self._container_type:
# For pointer fields, use the specified container type
# Two cases:
# 1. "std::set<climate::ClimateMode>" - Full type with template params, use as-is
# 2. "std::set" - No <>, append the element type
# If the container type already includes the element type (e.g., std::set<climate::ClimateMode>)
# use it as-is, otherwise append the element type
if "<" in self._container_type and ">" in self._container_type:
# Has template parameters specified, use as-is
return f"const {self._container_type}*"
# No <> at all, append element type
return f"const {self._container_type}<{self._ti.cpp_type}>*"
if self._use_fixed_vector:
return f"FixedVector<{self._ti.cpp_type}>"

View File

@@ -1,164 +0,0 @@
#!/usr/bin/env python3
"""
Pre-cache PlatformIO GitHub Downloads
This script extracts GitHub URLs from platformio.ini and pre-caches them
to avoid redundant downloads when switching between ESP8266 and ESP32 builds.
Usage:
python3 script/cache_platformio_downloads.py [platformio.ini]
"""
import argparse
import configparser
from pathlib import Path
import re
import sys
# Import the cache manager
sys.path.insert(0, str(Path(__file__).parent.parent))
from esphome.github_cache import GitHubCache
def extract_github_urls(platformio_ini: Path) -> list[str]:
"""Extract all GitHub URLs from platformio.ini.
Args:
platformio_ini: Path to platformio.ini file
Returns:
List of GitHub URLs found
"""
config = configparser.ConfigParser(inline_comment_prefixes=(";",))
config.read(platformio_ini)
urls = []
github_pattern = re.compile(r"https://github\.com/[^\s;]+\.zip")
for section in config.sections():
conf = config[section]
# Check platform
if "platform" in conf:
platform_value = conf["platform"]
matches = github_pattern.findall(platform_value)
urls.extend(matches)
# Check platform_packages
if "platform_packages" in conf:
for line in conf["platform_packages"].splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
matches = github_pattern.findall(line)
urls.extend(matches)
# Remove duplicates while preserving order using dict
return list(dict.fromkeys(urls))
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Pre-cache PlatformIO GitHub downloads",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
This script scans platformio.ini for GitHub URLs and pre-caches them.
This avoids redundant downloads when switching between platforms (e.g., ESP8266 and ESP32).
Examples:
# Cache downloads from default platformio.ini
%(prog)s
# Cache downloads from specific file
%(prog)s custom_platformio.ini
# Show what would be cached without downloading
%(prog)s --dry-run
""",
)
parser.add_argument(
"platformio_ini",
nargs="?",
default="platformio.ini",
help="Path to platformio.ini (default: platformio.ini)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be cached without downloading",
)
parser.add_argument(
"--cache-dir",
type=Path,
help="Cache directory (default: ~/.platformio/esphome_download_cache)",
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-download even if cached",
)
args = parser.parse_args()
platformio_ini = Path(args.platformio_ini)
if not platformio_ini.exists():
print(f"Error: {platformio_ini} not found", file=sys.stderr)
return 1
# Extract URLs
print(f"Scanning {platformio_ini} for GitHub URLs...")
urls = extract_github_urls(platformio_ini)
if not urls:
print("No GitHub URLs found in platformio.ini")
return 0
print(f"Found {len(urls)} unique GitHub URL(s):")
for url in urls:
print(f" - {url}")
print()
if args.dry_run:
print("Dry run - not downloading")
return 0
# Initialize cache (use PlatformIO directory by default)
cache_dir = args.cache_dir
if cache_dir is None:
cache_dir = Path.home() / ".platformio" / "esphome_download_cache"
cache = GitHubCache(cache_dir)
# Cache each URL
success_count = 0
for i, url in enumerate(urls, 1):
print(f"[{i}/{len(urls)}] Checking {url}")
try:
# Use the download_with_progress from github_download_cache CLI
from script.github_download_cache import download_with_progress
download_with_progress(cache, url, force=args.force, check_updates=True)
success_count += 1
print()
except Exception as e:
print(f"Error caching {url}: {e}", file=sys.stderr)
print()
# Show cache stats
total_size = cache.cache_size()
size_mb = total_size / (1024 * 1024)
print("\nCache summary:")
print(f" Successfully cached: {success_count}/{len(urls)}")
print(f" Total cache size: {size_mb:.2f} MB")
print(f" Cache location: {cache.cache_dir}")
return 0 if success_count == len(urls) else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,88 +0,0 @@
#!/usr/bin/env python3
"""Add metadata to memory analysis JSON file.
This script adds components and platform metadata to an existing
memory analysis JSON file. Used by CI to ensure all required fields are present
for the comment script.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Add metadata to memory analysis JSON file"
)
parser.add_argument(
"--json-file",
required=True,
help="Path to JSON file to update",
)
parser.add_argument(
"--components",
required=True,
help='JSON array of component names (e.g., \'["api", "wifi"]\')',
)
parser.add_argument(
"--platform",
required=True,
help="Platform name",
)
args = parser.parse_args()
# Load existing JSON
json_path = Path(args.json_file)
if not json_path.exists():
print(f"Error: JSON file not found: {args.json_file}", file=sys.stderr)
return 1
try:
with open(json_path, encoding="utf-8") as f:
data = json.load(f)
except (json.JSONDecodeError, OSError) as e:
print(f"Error loading JSON: {e}", file=sys.stderr)
return 1
# Parse components
try:
components = json.loads(args.components)
if not isinstance(components, list):
print("Error: --components must be a JSON array", file=sys.stderr)
return 1
# Element-level validation: ensure each component is a non-empty string
for idx, comp in enumerate(components):
if not isinstance(comp, str) or not comp.strip():
print(
f"Error: component at index {idx} is not a non-empty string: {comp!r}",
file=sys.stderr,
)
return 1
except json.JSONDecodeError as e:
print(f"Error parsing components: {e}", file=sys.stderr)
return 1
# Add metadata
data["components"] = components
data["platform"] = args.platform
# Write back
try:
with open(json_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
print(f"Added metadata to {args.json_file}", file=sys.stderr)
except OSError as e:
print(f"Error writing JSON: {e}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,23 +0,0 @@
"""Common helper functions for CI scripts."""
from __future__ import annotations
import os
def write_github_output(outputs: dict[str, str | int]) -> None:
"""Write multiple outputs to GITHUB_OUTPUT or stdout.
When running in GitHub Actions, writes to the GITHUB_OUTPUT file.
When running locally, writes to stdout for debugging.
Args:
outputs: Dictionary of key-value pairs to write
"""
github_output = os.environ.get("GITHUB_OUTPUT")
if github_output:
with open(github_output, "a", encoding="utf-8") as f:
f.writelines(f"{key}={value}\n" for key, value in outputs.items())
else:
for key, value in outputs.items():
print(f"{key}={value}")

View File

@@ -1,643 +0,0 @@
#!/usr/bin/env python3
"""Post or update a PR comment with memory impact analysis results.
This script creates or updates a GitHub PR comment with memory usage changes.
It uses the GitHub CLI (gh) to manage comments and maintains a single comment
that gets updated on subsequent runs.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import subprocess
import sys
from jinja2 import Environment, FileSystemLoader
# Add esphome to path for analyze_memory import
sys.path.insert(0, str(Path(__file__).parent.parent))
# pylint: disable=wrong-import-position
# Comment marker to identify our memory impact comments
COMMENT_MARKER = "<!-- esphome-memory-impact-analysis -->"
def run_gh_command(args: list[str], operation: str) -> subprocess.CompletedProcess:
"""Run a gh CLI command with error handling.
Args:
args: Command arguments (including 'gh')
operation: Description of the operation for error messages
Returns:
CompletedProcess result
Raises:
subprocess.CalledProcessError: If command fails (with detailed error output)
"""
try:
return subprocess.run(
args,
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
print(
f"ERROR: {operation} failed with exit code {e.returncode}", file=sys.stderr
)
print(f"ERROR: Command: {' '.join(args)}", file=sys.stderr)
print(f"ERROR: stdout: {e.stdout}", file=sys.stderr)
print(f"ERROR: stderr: {e.stderr}", file=sys.stderr)
raise
# Thresholds for emoji significance indicators (percentage)
OVERALL_CHANGE_THRESHOLD = 1.0 # Overall RAM/Flash changes
COMPONENT_CHANGE_THRESHOLD = 3.0 # Component breakdown changes
# Display limits for tables
MAX_COMPONENT_BREAKDOWN_ROWS = 20 # Maximum components to show in breakdown table
MAX_CHANGED_SYMBOLS_ROWS = 30 # Maximum changed symbols to show
MAX_NEW_SYMBOLS_ROWS = 15 # Maximum new symbols to show
MAX_REMOVED_SYMBOLS_ROWS = 15 # Maximum removed symbols to show
# Symbol display formatting
SYMBOL_DISPLAY_MAX_LENGTH = 100 # Max length before using <details> tag
SYMBOL_DISPLAY_TRUNCATE_LENGTH = 97 # Length to truncate in summary
# Component change noise threshold
COMPONENT_CHANGE_NOISE_THRESHOLD = 2 # Ignore component changes ≤ this many bytes
# Template directory
TEMPLATE_DIR = Path(__file__).parent / "templates"
def load_analysis_json(json_path: str) -> dict | None:
"""Load memory analysis results from JSON file.
Args:
json_path: Path to analysis JSON file
Returns:
Dictionary with analysis results or None if file doesn't exist/can't be loaded
"""
json_file = Path(json_path)
if not json_file.exists():
print(f"Analysis JSON not found: {json_path}", file=sys.stderr)
return None
try:
with open(json_file, encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
print(f"Failed to load analysis JSON: {e}", file=sys.stderr)
return None
def format_bytes(bytes_value: int) -> str:
"""Format bytes value with comma separators.
Args:
bytes_value: Number of bytes
Returns:
Formatted string with comma separators (e.g., "1,234 bytes")
"""
return f"{bytes_value:,} bytes"
def format_change(before: int, after: int, threshold: float | None = None) -> str:
"""Format memory change with delta and percentage.
Args:
before: Memory usage before change (in bytes)
after: Memory usage after change (in bytes)
threshold: Optional percentage threshold for "significant" change.
If provided, adds supplemental emoji (🎉/🚨/🔸/✅) to chart icons.
If None, only shows chart icons (📈/📉/➡️).
Returns:
Formatted string with delta and percentage
"""
delta = after - before
percentage = 0.0 if before == 0 else (delta / before) * 100
# Always use chart icons to show direction
if delta > 0:
delta_str = f"+{delta:,} bytes"
trend_icon = "📈"
# Add supplemental emoji based on threshold if provided
if threshold is not None:
significance = "🚨" if abs(percentage) > threshold else "🔸"
emoji = f"{trend_icon} {significance}"
else:
emoji = trend_icon
elif delta < 0:
delta_str = f"{delta:,} bytes"
trend_icon = "📉"
# Add supplemental emoji based on threshold if provided
if threshold is not None:
significance = "🎉" if abs(percentage) > threshold else ""
emoji = f"{trend_icon} {significance}"
else:
emoji = trend_icon
else:
delta_str = "+0 bytes"
emoji = "➡️"
# Format percentage with sign
if percentage > 0:
pct_str = f"+{percentage:.2f}%"
elif percentage < 0:
pct_str = f"{percentage:.2f}%"
else:
pct_str = "0.00%"
return f"{emoji} {delta_str} ({pct_str})"
def prepare_symbol_changes_data(
target_symbols: dict | None, pr_symbols: dict | None
) -> dict | None:
"""Prepare symbol changes data for template rendering.
Args:
target_symbols: Symbol name to size mapping for target branch
pr_symbols: Symbol name to size mapping for PR branch
Returns:
Dictionary with changed, new, and removed symbols, or None if no changes
"""
if not target_symbols or not pr_symbols:
return None
# Find all symbols that exist in both branches or only in one
all_symbols = set(target_symbols.keys()) | set(pr_symbols.keys())
# Track changes
changed_symbols: list[
tuple[str, int, int, int]
] = [] # (symbol, target_size, pr_size, delta)
new_symbols: list[tuple[str, int]] = [] # (symbol, size)
removed_symbols: list[tuple[str, int]] = [] # (symbol, size)
for symbol in all_symbols:
target_size = target_symbols.get(symbol, 0)
pr_size = pr_symbols.get(symbol, 0)
if target_size == 0 and pr_size > 0:
# New symbol
new_symbols.append((symbol, pr_size))
elif target_size > 0 and pr_size == 0:
# Removed symbol
removed_symbols.append((symbol, target_size))
elif target_size != pr_size:
# Changed symbol
delta = pr_size - target_size
changed_symbols.append((symbol, target_size, pr_size, delta))
if not changed_symbols and not new_symbols and not removed_symbols:
return None
# Sort by size/delta
changed_symbols.sort(key=lambda x: abs(x[3]), reverse=True)
new_symbols.sort(key=lambda x: x[1], reverse=True)
removed_symbols.sort(key=lambda x: x[1], reverse=True)
return {
"changed_symbols": changed_symbols,
"new_symbols": new_symbols,
"removed_symbols": removed_symbols,
}
def prepare_component_breakdown_data(
target_analysis: dict | None, pr_analysis: dict | None
) -> list[tuple[str, int, int, int]] | None:
"""Prepare component breakdown data for template rendering.
Args:
target_analysis: Component memory breakdown for target branch
pr_analysis: Component memory breakdown for PR branch
Returns:
List of tuples (component, target_flash, pr_flash, delta), or None if no changes
"""
if not target_analysis or not pr_analysis:
return None
# Combine all components from both analyses
all_components = set(target_analysis.keys()) | set(pr_analysis.keys())
# Filter to components that have changed (ignoring noise)
changed_components: list[
tuple[str, int, int, int]
] = [] # (comp, target_flash, pr_flash, delta)
for comp in all_components:
target_mem = target_analysis.get(comp, {})
pr_mem = pr_analysis.get(comp, {})
target_flash = target_mem.get("flash_total", 0)
pr_flash = pr_mem.get("flash_total", 0)
# Only include if component has meaningful change (above noise threshold)
delta = pr_flash - target_flash
if abs(delta) > COMPONENT_CHANGE_NOISE_THRESHOLD:
changed_components.append((comp, target_flash, pr_flash, delta))
if not changed_components:
return None
# Sort by absolute delta (largest changes first)
changed_components.sort(key=lambda x: abs(x[3]), reverse=True)
return changed_components
def create_comment_body(
components: list[str],
platform: str,
target_ram: int,
target_flash: int,
pr_ram: int,
pr_flash: int,
target_analysis: dict | None = None,
pr_analysis: dict | None = None,
target_symbols: dict | None = None,
pr_symbols: dict | None = None,
) -> str:
"""Create the comment body with memory impact analysis using Jinja2 templates.
Args:
components: List of component names (merged config)
platform: Platform name
target_ram: RAM usage in target branch
target_flash: Flash usage in target branch
pr_ram: RAM usage in PR branch
pr_flash: Flash usage in PR branch
target_analysis: Optional component breakdown for target branch
pr_analysis: Optional component breakdown for PR branch
target_symbols: Optional symbol map for target branch
pr_symbols: Optional symbol map for PR branch
Returns:
Formatted comment body
"""
# Set up Jinja2 environment
env = Environment(
loader=FileSystemLoader(TEMPLATE_DIR),
trim_blocks=True,
lstrip_blocks=True,
)
# Register custom filters
env.filters["format_bytes"] = format_bytes
env.filters["format_change"] = format_change
# Prepare template context
context = {
"comment_marker": COMMENT_MARKER,
"platform": platform,
"target_ram": format_bytes(target_ram),
"pr_ram": format_bytes(pr_ram),
"target_flash": format_bytes(target_flash),
"pr_flash": format_bytes(pr_flash),
"ram_change": format_change(
target_ram, pr_ram, threshold=OVERALL_CHANGE_THRESHOLD
),
"flash_change": format_change(
target_flash, pr_flash, threshold=OVERALL_CHANGE_THRESHOLD
),
"component_change_threshold": COMPONENT_CHANGE_THRESHOLD,
}
# Format components list
if len(components) == 1:
context["components_str"] = f"`{components[0]}`"
context["config_note"] = "a representative test configuration"
else:
context["components_str"] = ", ".join(f"`{c}`" for c in sorted(components))
context["config_note"] = (
f"a merged configuration with {len(components)} components"
)
# Prepare component breakdown if available
component_breakdown = ""
if target_analysis and pr_analysis:
changed_components = prepare_component_breakdown_data(
target_analysis, pr_analysis
)
if changed_components:
template = env.get_template("ci_memory_impact_component_breakdown.j2")
component_breakdown = template.render(
changed_components=changed_components,
format_bytes=format_bytes,
format_change=format_change,
component_change_threshold=COMPONENT_CHANGE_THRESHOLD,
max_rows=MAX_COMPONENT_BREAKDOWN_ROWS,
)
# Prepare symbol changes if available
symbol_changes = ""
if target_symbols and pr_symbols:
symbol_data = prepare_symbol_changes_data(target_symbols, pr_symbols)
if symbol_data:
template = env.get_template("ci_memory_impact_symbol_changes.j2")
symbol_changes = template.render(
**symbol_data,
format_bytes=format_bytes,
format_change=format_change,
max_changed_rows=MAX_CHANGED_SYMBOLS_ROWS,
max_new_rows=MAX_NEW_SYMBOLS_ROWS,
max_removed_rows=MAX_REMOVED_SYMBOLS_ROWS,
symbol_max_length=SYMBOL_DISPLAY_MAX_LENGTH,
symbol_truncate_length=SYMBOL_DISPLAY_TRUNCATE_LENGTH,
)
if not target_analysis or not pr_analysis:
print("No ELF files provided, skipping detailed analysis", file=sys.stderr)
context["component_breakdown"] = component_breakdown
context["symbol_changes"] = symbol_changes
# Render main template
template = env.get_template("ci_memory_impact_comment_template.j2")
return template.render(**context)
def find_existing_comment(pr_number: str) -> str | None:
"""Find existing memory impact comment on the PR.
Args:
pr_number: PR number
Returns:
Comment numeric ID if found, None otherwise
Raises:
subprocess.CalledProcessError: If gh command fails
"""
print(f"DEBUG: Looking for existing comment on PR #{pr_number}", file=sys.stderr)
# Use gh api to get comments directly - this returns the numeric id field
result = run_gh_command(
[
"gh",
"api",
f"/repos/{{owner}}/{{repo}}/issues/{pr_number}/comments",
"--jq",
".[] | {id, body}",
],
operation="Get PR comments",
)
print(
f"DEBUG: gh api comments output (first 500 chars):\n{result.stdout[:500]}",
file=sys.stderr,
)
# Parse comments and look for our marker
comment_count = 0
for line in result.stdout.strip().split("\n"):
if not line:
continue
try:
comment = json.loads(line)
comment_count += 1
comment_id = comment.get("id")
print(
f"DEBUG: Checking comment {comment_count}: id={comment_id}",
file=sys.stderr,
)
body = comment.get("body", "")
if COMMENT_MARKER in body:
print(
f"DEBUG: Found existing comment with id={comment_id}",
file=sys.stderr,
)
# Return the numeric id
return str(comment_id)
print("DEBUG: Comment does not contain marker", file=sys.stderr)
except json.JSONDecodeError as e:
print(f"DEBUG: JSON decode error: {e}", file=sys.stderr)
continue
print(
f"DEBUG: No existing comment found (checked {comment_count} comments)",
file=sys.stderr,
)
return None
def update_existing_comment(comment_id: str, comment_body: str) -> None:
"""Update an existing comment.
Args:
comment_id: Comment ID to update
comment_body: New comment body text
Raises:
subprocess.CalledProcessError: If gh command fails
"""
print(f"DEBUG: Updating existing comment {comment_id}", file=sys.stderr)
print(f"DEBUG: Comment body length: {len(comment_body)} bytes", file=sys.stderr)
result = run_gh_command(
[
"gh",
"api",
f"/repos/{{owner}}/{{repo}}/issues/comments/{comment_id}",
"-X",
"PATCH",
"-f",
f"body={comment_body}",
],
operation="Update PR comment",
)
print(f"DEBUG: Update response: {result.stdout}", file=sys.stderr)
def create_new_comment(pr_number: str, comment_body: str) -> None:
"""Create a new PR comment.
Args:
pr_number: PR number
comment_body: Comment body text
Raises:
subprocess.CalledProcessError: If gh command fails
"""
print(f"DEBUG: Posting new comment on PR #{pr_number}", file=sys.stderr)
print(f"DEBUG: Comment body length: {len(comment_body)} bytes", file=sys.stderr)
result = run_gh_command(
["gh", "pr", "comment", pr_number, "--body", comment_body],
operation="Create PR comment",
)
print(f"DEBUG: Post response: {result.stdout}", file=sys.stderr)
def post_or_update_comment(pr_number: str, comment_body: str) -> None:
"""Post a new comment or update existing one.
Args:
pr_number: PR number
comment_body: Comment body text
Raises:
subprocess.CalledProcessError: If gh command fails
"""
# Look for existing comment
existing_comment_id = find_existing_comment(pr_number)
if existing_comment_id and existing_comment_id != "None":
update_existing_comment(existing_comment_id, comment_body)
else:
create_new_comment(pr_number, comment_body)
print("Comment posted/updated successfully", file=sys.stderr)
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Post or update PR comment with memory impact analysis"
)
parser.add_argument("--pr-number", required=True, help="PR number")
parser.add_argument(
"--target-json",
required=True,
help="Path to target branch analysis JSON file",
)
parser.add_argument(
"--pr-json",
required=True,
help="Path to PR branch analysis JSON file",
)
args = parser.parse_args()
# Load analysis JSON files (all data comes from JSON for security)
target_data: dict | None = load_analysis_json(args.target_json)
if not target_data:
print("Error: Failed to load target analysis JSON", file=sys.stderr)
sys.exit(1)
pr_data: dict | None = load_analysis_json(args.pr_json)
if not pr_data:
print("Error: Failed to load PR analysis JSON", file=sys.stderr)
sys.exit(1)
# Extract detailed analysis if available
target_analysis: dict | None = None
pr_analysis: dict | None = None
target_symbols: dict | None = None
pr_symbols: dict | None = None
if target_data.get("detailed_analysis"):
target_analysis = target_data["detailed_analysis"].get("components")
target_symbols = target_data["detailed_analysis"].get("symbols")
if pr_data.get("detailed_analysis"):
pr_analysis = pr_data["detailed_analysis"].get("components")
pr_symbols = pr_data["detailed_analysis"].get("symbols")
# Extract all values from JSON files (prevents shell injection from PR code)
components = target_data.get("components")
platform = target_data.get("platform")
target_ram = target_data.get("ram_bytes")
target_flash = target_data.get("flash_bytes")
pr_ram = pr_data.get("ram_bytes")
pr_flash = pr_data.get("flash_bytes")
# Validate required fields and types
missing_fields: list[str] = []
type_errors: list[str] = []
if components is None:
missing_fields.append("components")
elif not isinstance(components, list):
type_errors.append(
f"components must be a list, got {type(components).__name__}"
)
else:
for idx, comp in enumerate(components):
if not isinstance(comp, str):
type_errors.append(
f"components[{idx}] must be a string, got {type(comp).__name__}"
)
if platform is None:
missing_fields.append("platform")
elif not isinstance(platform, str):
type_errors.append(f"platform must be a string, got {type(platform).__name__}")
if target_ram is None:
missing_fields.append("target.ram_bytes")
elif not isinstance(target_ram, int):
type_errors.append(
f"target.ram_bytes must be an integer, got {type(target_ram).__name__}"
)
if target_flash is None:
missing_fields.append("target.flash_bytes")
elif not isinstance(target_flash, int):
type_errors.append(
f"target.flash_bytes must be an integer, got {type(target_flash).__name__}"
)
if pr_ram is None:
missing_fields.append("pr.ram_bytes")
elif not isinstance(pr_ram, int):
type_errors.append(
f"pr.ram_bytes must be an integer, got {type(pr_ram).__name__}"
)
if pr_flash is None:
missing_fields.append("pr.flash_bytes")
elif not isinstance(pr_flash, int):
type_errors.append(
f"pr.flash_bytes must be an integer, got {type(pr_flash).__name__}"
)
if missing_fields or type_errors:
if missing_fields:
print(
f"Error: JSON files missing required fields: {', '.join(missing_fields)}",
file=sys.stderr,
)
if type_errors:
print(
f"Error: Type validation failed: {'; '.join(type_errors)}",
file=sys.stderr,
)
print(f"Target JSON keys: {list(target_data.keys())}", file=sys.stderr)
print(f"PR JSON keys: {list(pr_data.keys())}", file=sys.stderr)
sys.exit(1)
# Create comment body
# Note: Memory totals (RAM/Flash) are summed across all builds if multiple were run.
comment_body = create_comment_body(
components=components,
platform=platform,
target_ram=target_ram,
target_flash=target_flash,
pr_ram=pr_ram,
pr_flash=pr_flash,
target_analysis=target_analysis,
pr_analysis=pr_analysis,
target_symbols=target_symbols,
pr_symbols=pr_symbols,
)
# Post or update comment
post_or_update_comment(args.pr_number, comment_body)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,281 +0,0 @@
#!/usr/bin/env python3
"""Extract memory usage statistics from ESPHome build output.
This script parses the PlatformIO build output to extract RAM and flash
usage statistics for a compiled component. It's used by the CI workflow to
compare memory usage between branches.
The script reads compile output from stdin and looks for the standard
PlatformIO output format:
RAM: [==== ] 36.1% (used 29548 bytes from 81920 bytes)
Flash: [=== ] 34.0% (used 348511 bytes from 1023984 bytes)
Optionally performs detailed memory analysis if a build directory is provided.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import re
import sys
# Add esphome to path
sys.path.insert(0, str(Path(__file__).parent.parent))
# pylint: disable=wrong-import-position
from esphome.analyze_memory import MemoryAnalyzer
from esphome.platformio_api import IDEData
from script.ci_helpers import write_github_output
# Regex patterns for extracting memory usage from PlatformIO output
_RAM_PATTERN = re.compile(r"RAM:\s+\[.*?\]\s+\d+\.\d+%\s+\(used\s+(\d+)\s+bytes")
_FLASH_PATTERN = re.compile(r"Flash:\s+\[.*?\]\s+\d+\.\d+%\s+\(used\s+(\d+)\s+bytes")
_BUILD_PATH_PATTERN = re.compile(r"Build path: (.+)")
def extract_from_compile_output(
output_text: str,
) -> tuple[int | None, int | None, str | None]:
"""Extract memory usage and build directory from PlatformIO compile output.
Supports multiple builds (for component groups or isolated components).
When test_build_components.py creates multiple builds, this sums the
memory usage across all builds.
Looks for lines like:
RAM: [==== ] 36.1% (used 29548 bytes from 81920 bytes)
Flash: [=== ] 34.0% (used 348511 bytes from 1023984 bytes)
Also extracts build directory from lines like:
INFO Compiling app... Build path: /path/to/build
Args:
output_text: Compile output text (may contain multiple builds)
Returns:
Tuple of (total_ram_bytes, total_flash_bytes, build_dir) or (None, None, None) if not found
"""
# Find all RAM and Flash matches (may be multiple builds)
ram_matches = _RAM_PATTERN.findall(output_text)
flash_matches = _FLASH_PATTERN.findall(output_text)
if not ram_matches or not flash_matches:
return None, None, None
# Sum all builds (handles multiple component groups)
total_ram = sum(int(match) for match in ram_matches)
total_flash = sum(int(match) for match in flash_matches)
# Extract build directory from ESPHome's explicit build path output
# Look for: INFO Compiling app... Build path: /path/to/build
# Note: Multiple builds reuse the same build path (each overwrites the previous)
build_dir = None
if match := _BUILD_PATH_PATTERN.search(output_text):
build_dir = match.group(1).strip()
return total_ram, total_flash, build_dir
def run_detailed_analysis(build_dir: str) -> dict | None:
"""Run detailed memory analysis on build directory.
Args:
build_dir: Path to ESPHome build directory
Returns:
Dictionary with analysis results or None if analysis fails
"""
build_path = Path(build_dir)
if not build_path.exists():
print(f"Build directory not found: {build_dir}", file=sys.stderr)
return None
# Find firmware.elf
elf_path = None
for elf_candidate in [
build_path / "firmware.elf",
build_path / ".pioenvs" / build_path.name / "firmware.elf",
]:
if elf_candidate.exists():
elf_path = str(elf_candidate)
break
if not elf_path:
print(f"firmware.elf not found in {build_dir}", file=sys.stderr)
return None
# Find idedata.json - check multiple locations
device_name = build_path.name
idedata_candidates = [
# In .pioenvs for test builds
build_path / ".pioenvs" / device_name / "idedata.json",
# In .esphome/idedata for regular builds
Path.home() / ".esphome" / "idedata" / f"{device_name}.json",
# Check parent directories for .esphome/idedata (for test_build_components)
build_path.parent.parent.parent / "idedata" / f"{device_name}.json",
]
idedata = None
for idedata_path in idedata_candidates:
if not idedata_path.exists():
continue
try:
with open(idedata_path, encoding="utf-8") as f:
raw_data = json.load(f)
idedata = IDEData(raw_data)
print(f"Loaded idedata from: {idedata_path}", file=sys.stderr)
break
except (json.JSONDecodeError, OSError) as e:
print(
f"Warning: Failed to load idedata from {idedata_path}: {e}",
file=sys.stderr,
)
analyzer = MemoryAnalyzer(elf_path, idedata=idedata)
components = analyzer.analyze()
# Convert to JSON-serializable format
result = {
"components": {
name: {
"text": mem.text_size,
"rodata": mem.rodata_size,
"data": mem.data_size,
"bss": mem.bss_size,
"flash_total": mem.flash_total,
"ram_total": mem.ram_total,
"symbol_count": mem.symbol_count,
}
for name, mem in components.items()
},
"symbols": {},
}
# Build symbol map
for section in analyzer.sections.values():
for symbol_name, size, _ in section.symbols:
if size > 0:
demangled = analyzer._demangle_symbol(symbol_name)
result["symbols"][demangled] = size
return result
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Extract memory usage from ESPHome build output"
)
parser.add_argument(
"--output-env",
action="store_true",
help="Output to GITHUB_OUTPUT environment file",
)
parser.add_argument(
"--build-dir",
help="Optional build directory for detailed memory analysis (overrides auto-detection)",
)
parser.add_argument(
"--output-json",
help="Optional path to save detailed analysis JSON",
)
parser.add_argument(
"--output-build-dir",
help="Optional path to write the detected build directory",
)
args = parser.parse_args()
# Read compile output from stdin
compile_output = sys.stdin.read()
# Extract memory usage and build directory
ram_bytes, flash_bytes, detected_build_dir = extract_from_compile_output(
compile_output
)
if ram_bytes is None or flash_bytes is None:
print("Failed to extract memory usage from compile output", file=sys.stderr)
print("Expected lines like:", file=sys.stderr)
print(
" RAM: [==== ] 36.1% (used 29548 bytes from 81920 bytes)",
file=sys.stderr,
)
print(
" Flash: [=== ] 34.0% (used 348511 bytes from 1023984 bytes)",
file=sys.stderr,
)
return 1
# Count how many builds were found
num_builds = len(_RAM_PATTERN.findall(compile_output))
if num_builds > 1:
print(
f"Found {num_builds} builds - summing memory usage across all builds",
file=sys.stderr,
)
print(
"WARNING: Detailed analysis will only cover the last build",
file=sys.stderr,
)
print(f"Total RAM: {ram_bytes} bytes", file=sys.stderr)
print(f"Total Flash: {flash_bytes} bytes", file=sys.stderr)
# Determine which build directory to use (explicit arg overrides auto-detection)
build_dir = args.build_dir or detected_build_dir
if detected_build_dir:
print(f"Detected build directory: {detected_build_dir}", file=sys.stderr)
if num_builds > 1:
print(
f" (using last of {num_builds} builds for detailed analysis)",
file=sys.stderr,
)
# Write build directory to file if requested
if args.output_build_dir and build_dir:
build_dir_path = Path(args.output_build_dir)
build_dir_path.parent.mkdir(parents=True, exist_ok=True)
build_dir_path.write_text(build_dir)
print(f"Wrote build directory to {args.output_build_dir}", file=sys.stderr)
# Run detailed analysis if build directory available
detailed_analysis = None
if build_dir:
print(f"Running detailed analysis on {build_dir}", file=sys.stderr)
detailed_analysis = run_detailed_analysis(build_dir)
# Save JSON output if requested
if args.output_json:
output_data = {
"ram_bytes": ram_bytes,
"flash_bytes": flash_bytes,
"detailed_analysis": detailed_analysis,
}
output_path = Path(args.output_json)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2)
print(f"Saved analysis to {args.output_json}", file=sys.stderr)
if args.output_env:
# Output to GitHub Actions
write_github_output(
{
"ram_usage": ram_bytes,
"flash_usage": flash_bytes,
}
)
else:
print(f"{ram_bytes},{flash_bytes}")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -10,13 +10,7 @@ what files have changed. It outputs JSON with the following structure:
"clang_format": true/false,
"python_linters": true/false,
"changed_components": ["component1", "component2", ...],
"component_test_count": 5,
"memory_impact": {
"should_run": "true/false",
"components": ["component1", "component2", ...],
"platform": "esp32-idf",
"use_merged_config": "true"
}
"component_test_count": 5
}
The CI workflow uses this information to:
@@ -26,7 +20,6 @@ The CI workflow uses this information to:
- Skip or run Python linters (ruff, flake8, pylint, pyupgrade)
- Determine which components to test individually
- Decide how to split component tests (if there are many)
- Run memory impact analysis whenever there are changed components (merged config), and also for core-only changes
Usage:
python script/determine-jobs.py [-b BRANCH]
@@ -38,8 +31,6 @@ Options:
from __future__ import annotations
import argparse
from collections import Counter
from enum import StrEnum
from functools import cache
import json
import os
@@ -49,47 +40,16 @@ import sys
from typing import Any
from helpers import (
BASE_BUS_COMPONENTS,
CPP_FILE_EXTENSIONS,
ESPHOME_COMPONENTS_PATH,
PYTHON_FILE_EXTENSIONS,
changed_files,
get_all_dependencies,
get_component_from_path,
get_component_test_files,
get_components_from_integration_fixtures,
parse_test_filename,
root_path,
)
class Platform(StrEnum):
"""Platform identifiers for memory impact analysis."""
ESP8266_ARD = "esp8266-ard"
ESP32_IDF = "esp32-idf"
ESP32_C3_IDF = "esp32-c3-idf"
ESP32_C6_IDF = "esp32-c6-idf"
ESP32_S2_IDF = "esp32-s2-idf"
ESP32_S3_IDF = "esp32-s3-idf"
# Memory impact analysis constants
MEMORY_IMPACT_FALLBACK_COMPONENT = "api" # Representative component for core changes
MEMORY_IMPACT_FALLBACK_PLATFORM = Platform.ESP32_IDF # Most representative platform
# Platform preference order for memory impact analysis
# Prefer newer platforms first as they represent the future of ESPHome
# ESP8266 is most constrained but many new features don't support it
MEMORY_IMPACT_PLATFORM_PREFERENCE = [
Platform.ESP32_C6_IDF, # ESP32-C6 IDF (newest, supports Thread/Zigbee)
Platform.ESP8266_ARD, # ESP8266 Arduino (most memory constrained - best for impact analysis)
Platform.ESP32_IDF, # ESP32 IDF platform (primary ESP32 platform, most representative)
Platform.ESP32_C3_IDF, # ESP32-C3 IDF
Platform.ESP32_S2_IDF, # ESP32-S2 IDF
Platform.ESP32_S3_IDF, # ESP32-S3 IDF
]
def should_run_integration_tests(branch: str | None = None) -> bool:
"""Determine if integration tests should run based on changed files.
@@ -145,9 +105,12 @@ def should_run_integration_tests(branch: str | None = None) -> bool:
# Check if any required components changed
for file in files:
component = get_component_from_path(file)
if component and component in all_required_components:
return True
if file.startswith(ESPHOME_COMPONENTS_PATH):
parts = file.split("/")
if len(parts) >= 3:
component = parts[2]
if component in all_required_components:
return True
return False
@@ -261,136 +224,10 @@ def _component_has_tests(component: str) -> bool:
Returns:
True if the component has test YAML files
"""
return bool(get_component_test_files(component))
def detect_memory_impact_config(
branch: str | None = None,
) -> dict[str, Any]:
"""Determine memory impact analysis configuration.
Always runs memory impact analysis when there are changed components,
building a merged configuration with all changed components (like
test_build_components.py does) to get comprehensive memory analysis.
Args:
branch: Branch to compare against
Returns:
Dictionary with memory impact analysis parameters:
- should_run: "true" or "false"
- components: list of component names to analyze
- platform: platform name for the merged build
- use_merged_config: "true" (always use merged config)
"""
# Get actually changed files (not dependencies)
files = changed_files(branch)
# Find all changed components (excluding core and base bus components)
changed_component_set: set[str] = set()
has_core_changes = False
for file in files:
component = get_component_from_path(file)
if component:
# Skip base bus components as they're used across many builds
if component not in BASE_BUS_COMPONENTS:
changed_component_set.add(component)
elif file.startswith("esphome/"):
# Core ESPHome files changed (not component-specific)
has_core_changes = True
# If no components changed but core changed, test representative component
force_fallback_platform = False
if not changed_component_set and has_core_changes:
print(
f"Memory impact: No components changed, but core files changed. "
f"Testing {MEMORY_IMPACT_FALLBACK_COMPONENT} component on {MEMORY_IMPACT_FALLBACK_PLATFORM}.",
file=sys.stderr,
)
changed_component_set.add(MEMORY_IMPACT_FALLBACK_COMPONENT)
force_fallback_platform = True # Use fallback platform (most representative)
elif not changed_component_set:
# No components and no core changes
return {"should_run": "false"}
# Find components that have tests and collect their supported platforms
components_with_tests: list[str] = []
component_platforms_map: dict[
str, set[Platform]
] = {} # Track which platforms each component supports
for component in sorted(changed_component_set):
# Look for test files on preferred platforms
test_files = get_component_test_files(component)
if not test_files:
continue
# Check if component has tests for any preferred platform
available_platforms = [
platform
for test_file in test_files
if (platform := parse_test_filename(test_file)[1]) != "all"
and platform in MEMORY_IMPACT_PLATFORM_PREFERENCE
]
if not available_platforms:
continue
component_platforms_map[component] = set(available_platforms)
components_with_tests.append(component)
# If no components have tests, don't run memory impact
if not components_with_tests:
return {"should_run": "false"}
# Find common platforms supported by ALL components
# This ensures we can build all components together in a merged config
common_platforms = set(MEMORY_IMPACT_PLATFORM_PREFERENCE)
for component, platforms in component_platforms_map.items():
common_platforms &= platforms
# Select the most preferred platform from the common set
# Exception: for core changes, use fallback platform (most representative of codebase)
if force_fallback_platform:
platform = MEMORY_IMPACT_FALLBACK_PLATFORM
elif common_platforms:
# Pick the most preferred platform that all components support
platform = min(common_platforms, key=MEMORY_IMPACT_PLATFORM_PREFERENCE.index)
else:
# No common platform - pick the most commonly supported platform
# This allows testing components individually even if they can't be merged
# Count how many components support each platform
platform_counts = Counter(
p for platforms in component_platforms_map.values() for p in platforms
)
# Pick the platform supported by most components, preferring earlier in MEMORY_IMPACT_PLATFORM_PREFERENCE
platform = max(
platform_counts.keys(),
key=lambda p: (
platform_counts[p],
-MEMORY_IMPACT_PLATFORM_PREFERENCE.index(p),
),
)
# Debug output
print("Memory impact analysis:", file=sys.stderr)
print(f" Changed components: {sorted(changed_component_set)}", file=sys.stderr)
print(f" Components with tests: {components_with_tests}", file=sys.stderr)
print(
f" Component platforms: {dict(sorted(component_platforms_map.items()))}",
file=sys.stderr,
)
print(f" Common platforms: {sorted(common_platforms)}", file=sys.stderr)
print(f" Selected platform: {platform}", file=sys.stderr)
return {
"should_run": "true",
"components": components_with_tests,
"platform": platform,
"use_merged_config": "true",
}
tests_dir = Path(root_path) / "tests" / "components" / component
if not tests_dir.exists():
return False
return any(tests_dir.glob("test.*.yaml"))
def main() -> None:
@@ -442,9 +279,6 @@ def main() -> None:
if component not in directly_changed_components
]
# Detect components for memory impact analysis (merged config)
memory_impact = detect_memory_impact_config(args.branch)
# Build output
output: dict[str, Any] = {
"integration_tests": run_integration,
@@ -458,7 +292,6 @@ def main() -> None:
"component_test_count": len(changed_components_with_tests),
"directly_changed_count": len(directly_changed_with_tests),
"dependency_only_count": len(dependency_only_components),
"memory_impact": memory_impact,
}
# Output as JSON

View File

@@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
GitHub Download Cache CLI
This script provides a command-line interface to the GitHub download cache.
The actual caching logic is in esphome/github_cache.py.
Usage:
python3 script/github_download_cache.py download URL
python3 script/github_download_cache.py list
python3 script/github_download_cache.py stats
python3 script/github_download_cache.py clear
"""
import argparse
from pathlib import Path
import sys
import urllib.request
# Add parent directory to path to import esphome modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from esphome.github_cache import GitHubCache
def download_with_progress(
cache: GitHubCache, url: str, force: bool = False, check_updates: bool = True
) -> Path:
"""Download a URL with progress indicator and caching.
Args:
cache: GitHubCache instance
url: URL to download
force: Force re-download even if cached
check_updates: Check for updates using HTTP 304
Returns:
Path to cached file
"""
# If force, skip cache check
if not force:
cached_path = cache.get_cached_path(url, check_updates=check_updates)
if cached_path:
print(f"Using cached file for {url}")
print(f" Cache: {cached_path}")
return cached_path
# Need to download
print(f"Downloading {url}")
cache_path = cache._get_cache_path(url)
print(f" Cache: {cache_path}")
# Download with progress
temp_path = cache_path.with_suffix(cache_path.suffix + ".tmp")
try:
with urllib.request.urlopen(url) as response:
total_size = int(response.headers.get("Content-Length", 0))
downloaded = 0
with open(temp_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f"\r Progress: {percent:.1f}%", end="", flush=True)
print() # New line after progress
# Move to final location
temp_path.replace(cache_path)
# Let cache handle metadata
cache.save_to_cache(url, cache_path)
return cache_path
except (OSError, urllib.error.URLError) as e:
if temp_path.exists():
temp_path.unlink()
raise RuntimeError(f"Failed to download {url}: {e}") from e
def main():
"""CLI entry point."""
parser = argparse.ArgumentParser(
description="GitHub Download Cache Manager",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Download and cache a URL
%(prog)s download https://github.com/pioarduino/registry/releases/download/0.0.1/esptoolpy-v5.1.0.zip
# List cached files
%(prog)s list
# Show cache statistics
%(prog)s stats
# Clear cache
%(prog)s clear
""",
)
parser.add_argument(
"--cache-dir",
type=Path,
help="Cache directory (default: ~/.platformio/esphome_download_cache)",
)
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
# Download command
download_parser = subparsers.add_parser("download", help="Download and cache a URL")
download_parser.add_argument("url", help="URL to download")
download_parser.add_argument(
"--force", action="store_true", help="Force re-download even if cached"
)
download_parser.add_argument(
"--no-check-updates",
action="store_true",
help="Skip checking for updates (don't use HTTP 304)",
)
# List command
subparsers.add_parser("list", help="List cached files")
# Stats command
subparsers.add_parser("stats", help="Show cache statistics")
# Clear command
subparsers.add_parser("clear", help="Clear all cached files")
args = parser.parse_args()
if not args.command:
parser.print_help()
return 1
# Use PlatformIO cache directory by default
if args.cache_dir is None:
args.cache_dir = Path.home() / ".platformio" / "esphome_download_cache"
cache = GitHubCache(args.cache_dir)
if args.command == "download":
try:
check_updates = not args.no_check_updates
cache_path = download_with_progress(
cache, args.url, force=args.force, check_updates=check_updates
)
print(f"\nCached at: {cache_path}")
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
elif args.command == "list":
cached = cache.list_cached()
if not cached:
print("No cached files")
return 0
print(f"Cached files ({len(cached)}):")
for item in cached:
size_mb = item["size"] / (1024 * 1024)
print(f" {item['url']}")
print(f" Size: {size_mb:.2f} MB")
print(f" Path: {item['path']}")
return 0
elif args.command == "stats":
total_size = cache.cache_size()
cached_count = len(cache.list_cached())
size_mb = total_size / (1024 * 1024)
print(f"Cache directory: {cache.cache_dir}")
print(f"Cached files: {cached_count}")
print(f"Total size: {size_mb:.2f} MB")
return 0
elif args.command == "clear":
cache.clear_cache()
return 0
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -29,18 +29,6 @@ YAML_FILE_EXTENSIONS = (".yaml", ".yml")
# Component path prefix
ESPHOME_COMPONENTS_PATH = "esphome/components/"
# Base bus components - these ARE the bus implementations and should not
# be flagged as needing migration since they are the platform/base components
BASE_BUS_COMPONENTS = {
"i2c",
"spi",
"uart",
"modbus",
"canbus",
"remote_transmitter",
"remote_receiver",
}
def parse_list_components_output(output: str) -> list[str]:
"""Parse the output from list-components.py script.
@@ -58,65 +46,6 @@ def parse_list_components_output(output: str) -> list[str]:
return [c.strip() for c in output.strip().split("\n") if c.strip()]
def parse_test_filename(test_file: Path) -> tuple[str, str]:
"""Parse test filename to extract test name and platform.
Test files follow the naming pattern: test.<platform>.yaml or test-<variant>.<platform>.yaml
Args:
test_file: Path to test file
Returns:
Tuple of (test_name, platform)
"""
parts = test_file.stem.split(".")
if len(parts) == 2:
return parts[0], parts[1] # test, platform
return parts[0], "all"
def get_component_from_path(file_path: str) -> str | None:
"""Extract component name from a file path.
Args:
file_path: Path to a file (e.g., "esphome/components/wifi/wifi.cpp")
Returns:
Component name if path is in components directory, None otherwise
"""
if not file_path.startswith(ESPHOME_COMPONENTS_PATH):
return None
parts = file_path.split("/")
if len(parts) >= 3:
return parts[2]
return None
def get_component_test_files(
component: str, *, all_variants: bool = False
) -> list[Path]:
"""Get test files for a component.
Args:
component: Component name (e.g., "wifi")
all_variants: If True, returns all test files including variants (test-*.yaml).
If False, returns only base test files (test.*.yaml).
Default is False.
Returns:
List of test file paths for the component, or empty list if none exist
"""
tests_dir = Path(root_path) / "tests" / "components" / component
if not tests_dir.exists():
return []
if all_variants:
# Match both test.*.yaml and test-*.yaml patterns
return list(tests_dir.glob("test[.-]*.yaml"))
# Match only test.*.yaml (base tests)
return list(tests_dir.glob("test.*.yaml"))
def styled(color: str | tuple[str, ...], msg: str, reset: bool = True) -> str:
prefix = "".join(color) if isinstance(color, tuple) else color
suffix = colorama.Style.RESET_ALL if reset else ""
@@ -385,9 +314,11 @@ def _filter_changed_ci(files: list[str]) -> list[str]:
# because changes in one file can affect other files in the same component.
filtered_files = []
for f in files:
component = get_component_from_path(f)
if component and component in component_set:
filtered_files.append(f)
if f.startswith(ESPHOME_COMPONENTS_PATH):
# Check if file belongs to any of the changed components
parts = f.split("/")
if len(parts) >= 3 and parts[2] in component_set:
filtered_files.append(f)
return filtered_files

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable
from pathlib import Path
import sys
from helpers import changed_files, get_component_from_path, git_ls_files
from helpers import changed_files, git_ls_files
from esphome.const import (
KEY_CORE,
@@ -30,9 +30,11 @@ def get_all_component_files() -> list[str]:
def extract_component_names_array_from_files_array(files):
components = []
for file in files:
component_name = get_component_from_path(file)
if component_name and component_name not in components:
components.append(component_name)
file_parts = file.split("/")
if len(file_parts) >= 4:
component_name = file_parts[2]
if component_name not in components:
components.append(component_name)
return components

View File

@@ -1,138 +0,0 @@
#!/usr/bin/env python3
"""
PlatformIO Download Wrapper with Caching
This script can be used as a wrapper around PlatformIO downloads to add caching.
It intercepts download operations and uses the GitHub download cache.
This is designed to be called from PlatformIO's extra_scripts if needed.
"""
from pathlib import Path
import sys
# Import the cache manager
sys.path.insert(0, str(Path(__file__).parent))
from github_download_cache import GitHubDownloadCache
def is_github_url(url: str) -> bool:
"""Check if a URL is a GitHub URL."""
return "github.com" in url.lower()
def cached_download_handler(source, target, env):
"""PlatformIO download handler that uses caching for GitHub URLs.
This function can be registered as a custom download handler in PlatformIO.
Args:
source: Source URL
target: Target file path
env: SCons environment
"""
import shutil
import urllib.request
url = str(source[0])
target_path = Path(str(target[0]))
# Only cache GitHub URLs
if not is_github_url(url):
# Fall back to default download
print(f"Downloading (no cache): {url}")
with (
urllib.request.urlopen(url) as response,
open(target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
return
# Use cache for GitHub URLs
cache = GitHubDownloadCache()
print(f"Downloading with cache: {url}")
try:
cached_path = cache.download_with_cache(url, check_updates=True)
# Copy from cache to target
shutil.copy2(cached_path, target_path)
print(f" Copied to: {target_path}")
except Exception as e:
print(f"Cache download failed, using direct download: {e}")
# Fall back to direct download
with (
urllib.request.urlopen(url) as response,
open(target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
def setup_platformio_caching():
"""Setup PlatformIO to use cached downloads.
This should be called from an extra_scripts file in platformio.ini.
Example extra_scripts file (e.g., platformio_hooks.py):
Import("env")
from script.platformio_download_wrapper import setup_platformio_caching
setup_platformio_caching()
"""
try:
from SCons.Script import DefaultEnvironment
DefaultEnvironment()
# Register custom download handler
# Note: This may not work with all PlatformIO versions
# as the download mechanism is internal
print("Note: Direct download interception is not fully supported.")
print("Please use the cache_platformio_downloads.py script instead.")
except ImportError:
print("Warning: SCons not available, cannot setup download caching")
if __name__ == "__main__":
# CLI mode - can be used to manually download a URL with caching
import argparse
parser = argparse.ArgumentParser(description="Download a URL with caching")
parser.add_argument("url", help="URL to download")
parser.add_argument("target", help="Target file path")
parser.add_argument("--cache-dir", type=Path, help="Cache directory")
args = parser.parse_args()
cache = GitHubDownloadCache(args.cache_dir)
target_path = Path(args.target)
try:
if is_github_url(args.url):
print(f"Downloading with cache: {args.url}")
cached_path = cache.download_with_cache(args.url)
# Copy to target
import shutil
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(cached_path, target_path)
print(f"Copied to: {target_path}")
else:
print(f"Downloading directly (not a GitHub URL): {args.url}")
import shutil
import urllib.request
target_path.parent.mkdir(parents=True, exist_ok=True)
with (
urllib.request.urlopen(args.url) as response,
open(target_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
sys.exit(0)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

View File

@@ -28,7 +28,6 @@ from script.analyze_component_buses import (
create_grouping_signature,
merge_compatible_bus_groups,
)
from script.helpers import get_component_test_files
# Weighting for batch creation
# Isolated components can't be grouped/merged, so they count as 10x
@@ -46,12 +45,17 @@ def has_test_files(component_name: str, tests_dir: Path) -> bool:
Args:
component_name: Name of the component
tests_dir: Path to tests/components directory (unused, kept for compatibility)
tests_dir: Path to tests/components directory
Returns:
True if the component has test.*.yaml files
"""
return bool(get_component_test_files(component_name))
component_dir = tests_dir / component_name
if not component_dir.exists() or not component_dir.is_dir():
return False
# Check for test.*.yaml files
return any(component_dir.glob("test.*.yaml"))
def create_intelligent_batches(

View File

@@ -1,27 +0,0 @@
{{ comment_marker }}
## Memory Impact Analysis
**Components:** {{ components_str }}
**Platform:** `{{ platform }}`
| Metric | Target Branch | This PR | Change |
|--------|--------------|---------|--------|
| **RAM** | {{ target_ram }} | {{ pr_ram }} | {{ ram_change }} |
| **Flash** | {{ target_flash }} | {{ pr_flash }} | {{ flash_change }} |
{% if component_breakdown %}
{{ component_breakdown }}
{% endif %}
{% if symbol_changes %}
{{ symbol_changes }}
{% endif %}
{%- if target_cache_hit %}
> ⚡ Target branch analysis was loaded from cache (build skipped for faster CI).
{%- endif %}
---
> **Note:** This analysis measures **static RAM and Flash usage** only (compile-time allocation).
> **Dynamic memory (heap)** cannot be measured automatically.
> **⚠️ You must test this PR on a real device** to measure free heap and ensure no runtime memory issues.
*This analysis runs automatically when components change. Memory usage is measured from {{ config_note }}.*

View File

@@ -1,15 +0,0 @@
<details open>
<summary>📊 Component Memory Breakdown</summary>
| Component | Target Flash | PR Flash | Change |
|-----------|--------------|----------|--------|
{% for comp, target_flash, pr_flash, delta in changed_components[:max_rows] -%}
{% set threshold = component_change_threshold if comp.startswith("[esphome]") else none -%}
| `{{ comp }}` | {{ target_flash|format_bytes }} | {{ pr_flash|format_bytes }} | {{ format_change(target_flash, pr_flash, threshold=threshold) }} |
{% endfor -%}
{% if changed_components|length > max_rows -%}
| ... | ... | ... | *({{ changed_components|length - max_rows }} more components not shown)* |
{% endif -%}
</details>

View File

@@ -1,8 +0,0 @@
{#- Macro for formatting symbol names in tables -#}
{%- macro format_symbol(symbol, max_length, truncate_length) -%}
{%- if symbol|length <= max_length -%}
`{{ symbol }}`
{%- else -%}
<details><summary><code>{{ symbol[:truncate_length] }}...</code></summary><code>{{ symbol }}</code></details>
{%- endif -%}
{%- endmacro -%}

View File

@@ -1,51 +0,0 @@
{%- from 'ci_memory_impact_macros.j2' import format_symbol -%}
<details>
<summary>🔍 Symbol-Level Changes (click to expand)</summary>
{% if changed_symbols %}
### Changed Symbols
| Symbol | Target Size | PR Size | Change |
|--------|-------------|---------|--------|
{% for symbol, target_size, pr_size, delta in changed_symbols[:max_changed_rows] -%}
| {{ format_symbol(symbol, symbol_max_length, symbol_truncate_length) }} | {{ target_size|format_bytes }} | {{ pr_size|format_bytes }} | {{ format_change(target_size, pr_size) }} |
{% endfor -%}
{% if changed_symbols|length > max_changed_rows -%}
| ... | ... | ... | *({{ changed_symbols|length - max_changed_rows }} more changed symbols not shown)* |
{% endif -%}
{% endif %}
{% if new_symbols %}
### New Symbols (top {{ max_new_rows }})
| Symbol | Size |
|--------|------|
{% for symbol, size in new_symbols[:max_new_rows] -%}
| {{ format_symbol(symbol, symbol_max_length, symbol_truncate_length) }} | {{ size|format_bytes }} |
{% endfor -%}
{% if new_symbols|length > max_new_rows -%}
{% set total_new_size = new_symbols|sum(attribute=1) -%}
| *{{ new_symbols|length - max_new_rows }} more new symbols...* | *Total: {{ total_new_size|format_bytes }}* |
{% endif -%}
{% endif %}
{% if removed_symbols %}
### Removed Symbols (top {{ max_removed_rows }})
| Symbol | Size |
|--------|------|
{% for symbol, size in removed_symbols[:max_removed_rows] -%}
| {{ format_symbol(symbol, symbol_max_length, symbol_truncate_length) }} | {{ size|format_bytes }} |
{% endfor -%}
{% if removed_symbols|length > max_removed_rows -%}
{% set total_removed_size = removed_symbols|sum(attribute=1) -%}
| *{{ removed_symbols|length - max_removed_rows }} more removed symbols...* | *Total: {{ total_removed_size|format_bytes }}* |
{% endif -%}
{% endif %}
</details>

View File

@@ -39,7 +39,6 @@ from script.analyze_component_buses import (
merge_compatible_bus_groups,
uses_local_file_references,
)
from script.helpers import get_component_test_files
from script.merge_component_configs import merge_component_configs
@@ -83,14 +82,13 @@ def show_disk_space_if_ci(esphome_command: str) -> None:
def find_component_tests(
components_dir: Path, component_pattern: str = "*", base_only: bool = False
components_dir: Path, component_pattern: str = "*"
) -> dict[str, list[Path]]:
"""Find all component test files.
Args:
components_dir: Path to tests/components directory
component_pattern: Glob pattern for component names
base_only: If True, only find base test files (test.*.yaml), not variant files (test-*.yaml)
Returns:
Dictionary mapping component name to list of test files
@@ -101,10 +99,9 @@ def find_component_tests(
if not comp_dir.is_dir():
continue
# Get test files using helper function
test_files = get_component_test_files(comp_dir.name, all_variants=not base_only)
if test_files:
component_tests[comp_dir.name] = test_files
# Find test files matching test.*.yaml or test-*.yaml patterns
for test_file in comp_dir.glob("test[.-]*.yaml"):
component_tests[comp_dir.name].append(test_file)
return dict(component_tests)
@@ -934,7 +931,6 @@ def test_components(
continue_on_fail: bool,
enable_grouping: bool = True,
isolated_components: set[str] | None = None,
base_only: bool = False,
) -> int:
"""Test components with optional intelligent grouping.
@@ -948,7 +944,6 @@ def test_components(
These are tested WITHOUT --testing-mode to enable full validation
(pin conflicts, etc). This is used in CI for directly changed components
to catch issues that would be missed with --testing-mode.
base_only: If True, only test base test files (test.*.yaml), not variant files (test-*.yaml)
Returns:
Exit code (0 for success, 1 for failure)
@@ -966,7 +961,7 @@ def test_components(
# Find all component tests
all_tests = {}
for pattern in component_patterns:
all_tests.update(find_component_tests(tests_dir, pattern, base_only))
all_tests.update(find_component_tests(tests_dir, pattern))
if not all_tests:
print(f"No components found matching: {component_patterns}")
@@ -1127,11 +1122,6 @@ def main() -> int:
"These are tested WITHOUT --testing-mode to enable full validation. "
"Used in CI for directly changed components to catch pin conflicts and other issues.",
)
parser.add_argument(
"--base-only",
action="store_true",
help="Only test base test files (test.*.yaml), not variant files (test-*.yaml)",
)
args = parser.parse_args()
@@ -1150,7 +1140,6 @@ def main() -> int:
continue_on_fail=args.continue_on_fail,
enable_grouping=not args.no_grouping,
isolated_components=isolated_components,
base_only=args.base_only,
)

View File

@@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
import pytest
from esphome.components.packages import do_packages_pass
from esphome.config import resolve_extend_remove
from esphome.config_helpers import Extend, Remove
import esphome.config_validation as cv
from esphome.const import (
@@ -65,20 +64,13 @@ def fixture_basic_esphome():
return {CONF_NAME: TEST_DEVICE_NAME, CONF_PLATFORM: TEST_PLATFORM}
def packages_pass(config):
"""Wrapper around packages_pass that also resolves Extend and Remove."""
config = do_packages_pass(config)
resolve_extend_remove(config)
return config
def test_package_unused(basic_esphome, basic_wifi):
"""
Ensures do_package_pass does not change a config if packages aren't used.
"""
config = {CONF_ESPHOME: basic_esphome, CONF_WIFI: basic_wifi}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == config
@@ -91,7 +83,7 @@ def test_package_invalid_dict(basic_esphome, basic_wifi):
config = {CONF_ESPHOME: basic_esphome, CONF_PACKAGES: basic_wifi | {CONF_URL: ""}}
with pytest.raises(cv.Invalid):
packages_pass(config)
do_packages_pass(config)
def test_package_include(basic_wifi, basic_esphome):
@@ -107,7 +99,7 @@ def test_package_include(basic_wifi, basic_esphome):
expected = {CONF_ESPHOME: basic_esphome, CONF_WIFI: basic_wifi}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -132,7 +124,7 @@ def test_package_append(basic_wifi, basic_esphome):
},
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -156,7 +148,7 @@ def test_package_override(basic_wifi, basic_esphome):
},
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -185,7 +177,7 @@ def test_multiple_package_order():
},
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -241,7 +233,7 @@ def test_package_list_merge():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -319,7 +311,7 @@ def test_package_list_merge_by_id():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -358,13 +350,13 @@ def test_package_merge_by_id_with_list():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
def test_package_merge_by_missing_id():
"""
Ensures that a validation error is thrown when trying to extend a missing ID.
Ensures that components with missing IDs are not merged.
"""
config = {
@@ -387,15 +379,25 @@ def test_package_merge_by_missing_id():
],
}
error_raised = False
try:
packages_pass(config)
assert False, "Expected validation error for missing ID"
except cv.Invalid as err:
error_raised = True
assert err.path == [CONF_SENSOR, 2]
expected = {
CONF_SENSOR: [
{
CONF_ID: TEST_SENSOR_ID_1,
CONF_FILTERS: [{CONF_MULTIPLY: 42.0}],
},
{
CONF_ID: TEST_SENSOR_ID_1,
CONF_FILTERS: [{CONF_MULTIPLY: 10.0}],
},
{
CONF_ID: Extend(TEST_SENSOR_ID_2),
CONF_FILTERS: [{CONF_OFFSET: 146.0}],
},
]
}
assert error_raised
actual = do_packages_pass(config)
assert actual == expected
def test_package_list_remove_by_id():
@@ -445,7 +447,7 @@ def test_package_list_remove_by_id():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -491,7 +493,7 @@ def test_multiple_package_list_remove_by_id():
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -512,7 +514,7 @@ def test_package_dict_remove_by_id(basic_wifi, basic_esphome):
CONF_ESPHOME: basic_esphome,
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -543,6 +545,7 @@ def test_package_remove_by_missing_id():
}
expected = {
"missing_key": Remove(),
CONF_SENSOR: [
{
CONF_ID: TEST_SENSOR_ID_1,
@@ -552,10 +555,14 @@ def test_package_remove_by_missing_id():
CONF_ID: TEST_SENSOR_ID_1,
CONF_FILTERS: [{CONF_MULTIPLY: 10.0}],
},
{
CONF_ID: Remove(TEST_SENSOR_ID_2),
CONF_FILTERS: [{CONF_OFFSET: 146.0}],
},
],
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -627,7 +634,7 @@ def test_remote_packages_with_files_list(
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected
@@ -723,5 +730,5 @@ def test_remote_packages_with_files_and_vars(
]
}
actual = packages_pass(config)
actual = do_packages_pass(config)
assert actual == expected

View File

@@ -1,101 +0,0 @@
sensor:
# Source sensor for testing filters
- platform: template
name: "Source Sensor"
id: source_sensor
lambda: return 42.0;
update_interval: 1s
# Streaming filters (window_size == send_every) - uses StreamingFilter base class
- platform: copy
source_id: source_sensor
name: "Streaming Min Filter"
filters:
- min:
window_size: 10
send_every: 10 # Batch window → StreamingMinFilter
- platform: copy
source_id: source_sensor
name: "Streaming Max Filter"
filters:
- max:
window_size: 10
send_every: 10 # Batch window → StreamingMaxFilter
- platform: copy
source_id: source_sensor
name: "Streaming Moving Average Filter"
filters:
- sliding_window_moving_average:
window_size: 10
send_every: 10 # Batch window → StreamingMovingAverageFilter
# Sliding window filters (window_size != send_every) - uses SlidingWindowFilter base class with ring buffer
- platform: copy
source_id: source_sensor
name: "Sliding Min Filter"
filters:
- min:
window_size: 10
send_every: 5 # Sliding window → MinFilter with ring buffer
- platform: copy
source_id: source_sensor
name: "Sliding Max Filter"
filters:
- max:
window_size: 10
send_every: 5 # Sliding window → MaxFilter with ring buffer
- platform: copy
source_id: source_sensor
name: "Sliding Median Filter"
filters:
- median:
window_size: 10
send_every: 5 # Sliding window → MedianFilter with ring buffer
- platform: copy
source_id: source_sensor
name: "Sliding Quantile Filter"
filters:
- quantile:
window_size: 10
send_every: 5
quantile: 0.9 # Sliding window → QuantileFilter with ring buffer
- platform: copy
source_id: source_sensor
name: "Sliding Moving Average Filter"
filters:
- sliding_window_moving_average:
window_size: 10
send_every: 5 # Sliding window → SlidingWindowMovingAverageFilter with ring buffer
# Edge cases
- platform: copy
source_id: source_sensor
name: "Large Batch Window Min"
filters:
- min:
window_size: 1000
send_every: 1000 # Large batch → StreamingMinFilter (4 bytes, not 4KB)
- platform: copy
source_id: source_sensor
name: "Small Sliding Window"
filters:
- median:
window_size: 3
send_every: 1 # Frequent output → MedianFilter with 3-element ring buffer
# send_first_at parameter test
- platform: copy
source_id: source_sensor
name: "Early Send Filter"
filters:
- max:
window_size: 10
send_every: 10
send_first_at: 1 # Send after first value

View File

@@ -1 +0,0 @@
<<: !include common.yaml

View File

@@ -1,5 +1,4 @@
wifi:
fast_connect: true
networks:
- ssid: MySSID
eap:

View File

@@ -7,7 +7,6 @@ This directory contains end-to-end integration tests for ESPHome, focusing on te
- `conftest.py` - Common fixtures and utilities
- `const.py` - Constants used throughout the integration tests
- `types.py` - Type definitions for fixtures and functions
- `state_utils.py` - State handling utilities (e.g., `InitialStateHelper`, `build_key_to_entity_mapping`)
- `fixtures/` - YAML configuration files for tests
- `test_*.py` - Individual test files
@@ -27,32 +26,6 @@ The `yaml_config` fixture automatically loads YAML configurations based on the t
- `reserved_tcp_port` - Reserves a TCP port by holding the socket open until ESPHome needs it
- `unused_tcp_port` - Provides the reserved port number for each test
### Helper Utilities
#### InitialStateHelper (`state_utils.py`)
The `InitialStateHelper` class solves a common problem in integration tests: when an API client connects, ESPHome automatically broadcasts the current state of all entities. This can interfere with tests that want to track only new state changes triggered by test actions.
**What it does:**
- Tracks all entities (except stateless ones like buttons)
- Swallows the first state broadcast for each entity
- Forwards all subsequent state changes to your test callback
- Provides `wait_for_initial_states()` to synchronize before test actions
**When to use it:**
- Any test that triggers entity state changes and needs to verify them
- Tests that would otherwise see duplicate or unexpected states
- Tests that need clean separation between initial state and test-triggered changes
**Implementation details:**
- Uses `(device_id, key)` tuples to uniquely identify entities across devices
- Automatically excludes `ButtonInfo` entities (stateless)
- Provides debug logging to track state reception (use `--log-cli-level=DEBUG`)
- Safe for concurrent use with multiple entity types
**Future work:**
Consider converting existing integration tests to use `InitialStateHelper` for more reliable state tracking and to eliminate race conditions related to initial state broadcasts.
### Writing Tests
The simplest way to write a test is to use the `run_compiled` and `api_client_connected` fixtures:
@@ -152,54 +125,6 @@ async def test_my_sensor(
```
##### State Subscription Pattern
**Recommended: Using InitialStateHelper**
When an API client connects, ESPHome automatically sends the current state of all entities. The `InitialStateHelper` (from `state_utils.py`) handles this by swallowing these initial states and only forwarding subsequent state changes to your test callback:
```python
from .state_utils import InitialStateHelper
# Track state changes with futures
loop = asyncio.get_running_loop()
states: dict[int, EntityState] = {}
state_future: asyncio.Future[EntityState] = loop.create_future()
def on_state(state: EntityState) -> None:
"""This callback only receives NEW state changes, not initial states."""
states[state.key] = state
# Check for specific condition using isinstance
if isinstance(state, SensorState) and state.state == expected_value:
if not state_future.done():
state_future.set_result(state)
# Get entities and set up state synchronization
entities, services = await client.list_entities_services()
initial_state_helper = InitialStateHelper(entities)
# Subscribe with the wrapper that filters initial states
client.subscribe_states(initial_state_helper.on_state_wrapper(on_state))
# Wait for all initial states to be broadcast
try:
await initial_state_helper.wait_for_initial_states()
except TimeoutError:
pytest.fail("Timeout waiting for initial states")
# Now perform your test actions - on_state will only receive new changes
# ... trigger state changes ...
# Wait for expected state
try:
result = await asyncio.wait_for(state_future, timeout=5.0)
except asyncio.TimeoutError:
pytest.fail(f"Expected state not received. Got: {list(states.values())}")
```
**Legacy: Manual State Tracking**
If you need to handle initial states manually (not recommended for new tests):
```python
# Track state changes with futures
loop = asyncio.get_running_loop()

View File

@@ -1,170 +0,0 @@
esphome:
name: test-script-queued
host:
api:
actions:
# Test 1: Queue depth with default max_runs=5
- action: test_queue_depth
then:
- logger.log: "=== TEST 1: Queue depth (max_runs=5 means 5 total, reject 6-7) ==="
- script.execute:
id: queue_depth_script
value: 1
- script.execute:
id: queue_depth_script
value: 2
- script.execute:
id: queue_depth_script
value: 3
- script.execute:
id: queue_depth_script
value: 4
- script.execute:
id: queue_depth_script
value: 5
- script.execute:
id: queue_depth_script
value: 6
- script.execute:
id: queue_depth_script
value: 7
# Test 2: Ring buffer wrap test
- action: test_ring_buffer
then:
- logger.log: "=== TEST 2: Ring buffer wrap (should process A, B, C in order) ==="
- script.execute:
id: wrap_script
msg: "A"
- script.execute:
id: wrap_script
msg: "B"
- script.execute:
id: wrap_script
msg: "C"
# Test 3: Stop clears queue
- action: test_stop_clears
then:
- logger.log: "=== TEST 3: Stop clears queue (should only see 1, then 'STOPPED') ==="
- script.execute:
id: stop_script
num: 1
- script.execute:
id: stop_script
num: 2
- script.execute:
id: stop_script
num: 3
- delay: 50ms
- logger.log: "STOPPING script now"
- script.stop: stop_script
# Test 4: Verify rejection (max_runs=3)
- action: test_rejection
then:
- logger.log: "=== TEST 4: Verify rejection (max_runs=3 means 3 total, reject 4-8) ==="
- script.execute:
id: rejection_script
val: 1
- script.execute:
id: rejection_script
val: 2
- script.execute:
id: rejection_script
val: 3
- script.execute:
id: rejection_script
val: 4
- script.execute:
id: rejection_script
val: 5
- script.execute:
id: rejection_script
val: 6
- script.execute:
id: rejection_script
val: 7
- script.execute:
id: rejection_script
val: 8
# Test 5: No parameters test
- action: test_no_params
then:
- logger.log: "=== TEST 5: No params (should process 3 times) ==="
- script.execute: no_params_script
- script.execute: no_params_script
- script.execute: no_params_script
logger:
level: DEBUG
script:
# Test script 1: Queue depth test (default max_runs=5)
- id: queue_depth_script
mode: queued
parameters:
value: int
then:
- logger.log:
format: "Queue test: START item %d"
args: ['value']
- delay: 100ms
- logger.log:
format: "Queue test: END item %d"
args: ['value']
# Test script 2: Ring buffer wrap test (max_runs=3)
- id: wrap_script
mode: queued
max_runs: 3
parameters:
msg: string
then:
- logger.log:
format: "Ring buffer: START '%s'"
args: ['msg.c_str()']
- delay: 50ms
- logger.log:
format: "Ring buffer: END '%s'"
args: ['msg.c_str()']
# Test script 3: Stop test
- id: stop_script
mode: queued
max_runs: 5
parameters:
num: int
then:
- logger.log:
format: "Stop test: START %d"
args: ['num']
- delay: 100ms
- logger.log:
format: "Stop test: END %d"
args: ['num']
# Test script 4: Rejection test (max_runs=3)
- id: rejection_script
mode: queued
max_runs: 3
parameters:
val: int
then:
- logger.log:
format: "Rejection test: START %d"
args: ['val']
- delay: 200ms
- logger.log:
format: "Rejection test: END %d"
args: ['val']
# Test script 5: No parameters
- id: no_params_script
mode: queued
then:
- logger.log: "No params: START"
- delay: 50ms
- logger.log: "No params: END"

View File

@@ -1,58 +0,0 @@
esphome:
name: test-batch-window-filters
host:
api:
batch_delay: 0ms # Disable batching to receive all state updates
logger:
level: DEBUG
# Template sensor that we'll use to publish values
sensor:
- platform: template
name: "Source Sensor"
id: source_sensor
accuracy_decimals: 2
# Batch window filters (window_size == send_every) - use streaming filters
- platform: copy
source_id: source_sensor
name: "Min Sensor"
id: min_sensor
filters:
- min:
window_size: 5
send_every: 5
send_first_at: 1
- platform: copy
source_id: source_sensor
name: "Max Sensor"
id: max_sensor
filters:
- max:
window_size: 5
send_every: 5
send_first_at: 1
- platform: copy
source_id: source_sensor
name: "Moving Avg Sensor"
id: moving_avg_sensor
filters:
- sliding_window_moving_average:
window_size: 5
send_every: 5
send_first_at: 1
# Button to trigger publishing test values
button:
- platform: template
name: "Publish Values Button"
id: publish_button
on_press:
- lambda: |-
// Publish 10 values: 1.0, 2.0, ..., 10.0
for (int i = 1; i <= 10; i++) {
id(source_sensor).publish_state(float(i));
}

View File

@@ -1,84 +0,0 @@
esphome:
name: test-nan-handling
host:
api:
batch_delay: 0ms # Disable batching to receive all state updates
logger:
level: DEBUG
sensor:
- platform: template
name: "Source NaN Sensor"
id: source_nan_sensor
accuracy_decimals: 2
- platform: copy
source_id: source_nan_sensor
name: "Min NaN Sensor"
id: min_nan_sensor
filters:
- min:
window_size: 5
send_every: 5
send_first_at: 1
- platform: copy
source_id: source_nan_sensor
name: "Max NaN Sensor"
id: max_nan_sensor
filters:
- max:
window_size: 5
send_every: 5
send_first_at: 1
script:
- id: publish_nan_values_script
then:
- sensor.template.publish:
id: source_nan_sensor
state: 10.0
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: !lambda 'return NAN;'
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: 5.0
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: !lambda 'return NAN;'
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: 15.0
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: 8.0
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: !lambda 'return NAN;'
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: 12.0
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: 3.0
- delay: 20ms
- sensor.template.publish:
id: source_nan_sensor
state: !lambda 'return NAN;'
button:
- platform: template
name: "Publish NaN Values Button"
id: publish_nan_button
on_press:
- script.execute: publish_nan_values_script

View File

@@ -1,115 +0,0 @@
esphome:
name: test-sliding-window-filters
host:
api:
batch_delay: 0ms # Disable batching to receive all state updates
logger:
level: DEBUG
# Template sensor that we'll use to publish values
sensor:
- platform: template
name: "Source Sensor"
id: source_sensor
accuracy_decimals: 2
# ACTUAL sliding window filters (window_size != send_every) - use ring buffers
# Window of 5, send every 2 values
- platform: copy
source_id: source_sensor
name: "Sliding Min Sensor"
id: sliding_min_sensor
filters:
- min:
window_size: 5
send_every: 2
send_first_at: 1
- platform: copy
source_id: source_sensor
name: "Sliding Max Sensor"
id: sliding_max_sensor
filters:
- max:
window_size: 5
send_every: 2
send_first_at: 1
- platform: copy
source_id: source_sensor
name: "Sliding Median Sensor"
id: sliding_median_sensor
filters:
- median:
window_size: 5
send_every: 2
send_first_at: 1
- platform: copy
source_id: source_sensor
name: "Sliding Moving Avg Sensor"
id: sliding_moving_avg_sensor
filters:
- sliding_window_moving_average:
window_size: 5
send_every: 2
send_first_at: 1
# Button to trigger publishing test values
script:
- id: publish_values_script
then:
# Publish 10 values: 1.0, 2.0, ..., 10.0
# With window_size=5, send_every=2, send_first_at=1:
# - Output at position 1: window=[1], min=1, max=1, median=1, avg=1
# - Output at position 3: window=[1,2,3], min=1, max=3, median=2, avg=2
# - Output at position 5: window=[1,2,3,4,5], min=1, max=5, median=3, avg=3
# - Output at position 7: window=[3,4,5,6,7], min=3, max=7, median=5, avg=5
# - Output at position 9: window=[5,6,7,8,9], min=5, max=9, median=7, avg=7
- sensor.template.publish:
id: source_sensor
state: 1.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 2.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 3.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 4.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 5.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 6.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 7.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 8.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 9.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 10.0
button:
- platform: template
name: "Publish Values Button"
id: publish_button
on_press:
- script.execute: publish_values_script

View File

@@ -1,72 +0,0 @@
esphome:
name: test-ring-buffer-wraparound
host:
api:
batch_delay: 0ms # Disable batching to receive all state updates
logger:
level: DEBUG
sensor:
- platform: template
name: "Source Wraparound Sensor"
id: source_wraparound
accuracy_decimals: 2
- platform: copy
source_id: source_wraparound
name: "Wraparound Min Sensor"
id: wraparound_min_sensor
filters:
- min:
window_size: 3
send_every: 3
send_first_at: 1
script:
- id: publish_wraparound_script
then:
# Publish 9 values to test ring buffer wraparound
# Values: 10, 20, 30, 5, 25, 15, 40, 35, 20
- sensor.template.publish:
id: source_wraparound
state: 10.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 20.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 30.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 5.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 25.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 15.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 40.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 35.0
- delay: 20ms
- sensor.template.publish:
id: source_wraparound
state: 20.0
button:
- platform: template
name: "Publish Wraparound Button"
id: publish_wraparound_button
on_press:
- script.execute: publish_wraparound_script

View File

@@ -1,123 +0,0 @@
esphome:
name: test-sliding-window-filters
host:
api:
batch_delay: 0ms # Disable batching to receive all state updates
logger:
level: DEBUG
# Template sensor that we'll use to publish values
sensor:
- platform: template
name: "Source Sensor"
id: source_sensor
accuracy_decimals: 2
# Min filter sensor
- platform: copy
source_id: source_sensor
name: "Min Sensor"
id: min_sensor
filters:
- min:
window_size: 5
send_every: 5
send_first_at: 1
# Max filter sensor
- platform: copy
source_id: source_sensor
name: "Max Sensor"
id: max_sensor
filters:
- max:
window_size: 5
send_every: 5
send_first_at: 1
# Median filter sensor
- platform: copy
source_id: source_sensor
name: "Median Sensor"
id: median_sensor
filters:
- median:
window_size: 5
send_every: 5
send_first_at: 1
# Quantile filter sensor (90th percentile)
- platform: copy
source_id: source_sensor
name: "Quantile Sensor"
id: quantile_sensor
filters:
- quantile:
window_size: 5
send_every: 5
send_first_at: 1
quantile: 0.9
# Moving average filter sensor
- platform: copy
source_id: source_sensor
name: "Moving Avg Sensor"
id: moving_avg_sensor
filters:
- sliding_window_moving_average:
window_size: 5
send_every: 5
send_first_at: 1
# Script to publish values with delays
script:
- id: publish_values_script
then:
- sensor.template.publish:
id: source_sensor
state: 1.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 2.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 3.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 4.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 5.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 6.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 7.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 8.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 9.0
- delay: 20ms
- sensor.template.publish:
id: source_sensor
state: 10.0
# Button to trigger publishing test values
button:
- platform: template
name: "Publish Values Button"
id: publish_button
on_press:
- script.execute: publish_values_script

View File

@@ -1,167 +0,0 @@
"""Shared utilities for ESPHome integration tests - state handling."""
from __future__ import annotations
import asyncio
import logging
from aioesphomeapi import ButtonInfo, EntityInfo, EntityState
_LOGGER = logging.getLogger(__name__)
def build_key_to_entity_mapping(
entities: list[EntityInfo], entity_names: list[str]
) -> dict[int, str]:
"""Build a mapping from entity keys to entity names.
Args:
entities: List of entity info objects from the API
entity_names: List of entity names to search for in object_ids
Returns:
Dictionary mapping entity keys to entity names
"""
key_to_entity: dict[int, str] = {}
for entity in entities:
obj_id = entity.object_id.lower()
for entity_name in entity_names:
if entity_name in obj_id:
key_to_entity[entity.key] = entity_name
break
return key_to_entity
class InitialStateHelper:
"""Helper to wait for initial states before processing test states.
When an API client connects, ESPHome sends the current state of all entities.
This helper wraps the user's state callback and swallows the first state for
each entity, then forwards all subsequent states to the user callback.
Usage:
entities, services = await client.list_entities_services()
helper = InitialStateHelper(entities)
client.subscribe_states(helper.on_state_wrapper(user_callback))
await helper.wait_for_initial_states()
"""
def __init__(self, entities: list[EntityInfo]) -> None:
"""Initialize the helper.
Args:
entities: All entities from list_entities_services()
"""
# Set of (device_id, key) tuples waiting for initial state
# Buttons are stateless, so exclude them
self._wait_initial_states = {
(entity.device_id, entity.key)
for entity in entities
if not isinstance(entity, ButtonInfo)
}
# Keep entity info for debugging - use (device_id, key) tuple
self._entities_by_id = {
(entity.device_id, entity.key): entity for entity in entities
}
# Log all entities
_LOGGER.debug(
"InitialStateHelper: Found %d total entities: %s",
len(entities),
[(type(e).__name__, e.object_id) for e in entities],
)
# Log which ones we're waiting for
_LOGGER.debug(
"InitialStateHelper: Waiting for %d entities (excluding ButtonInfo): %s",
len(self._wait_initial_states),
[self._entities_by_id[k].object_id for k in self._wait_initial_states],
)
# Log which ones we're NOT waiting for
not_waiting = {
(e.device_id, e.key) for e in entities
} - self._wait_initial_states
if not_waiting:
not_waiting_info = [
f"{type(self._entities_by_id[k]).__name__}:{self._entities_by_id[k].object_id}"
for k in not_waiting
]
_LOGGER.debug(
"InitialStateHelper: NOT waiting for %d entities: %s",
len(not_waiting),
not_waiting_info,
)
# Create future in the running event loop
self._initial_states_received = asyncio.get_running_loop().create_future()
# If no entities to wait for, mark complete immediately
if not self._wait_initial_states:
self._initial_states_received.set_result(True)
def on_state_wrapper(self, user_callback):
"""Wrap a user callback to track initial states.
Args:
user_callback: The user's state callback function
Returns:
Wrapped callback that swallows first state per entity, forwards rest
"""
def wrapper(state: EntityState) -> None:
"""Swallow initial state per entity, forward subsequent states."""
# Create entity identifier tuple
entity_id = (state.device_id, state.key)
# Log which entity is sending state
if entity_id in self._entities_by_id:
entity = self._entities_by_id[entity_id]
_LOGGER.debug(
"Received state for %s (type: %s, device_id: %s, key: %d)",
entity.object_id,
type(entity).__name__,
state.device_id,
state.key,
)
# If this entity is waiting for initial state
if entity_id in self._wait_initial_states:
# Remove from waiting set
self._wait_initial_states.discard(entity_id)
_LOGGER.debug(
"Swallowed initial state for %s, %d entities remaining",
self._entities_by_id[entity_id].object_id
if entity_id in self._entities_by_id
else entity_id,
len(self._wait_initial_states),
)
# Check if we've now seen all entities
if (
not self._wait_initial_states
and not self._initial_states_received.done()
):
_LOGGER.debug("All initial states received")
self._initial_states_received.set_result(True)
# Don't forward initial state to user
return
# Forward subsequent states to user callback
_LOGGER.debug("Forwarding state to user callback")
user_callback(state)
return wrapper
async def wait_for_initial_states(self, timeout: float = 5.0) -> None:
"""Wait for all initial states to be received.
Args:
timeout: Maximum time to wait in seconds
Raises:
asyncio.TimeoutError: If initial states aren't received within timeout
"""
await asyncio.wait_for(self._initial_states_received, timeout=timeout)

View File

@@ -8,7 +8,6 @@ import asyncio
from typing import Any
from aioesphomeapi import LightState
from aioesphomeapi.model import ColorMode
import pytest
from .types import APIClientConnectedFactory, RunCompiledFunction
@@ -36,51 +35,10 @@ async def test_light_calls(
# Get the light entities
entities = await client.list_entities_services()
lights = [e for e in entities[0] if e.object_id.startswith("test_")]
assert len(lights) >= 3 # Should have RGBCW, RGB, and Binary lights
assert len(lights) >= 2 # Should have RGBCW and RGB lights
rgbcw_light = next(light for light in lights if "RGBCW" in light.name)
rgb_light = next(light for light in lights if "RGB Light" in light.name)
binary_light = next(light for light in lights if "Binary" in light.name)
# Test color mode encoding: Verify supported_color_modes contains actual ColorMode enum values
# not bit positions. This is critical - the iterator must convert bit positions to actual
# ColorMode enum values for API encoding.
# RGBCW light (rgbww platform) should support RGB_COLD_WARM_WHITE mode
assert ColorMode.RGB_COLD_WARM_WHITE in rgbcw_light.supported_color_modes, (
f"RGBCW light missing RGB_COLD_WARM_WHITE mode. Got: {rgbcw_light.supported_color_modes}"
)
# Verify it's the actual enum value, not bit position
assert ColorMode.RGB_COLD_WARM_WHITE.value in [
mode.value for mode in rgbcw_light.supported_color_modes
], (
f"RGBCW light has wrong color mode values. Expected {ColorMode.RGB_COLD_WARM_WHITE.value} "
f"(RGB_COLD_WARM_WHITE), got: {[mode.value for mode in rgbcw_light.supported_color_modes]}"
)
# RGB light should support RGB mode
assert ColorMode.RGB in rgb_light.supported_color_modes, (
f"RGB light missing RGB color mode. Got: {rgb_light.supported_color_modes}"
)
# Verify it's the actual enum value, not bit position
assert ColorMode.RGB.value in [
mode.value for mode in rgb_light.supported_color_modes
], (
f"RGB light has wrong color mode values. Expected {ColorMode.RGB.value} (RGB), got: "
f"{[mode.value for mode in rgb_light.supported_color_modes]}"
)
# Binary light (on/off only) should support ON_OFF mode
assert ColorMode.ON_OFF in binary_light.supported_color_modes, (
f"Binary light missing ON_OFF color mode. Got: {binary_light.supported_color_modes}"
)
# Verify it's the actual enum value, not bit position
assert ColorMode.ON_OFF.value in [
mode.value for mode in binary_light.supported_color_modes
], (
f"Binary light has wrong color mode values. Expected {ColorMode.ON_OFF.value} (ON_OFF), got: "
f"{[mode.value for mode in binary_light.supported_color_modes]}"
)
async def wait_for_state_change(key: int, timeout: float = 1.0) -> Any:
"""Wait for a state change for the given entity key."""

View File

@@ -1,203 +0,0 @@
"""Test ESPHome queued script functionality."""
from __future__ import annotations
import asyncio
import re
import pytest
from .types import APIClientConnectedFactory, RunCompiledFunction
@pytest.mark.asyncio
async def test_script_queued(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test comprehensive queued script functionality."""
loop = asyncio.get_running_loop()
# Track all test results
test_results = {
"queue_depth": {"processed": [], "rejections": 0},
"ring_buffer": {"start_order": [], "end_order": []},
"stop": {"processed": [], "stop_logged": False},
"rejection": {"processed": [], "rejections": 0},
"no_params": {"executions": 0},
}
# Patterns for Test 1: Queue depth
queue_start = re.compile(r"Queue test: START item (\d+)")
queue_end = re.compile(r"Queue test: END item (\d+)")
queue_reject = re.compile(r"Script 'queue_depth_script' max instances")
# Patterns for Test 2: Ring buffer
ring_start = re.compile(r"Ring buffer: START '([A-Z])'")
ring_end = re.compile(r"Ring buffer: END '([A-Z])'")
# Patterns for Test 3: Stop
stop_start = re.compile(r"Stop test: START (\d+)")
stop_log = re.compile(r"STOPPING script now")
# Patterns for Test 4: Rejection
reject_start = re.compile(r"Rejection test: START (\d+)")
reject_end = re.compile(r"Rejection test: END (\d+)")
reject_reject = re.compile(r"Script 'rejection_script' max instances")
# Patterns for Test 5: No params
no_params_end = re.compile(r"No params: END")
# Test completion futures
test1_complete = loop.create_future()
test2_complete = loop.create_future()
test3_complete = loop.create_future()
test4_complete = loop.create_future()
test5_complete = loop.create_future()
def check_output(line: str) -> None:
"""Check log output for all test messages."""
# Test 1: Queue depth
if match := queue_start.search(line):
item = int(match.group(1))
if item not in test_results["queue_depth"]["processed"]:
test_results["queue_depth"]["processed"].append(item)
if match := queue_end.search(line):
item = int(match.group(1))
if item == 5 and not test1_complete.done():
test1_complete.set_result(True)
if queue_reject.search(line):
test_results["queue_depth"]["rejections"] += 1
# Test 2: Ring buffer
if match := ring_start.search(line):
msg = match.group(1)
test_results["ring_buffer"]["start_order"].append(msg)
if match := ring_end.search(line):
msg = match.group(1)
test_results["ring_buffer"]["end_order"].append(msg)
if (
len(test_results["ring_buffer"]["end_order"]) == 3
and not test2_complete.done()
):
test2_complete.set_result(True)
# Test 3: Stop
if match := stop_start.search(line):
item = int(match.group(1))
if item not in test_results["stop"]["processed"]:
test_results["stop"]["processed"].append(item)
if stop_log.search(line):
test_results["stop"]["stop_logged"] = True
# Give time for any queued items to be cleared
if not test3_complete.done():
loop.call_later(
0.3,
lambda: test3_complete.set_result(True)
if not test3_complete.done()
else None,
)
# Test 4: Rejection
if match := reject_start.search(line):
item = int(match.group(1))
if item not in test_results["rejection"]["processed"]:
test_results["rejection"]["processed"].append(item)
if match := reject_end.search(line):
item = int(match.group(1))
if item == 3 and not test4_complete.done():
test4_complete.set_result(True)
if reject_reject.search(line):
test_results["rejection"]["rejections"] += 1
# Test 5: No params
if no_params_end.search(line):
test_results["no_params"]["executions"] += 1
if (
test_results["no_params"]["executions"] == 3
and not test5_complete.done()
):
test5_complete.set_result(True)
async with (
run_compiled(yaml_config, line_callback=check_output),
api_client_connected() as client,
):
# Get services
_, services = await client.list_entities_services()
# Test 1: Queue depth limit
test_service = next((s for s in services if s.name == "test_queue_depth"), None)
assert test_service is not None, "test_queue_depth service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test1_complete, timeout=2.0)
await asyncio.sleep(0.1) # Give time for rejections
# Verify Test 1
assert sorted(test_results["queue_depth"]["processed"]) == [1, 2, 3, 4, 5], (
f"Test 1: Expected to process items 1-5 (max_runs=5 means 5 total), got {sorted(test_results['queue_depth']['processed'])}"
)
assert test_results["queue_depth"]["rejections"] >= 2, (
"Test 1: Expected at least 2 rejection warnings (items 6-7 should be rejected)"
)
# Test 2: Ring buffer order
test_service = next((s for s in services if s.name == "test_ring_buffer"), None)
assert test_service is not None, "test_ring_buffer service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test2_complete, timeout=2.0)
# Verify Test 2
assert test_results["ring_buffer"]["start_order"] == ["A", "B", "C"], (
f"Test 2: Expected start order [A, B, C], got {test_results['ring_buffer']['start_order']}"
)
assert test_results["ring_buffer"]["end_order"] == ["A", "B", "C"], (
f"Test 2: Expected end order [A, B, C], got {test_results['ring_buffer']['end_order']}"
)
# Test 3: Stop clears queue
test_service = next((s for s in services if s.name == "test_stop_clears"), None)
assert test_service is not None, "test_stop_clears service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test3_complete, timeout=2.0)
# Verify Test 3
assert test_results["stop"]["stop_logged"], (
"Test 3: Stop command was not logged"
)
assert test_results["stop"]["processed"] == [1], (
f"Test 3: Expected only item 1 to process, got {test_results['stop']['processed']}"
)
# Test 4: Rejection enforcement (max_runs=3)
test_service = next((s for s in services if s.name == "test_rejection"), None)
assert test_service is not None, "test_rejection service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test4_complete, timeout=2.0)
await asyncio.sleep(0.1) # Give time for rejections
# Verify Test 4
assert sorted(test_results["rejection"]["processed"]) == [1, 2, 3], (
f"Test 4: Expected to process items 1-3 (max_runs=3 means 3 total), got {sorted(test_results['rejection']['processed'])}"
)
assert test_results["rejection"]["rejections"] == 5, (
f"Test 4: Expected 5 rejections (items 4-8), got {test_results['rejection']['rejections']}"
)
# Test 5: No parameters
test_service = next((s for s in services if s.name == "test_no_params"), None)
assert test_service is not None, "test_no_params service not found"
client.execute_service(test_service, {})
await asyncio.wait_for(test5_complete, timeout=2.0)
# Verify Test 5
assert test_results["no_params"]["executions"] == 3, (
f"Test 5: Expected 3 executions, got {test_results['no_params']['executions']}"
)

View File

@@ -1,151 +0,0 @@
"""Test sensor ring buffer filter functionality (window_size != send_every)."""
from __future__ import annotations
import asyncio
from aioesphomeapi import EntityState, SensorState
import pytest
from .state_utils import InitialStateHelper, build_key_to_entity_mapping
from .types import APIClientConnectedFactory, RunCompiledFunction
@pytest.mark.asyncio
async def test_sensor_filters_ring_buffer(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test that ring buffer filters (window_size != send_every) work correctly."""
loop = asyncio.get_running_loop()
# Track state changes for each sensor
sensor_states: dict[str, list[float]] = {
"sliding_min": [],
"sliding_max": [],
"sliding_median": [],
"sliding_moving_avg": [],
}
# Futures to track when we receive expected values
all_updates_received = loop.create_future()
def on_state(state: EntityState) -> None:
"""Track sensor state updates."""
if not isinstance(state, SensorState):
return
# Skip NaN values
if state.missing_state:
return
# Get the sensor name from the key mapping
sensor_name = key_to_sensor.get(state.key)
if not sensor_name or sensor_name not in sensor_states:
return
sensor_states[sensor_name].append(state.state)
# Check if we've received enough updates from all sensors
# With send_every=2, send_first_at=1, we expect 5 outputs per sensor
if (
len(sensor_states["sliding_min"]) >= 5
and len(sensor_states["sliding_max"]) >= 5
and len(sensor_states["sliding_median"]) >= 5
and len(sensor_states["sliding_moving_avg"]) >= 5
and not all_updates_received.done()
):
all_updates_received.set_result(True)
async with (
run_compiled(yaml_config),
api_client_connected() as client,
):
# Get entities first to build key mapping
entities, services = await client.list_entities_services()
# Build key-to-sensor mapping
key_to_sensor = build_key_to_entity_mapping(
entities,
[
"sliding_min",
"sliding_max",
"sliding_median",
"sliding_moving_avg",
],
)
# Set up initial state helper with all entities
initial_state_helper = InitialStateHelper(entities)
# Subscribe to state changes with wrapper
client.subscribe_states(initial_state_helper.on_state_wrapper(on_state))
# Wait for initial states to be sent before pressing button
try:
await initial_state_helper.wait_for_initial_states()
except TimeoutError:
pytest.fail("Timeout waiting for initial states")
# Find the publish button
publish_button = next(
(e for e in entities if "publish_values_button" in e.object_id.lower()),
None,
)
assert publish_button is not None, "Publish Values Button not found"
# Press the button to publish test values
client.button_command(publish_button.key)
# Wait for all sensors to receive their values
try:
await asyncio.wait_for(all_updates_received, timeout=10.0)
except TimeoutError:
# Provide detailed failure info
pytest.fail(
f"Timeout waiting for updates. Received states:\n"
f" min: {sensor_states['sliding_min']}\n"
f" max: {sensor_states['sliding_max']}\n"
f" median: {sensor_states['sliding_median']}\n"
f" moving_avg: {sensor_states['sliding_moving_avg']}"
)
# Verify we got 5 outputs per sensor (positions 1, 3, 5, 7, 9)
assert len(sensor_states["sliding_min"]) == 5, (
f"Min sensor should have 5 values, got {len(sensor_states['sliding_min'])}: {sensor_states['sliding_min']}"
)
assert len(sensor_states["sliding_max"]) == 5
assert len(sensor_states["sliding_median"]) == 5
assert len(sensor_states["sliding_moving_avg"]) == 5
# Verify the values at each output position
# Position 1: window=[1]
assert sensor_states["sliding_min"][0] == pytest.approx(1.0)
assert sensor_states["sliding_max"][0] == pytest.approx(1.0)
assert sensor_states["sliding_median"][0] == pytest.approx(1.0)
assert sensor_states["sliding_moving_avg"][0] == pytest.approx(1.0)
# Position 3: window=[1,2,3]
assert sensor_states["sliding_min"][1] == pytest.approx(1.0)
assert sensor_states["sliding_max"][1] == pytest.approx(3.0)
assert sensor_states["sliding_median"][1] == pytest.approx(2.0)
assert sensor_states["sliding_moving_avg"][1] == pytest.approx(2.0)
# Position 5: window=[1,2,3,4,5]
assert sensor_states["sliding_min"][2] == pytest.approx(1.0)
assert sensor_states["sliding_max"][2] == pytest.approx(5.0)
assert sensor_states["sliding_median"][2] == pytest.approx(3.0)
assert sensor_states["sliding_moving_avg"][2] == pytest.approx(3.0)
# Position 7: window=[3,4,5,6,7] (ring buffer wrapped)
assert sensor_states["sliding_min"][3] == pytest.approx(3.0)
assert sensor_states["sliding_max"][3] == pytest.approx(7.0)
assert sensor_states["sliding_median"][3] == pytest.approx(5.0)
assert sensor_states["sliding_moving_avg"][3] == pytest.approx(5.0)
# Position 9: window=[5,6,7,8,9] (ring buffer wrapped)
assert sensor_states["sliding_min"][4] == pytest.approx(5.0)
assert sensor_states["sliding_max"][4] == pytest.approx(9.0)
assert sensor_states["sliding_median"][4] == pytest.approx(7.0)
assert sensor_states["sliding_moving_avg"][4] == pytest.approx(7.0)

View File

@@ -1,395 +0,0 @@
"""Test sensor sliding window filter functionality."""
from __future__ import annotations
import asyncio
from aioesphomeapi import EntityState, SensorState
import pytest
from .state_utils import InitialStateHelper, build_key_to_entity_mapping
from .types import APIClientConnectedFactory, RunCompiledFunction
@pytest.mark.asyncio
async def test_sensor_filters_sliding_window(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test that sliding window filters (min, max, median, quantile, moving_average) work correctly."""
loop = asyncio.get_running_loop()
# Track state changes for each sensor
sensor_states: dict[str, list[float]] = {
"min_sensor": [],
"max_sensor": [],
"median_sensor": [],
"quantile_sensor": [],
"moving_avg_sensor": [],
}
# Futures to track when we receive expected values
min_received = loop.create_future()
max_received = loop.create_future()
median_received = loop.create_future()
quantile_received = loop.create_future()
moving_avg_received = loop.create_future()
def on_state(state: EntityState) -> None:
"""Track sensor state updates."""
if not isinstance(state, SensorState):
return
# Skip NaN values
if state.missing_state:
return
# Get the sensor name from the key mapping
sensor_name = key_to_sensor.get(state.key)
if not sensor_name or sensor_name not in sensor_states:
return
sensor_states[sensor_name].append(state.state)
# Check if we received the expected final value
# After publishing 10 values [1.0, 2.0, ..., 10.0], the window has the last 5: [2, 3, 4, 5, 6]
# Filters send at position 1 and position 6 (send_every=5 means every 5th value after first)
if (
sensor_name == "min_sensor"
and state.state == pytest.approx(2.0)
and not min_received.done()
):
min_received.set_result(True)
elif (
sensor_name == "max_sensor"
and state.state == pytest.approx(6.0)
and not max_received.done()
):
max_received.set_result(True)
elif (
sensor_name == "median_sensor"
and state.state == pytest.approx(4.0)
and not median_received.done()
):
# Median of [2, 3, 4, 5, 6] = 4
median_received.set_result(True)
elif (
sensor_name == "quantile_sensor"
and state.state == pytest.approx(6.0)
and not quantile_received.done()
):
# 90th percentile of [2, 3, 4, 5, 6] = 6
quantile_received.set_result(True)
elif (
sensor_name == "moving_avg_sensor"
and state.state == pytest.approx(4.0)
and not moving_avg_received.done()
):
# Average of [2, 3, 4, 5, 6] = 4
moving_avg_received.set_result(True)
async with (
run_compiled(yaml_config),
api_client_connected() as client,
):
# Get entities first to build key mapping
entities, services = await client.list_entities_services()
# Build key-to-sensor mapping
key_to_sensor = build_key_to_entity_mapping(
entities,
[
"min_sensor",
"max_sensor",
"median_sensor",
"quantile_sensor",
"moving_avg_sensor",
],
)
# Set up initial state helper with all entities
initial_state_helper = InitialStateHelper(entities)
# Subscribe to state changes with wrapper
client.subscribe_states(initial_state_helper.on_state_wrapper(on_state))
# Wait for initial states to be sent before pressing button
try:
await initial_state_helper.wait_for_initial_states()
except TimeoutError:
pytest.fail("Timeout waiting for initial states")
# Find the publish button
publish_button = next(
(e for e in entities if "publish_values_button" in e.object_id.lower()),
None,
)
assert publish_button is not None, "Publish Values Button not found"
# Press the button to publish test values
client.button_command(publish_button.key)
# Wait for all sensors to receive their final values
try:
await asyncio.wait_for(
asyncio.gather(
min_received,
max_received,
median_received,
quantile_received,
moving_avg_received,
),
timeout=10.0,
)
except TimeoutError:
# Provide detailed failure info
pytest.fail(
f"Timeout waiting for expected values. Received states:\n"
f" min: {sensor_states['min_sensor']}\n"
f" max: {sensor_states['max_sensor']}\n"
f" median: {sensor_states['median_sensor']}\n"
f" quantile: {sensor_states['quantile_sensor']}\n"
f" moving_avg: {sensor_states['moving_avg_sensor']}"
)
# Verify we got the expected values
# With batch_delay: 0ms, we should receive all outputs
# Filters output at positions 1 and 6 (send_every: 5)
assert len(sensor_states["min_sensor"]) == 2, (
f"Min sensor should have 2 values, got {len(sensor_states['min_sensor'])}: {sensor_states['min_sensor']}"
)
assert len(sensor_states["max_sensor"]) == 2, (
f"Max sensor should have 2 values, got {len(sensor_states['max_sensor'])}: {sensor_states['max_sensor']}"
)
assert len(sensor_states["median_sensor"]) == 2
assert len(sensor_states["quantile_sensor"]) == 2
assert len(sensor_states["moving_avg_sensor"]) == 2
# Verify the first output (after 1 value: [1])
assert sensor_states["min_sensor"][0] == pytest.approx(1.0), (
f"First min should be 1.0, got {sensor_states['min_sensor'][0]}"
)
assert sensor_states["max_sensor"][0] == pytest.approx(1.0), (
f"First max should be 1.0, got {sensor_states['max_sensor'][0]}"
)
assert sensor_states["median_sensor"][0] == pytest.approx(1.0), (
f"First median should be 1.0, got {sensor_states['median_sensor'][0]}"
)
assert sensor_states["moving_avg_sensor"][0] == pytest.approx(1.0), (
f"First moving avg should be 1.0, got {sensor_states['moving_avg_sensor'][0]}"
)
# Verify the second output (after 6 values, window has [2, 3, 4, 5, 6])
assert sensor_states["min_sensor"][1] == pytest.approx(2.0), (
f"Second min should be 2.0, got {sensor_states['min_sensor'][1]}"
)
assert sensor_states["max_sensor"][1] == pytest.approx(6.0), (
f"Second max should be 6.0, got {sensor_states['max_sensor'][1]}"
)
assert sensor_states["median_sensor"][1] == pytest.approx(4.0), (
f"Second median should be 4.0, got {sensor_states['median_sensor'][1]}"
)
assert sensor_states["moving_avg_sensor"][1] == pytest.approx(4.0), (
f"Second moving avg should be 4.0, got {sensor_states['moving_avg_sensor'][1]}"
)
@pytest.mark.asyncio
async def test_sensor_filters_nan_handling(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test that sliding window filters handle NaN values correctly."""
loop = asyncio.get_running_loop()
# Track states
min_states: list[float] = []
max_states: list[float] = []
# Future to track completion
filters_completed = loop.create_future()
def on_state(state: EntityState) -> None:
"""Track sensor state updates."""
if not isinstance(state, SensorState):
return
# Skip NaN values
if state.missing_state:
return
sensor_name = key_to_sensor.get(state.key)
if sensor_name == "min_nan":
min_states.append(state.state)
elif sensor_name == "max_nan":
max_states.append(state.state)
# Check if both have received their final values
# With batch_delay: 0ms, we should receive 2 outputs each
if (
len(min_states) >= 2
and len(max_states) >= 2
and not filters_completed.done()
):
filters_completed.set_result(True)
async with (
run_compiled(yaml_config),
api_client_connected() as client,
):
# Get entities first to build key mapping
entities, services = await client.list_entities_services()
# Build key-to-sensor mapping
key_to_sensor = build_key_to_entity_mapping(entities, ["min_nan", "max_nan"])
# Set up initial state helper with all entities
initial_state_helper = InitialStateHelper(entities)
# Subscribe to state changes with wrapper
client.subscribe_states(initial_state_helper.on_state_wrapper(on_state))
# Wait for initial states
try:
await initial_state_helper.wait_for_initial_states()
except TimeoutError:
pytest.fail("Timeout waiting for initial states")
# Find the publish button
publish_button = next(
(e for e in entities if "publish_nan_values_button" in e.object_id.lower()),
None,
)
assert publish_button is not None, "Publish NaN Values Button not found"
# Press the button
client.button_command(publish_button.key)
# Wait for filters to process
try:
await asyncio.wait_for(filters_completed, timeout=10.0)
except TimeoutError:
pytest.fail(
f"Timeout waiting for NaN handling. Received:\n"
f" min_states: {min_states}\n"
f" max_states: {max_states}"
)
# Verify NaN values were ignored
# With batch_delay: 0ms, we should receive both outputs (at positions 1 and 6)
# Position 1: window=[10], min=10, max=10
# Position 6: window=[NaN, 5, NaN, 15, 8], ignoring NaN -> [5, 15, 8], min=5, max=15
assert len(min_states) == 2, (
f"Should have 2 min states, got {len(min_states)}: {min_states}"
)
assert len(max_states) == 2, (
f"Should have 2 max states, got {len(max_states)}: {max_states}"
)
# First output
assert min_states[0] == pytest.approx(10.0), (
f"First min should be 10.0, got {min_states[0]}"
)
assert max_states[0] == pytest.approx(10.0), (
f"First max should be 10.0, got {max_states[0]}"
)
# Second output - verify NaN values were ignored
assert min_states[1] == pytest.approx(5.0), (
f"Second min should ignore NaN and return 5.0, got {min_states[1]}"
)
assert max_states[1] == pytest.approx(15.0), (
f"Second max should ignore NaN and return 15.0, got {max_states[1]}"
)
@pytest.mark.asyncio
async def test_sensor_filters_ring_buffer_wraparound(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test that ring buffer correctly wraps around when window fills up."""
loop = asyncio.get_running_loop()
min_states: list[float] = []
test_completed = loop.create_future()
def on_state(state: EntityState) -> None:
"""Track min sensor states."""
if not isinstance(state, SensorState):
return
# Skip NaN values
if state.missing_state:
return
sensor_name = key_to_sensor.get(state.key)
if sensor_name == "wraparound_min":
min_states.append(state.state)
# With batch_delay: 0ms, we should receive all 3 outputs
if len(min_states) >= 3 and not test_completed.done():
test_completed.set_result(True)
async with (
run_compiled(yaml_config),
api_client_connected() as client,
):
# Get entities first to build key mapping
entities, services = await client.list_entities_services()
# Build key-to-sensor mapping
key_to_sensor = build_key_to_entity_mapping(entities, ["wraparound_min"])
# Set up initial state helper with all entities
initial_state_helper = InitialStateHelper(entities)
# Subscribe to state changes with wrapper
client.subscribe_states(initial_state_helper.on_state_wrapper(on_state))
# Wait for initial state
try:
await initial_state_helper.wait_for_initial_states()
except TimeoutError:
pytest.fail("Timeout waiting for initial state")
# Find the publish button
publish_button = next(
(e for e in entities if "publish_wraparound_button" in e.object_id.lower()),
None,
)
assert publish_button is not None, "Publish Wraparound Button not found"
# Press the button
# Will publish: 10, 20, 30, 5, 25, 15, 40, 35, 20
client.button_command(publish_button.key)
# Wait for completion
try:
await asyncio.wait_for(test_completed, timeout=10.0)
except TimeoutError:
pytest.fail(f"Timeout waiting for wraparound test. Received: {min_states}")
# Verify outputs
# With window_size=3, send_every=3, we get outputs at positions 1, 4, 7
# Position 1: window=[10], min=10
# Position 4: window=[20, 30, 5], min=5
# Position 7: window=[15, 40, 35], min=15
# With batch_delay: 0ms, we should receive all 3 outputs
assert len(min_states) == 3, (
f"Should have 3 states, got {len(min_states)}: {min_states}"
)
assert min_states[0] == pytest.approx(10.0), (
f"First min should be 10.0, got {min_states[0]}"
)
assert min_states[1] == pytest.approx(5.0), (
f"Second min should be 5.0, got {min_states[1]}"
)
assert min_states[2] == pytest.approx(15.0), (
f"Third min should be 15.0, got {min_states[2]}"
)

View File

@@ -17,9 +17,6 @@ script_dir = os.path.abspath(
)
sys.path.insert(0, script_dir)
# Import helpers module for patching
import helpers # noqa: E402
spec = importlib.util.spec_from_file_location(
"determine_jobs", os.path.join(script_dir, "determine-jobs.py")
)
@@ -62,29 +59,15 @@ def mock_subprocess_run() -> Generator[Mock, None, None]:
yield mock
@pytest.fixture
def mock_changed_files() -> Generator[Mock, None, None]:
"""Mock changed_files for memory impact detection."""
with patch.object(determine_jobs, "changed_files") as mock:
# Default to empty list
mock.return_value = []
yield mock
def test_main_all_tests_should_run(
mock_should_run_integration_tests: Mock,
mock_should_run_clang_tidy: Mock,
mock_should_run_clang_format: Mock,
mock_should_run_python_linters: Mock,
mock_subprocess_run: Mock,
mock_changed_files: Mock,
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test when all tests should run."""
# Ensure we're not in GITHUB_ACTIONS mode for this test
monkeypatch.delenv("GITHUB_ACTIONS", raising=False)
mock_should_run_integration_tests.return_value = True
mock_should_run_clang_tidy.return_value = True
mock_should_run_clang_format.return_value = True
@@ -117,9 +100,6 @@ def test_main_all_tests_should_run(
assert output["component_test_count"] == len(
output["changed_components_with_tests"]
)
# memory_impact should be present
assert "memory_impact" in output
assert output["memory_impact"]["should_run"] == "false" # No files changed
def test_main_no_tests_should_run(
@@ -128,14 +108,9 @@ def test_main_no_tests_should_run(
mock_should_run_clang_format: Mock,
mock_should_run_python_linters: Mock,
mock_subprocess_run: Mock,
mock_changed_files: Mock,
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test when no tests should run."""
# Ensure we're not in GITHUB_ACTIONS mode for this test
monkeypatch.delenv("GITHUB_ACTIONS", raising=False)
mock_should_run_integration_tests.return_value = False
mock_should_run_clang_tidy.return_value = False
mock_should_run_clang_format.return_value = False
@@ -161,9 +136,6 @@ def test_main_no_tests_should_run(
assert output["changed_components"] == []
assert output["changed_components_with_tests"] == []
assert output["component_test_count"] == 0
# memory_impact should be present
assert "memory_impact" in output
assert output["memory_impact"]["should_run"] == "false"
def test_main_list_components_fails(
@@ -197,14 +169,9 @@ def test_main_with_branch_argument(
mock_should_run_clang_format: Mock,
mock_should_run_python_linters: Mock,
mock_subprocess_run: Mock,
mock_changed_files: Mock,
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test with branch argument."""
# Ensure we're not in GITHUB_ACTIONS mode for this test
monkeypatch.delenv("GITHUB_ACTIONS", raising=False)
mock_should_run_integration_tests.return_value = False
mock_should_run_clang_tidy.return_value = True
mock_should_run_clang_format.return_value = False
@@ -249,9 +216,6 @@ def test_main_with_branch_argument(
assert output["component_test_count"] == len(
output["changed_components_with_tests"]
)
# memory_impact should be present
assert "memory_impact" in output
assert output["memory_impact"]["should_run"] == "false"
def test_should_run_integration_tests(
@@ -439,15 +403,10 @@ def test_main_filters_components_without_tests(
mock_should_run_clang_format: Mock,
mock_should_run_python_linters: Mock,
mock_subprocess_run: Mock,
mock_changed_files: Mock,
capsys: pytest.CaptureFixture[str],
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that components without test files are filtered out."""
# Ensure we're not in GITHUB_ACTIONS mode for this test
monkeypatch.delenv("GITHUB_ACTIONS", raising=False)
mock_should_run_integration_tests.return_value = False
mock_should_run_clang_tidy.return_value = False
mock_should_run_clang_format.return_value = False
@@ -481,10 +440,9 @@ def test_main_filters_components_without_tests(
airthings_dir = tests_dir / "airthings_ble"
airthings_dir.mkdir(parents=True)
# Mock root_path to use tmp_path (need to patch both determine_jobs and helpers)
# Mock root_path to use tmp_path
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch("sys.argv", ["determine-jobs.py"]),
):
# Clear the cache since we're mocking root_path
@@ -501,188 +459,3 @@ def test_main_filters_components_without_tests(
assert set(output["changed_components_with_tests"]) == {"wifi", "sensor"}
# component_test_count should be based on components with tests
assert output["component_test_count"] == 2
# memory_impact should be present
assert "memory_impact" in output
assert output["memory_impact"]["should_run"] == "false"
# Tests for detect_memory_impact_config function
def test_detect_memory_impact_config_with_common_platform(tmp_path: Path) -> None:
"""Test memory impact detection when components share a common platform."""
# Create test directory structure
tests_dir = tmp_path / "tests" / "components"
# wifi component with esp32-idf test
wifi_dir = tests_dir / "wifi"
wifi_dir.mkdir(parents=True)
(wifi_dir / "test.esp32-idf.yaml").write_text("test: wifi")
# api component with esp32-idf test
api_dir = tests_dir / "api"
api_dir.mkdir(parents=True)
(api_dir / "test.esp32-idf.yaml").write_text("test: api")
# Mock changed_files to return wifi and api component changes
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch.object(determine_jobs, "changed_files") as mock_changed_files,
):
mock_changed_files.return_value = [
"esphome/components/wifi/wifi.cpp",
"esphome/components/api/api.cpp",
]
determine_jobs._component_has_tests.cache_clear()
result = determine_jobs.detect_memory_impact_config()
assert result["should_run"] == "true"
assert set(result["components"]) == {"wifi", "api"}
assert result["platform"] == "esp32-idf" # Common platform
assert result["use_merged_config"] == "true"
def test_detect_memory_impact_config_core_only_changes(tmp_path: Path) -> None:
"""Test memory impact detection with core-only changes (no component changes)."""
# Create test directory structure with fallback component
tests_dir = tmp_path / "tests" / "components"
# api component (fallback component) with esp32-idf test
api_dir = tests_dir / "api"
api_dir.mkdir(parents=True)
(api_dir / "test.esp32-idf.yaml").write_text("test: api")
# Mock changed_files to return only core files (no component files)
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch.object(determine_jobs, "changed_files") as mock_changed_files,
):
mock_changed_files.return_value = [
"esphome/core/application.cpp",
"esphome/core/component.h",
]
determine_jobs._component_has_tests.cache_clear()
result = determine_jobs.detect_memory_impact_config()
assert result["should_run"] == "true"
assert result["components"] == ["api"] # Fallback component
assert result["platform"] == "esp32-idf" # Fallback platform
assert result["use_merged_config"] == "true"
def test_detect_memory_impact_config_no_common_platform(tmp_path: Path) -> None:
"""Test memory impact detection when components have no common platform."""
# Create test directory structure
tests_dir = tmp_path / "tests" / "components"
# wifi component only has esp32-idf test
wifi_dir = tests_dir / "wifi"
wifi_dir.mkdir(parents=True)
(wifi_dir / "test.esp32-idf.yaml").write_text("test: wifi")
# logger component only has esp8266-ard test
logger_dir = tests_dir / "logger"
logger_dir.mkdir(parents=True)
(logger_dir / "test.esp8266-ard.yaml").write_text("test: logger")
# Mock changed_files to return both components
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch.object(determine_jobs, "changed_files") as mock_changed_files,
):
mock_changed_files.return_value = [
"esphome/components/wifi/wifi.cpp",
"esphome/components/logger/logger.cpp",
]
determine_jobs._component_has_tests.cache_clear()
result = determine_jobs.detect_memory_impact_config()
# Should pick the most frequently supported platform
assert result["should_run"] == "true"
assert set(result["components"]) == {"wifi", "logger"}
# When no common platform, picks most commonly supported
# esp8266-ard is preferred over esp32-idf in the preference list
assert result["platform"] in ["esp32-idf", "esp8266-ard"]
assert result["use_merged_config"] == "true"
def test_detect_memory_impact_config_no_changes(tmp_path: Path) -> None:
"""Test memory impact detection when no files changed."""
# Mock changed_files to return empty list
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch.object(determine_jobs, "changed_files") as mock_changed_files,
):
mock_changed_files.return_value = []
determine_jobs._component_has_tests.cache_clear()
result = determine_jobs.detect_memory_impact_config()
assert result["should_run"] == "false"
def test_detect_memory_impact_config_no_components_with_tests(tmp_path: Path) -> None:
"""Test memory impact detection when changed components have no tests."""
# Create test directory structure
tests_dir = tmp_path / "tests" / "components"
# Create component directory but no test files
custom_component_dir = tests_dir / "my_custom_component"
custom_component_dir.mkdir(parents=True)
# Mock changed_files to return component without tests
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch.object(determine_jobs, "changed_files") as mock_changed_files,
):
mock_changed_files.return_value = [
"esphome/components/my_custom_component/component.cpp",
]
determine_jobs._component_has_tests.cache_clear()
result = determine_jobs.detect_memory_impact_config()
assert result["should_run"] == "false"
def test_detect_memory_impact_config_skips_base_bus_components(tmp_path: Path) -> None:
"""Test that base bus components (i2c, spi, uart) are skipped."""
# Create test directory structure
tests_dir = tmp_path / "tests" / "components"
# i2c component (should be skipped as it's a base bus component)
i2c_dir = tests_dir / "i2c"
i2c_dir.mkdir(parents=True)
(i2c_dir / "test.esp32-idf.yaml").write_text("test: i2c")
# wifi component (should not be skipped)
wifi_dir = tests_dir / "wifi"
wifi_dir.mkdir(parents=True)
(wifi_dir / "test.esp32-idf.yaml").write_text("test: wifi")
# Mock changed_files to return both i2c and wifi
with (
patch.object(determine_jobs, "root_path", str(tmp_path)),
patch.object(helpers, "root_path", str(tmp_path)),
patch.object(determine_jobs, "changed_files") as mock_changed_files,
):
mock_changed_files.return_value = [
"esphome/components/i2c/i2c.cpp",
"esphome/components/wifi/wifi.cpp",
]
determine_jobs._component_has_tests.cache_clear()
result = determine_jobs.detect_memory_impact_config()
# Should only include wifi, not i2c
assert result["should_run"] == "true"
assert result["components"] == ["wifi"]
assert "i2c" not in result["components"]

View File

@@ -1,9 +0,0 @@
substitutions:
A: component1
B: component2
C: component3
some_component:
- id: component1
value: 2
- id: component2
value: 5

View File

@@ -1,22 +0,0 @@
substitutions:
A: component1
B: component2
C: component3
packages:
- some_component:
- id: component1
value: 1
- id: !extend ${B}
value: 4
- id: !extend ${B}
value: 5
- id: component3
value: 6
some_component:
- id: !extend ${A}
value: 2
- id: component2
value: 3
- id: !remove ${C}

View File

@@ -321,14 +321,12 @@ def test_choose_upload_log_host_with_serial_device_no_ports(
) -> None:
"""Test SERIAL device when no serial ports are found."""
setup_core()
with pytest.raises(
EsphomeError, match="All specified devices .* could not be resolved"
):
choose_upload_log_host(
default="SERIAL",
check_default=None,
purpose=Purpose.UPLOADING,
)
result = choose_upload_log_host(
default="SERIAL",
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == []
assert "No serial ports found, skipping SERIAL device" in caplog.text
@@ -369,14 +367,12 @@ def test_choose_upload_log_host_with_ota_device_with_api_config() -> None:
"""Test OTA device when API is configured (no upload without OTA in config)."""
setup_core(config={CONF_API: {}}, address="192.168.1.100")
with pytest.raises(
EsphomeError, match="All specified devices .* could not be resolved"
):
choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
result = choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == []
def test_choose_upload_log_host_with_ota_device_with_api_config_logging() -> None:
@@ -409,14 +405,12 @@ def test_choose_upload_log_host_with_ota_device_no_fallback() -> None:
"""Test OTA device with no valid fallback options."""
setup_core()
with pytest.raises(
EsphomeError, match="All specified devices .* could not be resolved"
):
choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
result = choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == []
@pytest.mark.usefixtures("mock_choose_prompt")
@@ -621,19 +615,21 @@ def test_choose_upload_log_host_empty_defaults_list() -> None:
@pytest.mark.usefixtures("mock_no_serial_ports", "mock_no_mqtt_logging")
def test_choose_upload_log_host_all_devices_unresolved() -> None:
def test_choose_upload_log_host_all_devices_unresolved(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test when all specified devices cannot be resolved."""
setup_core()
with pytest.raises(
EsphomeError,
match=r"All specified devices \['SERIAL', 'OTA'\] could not be resolved",
):
choose_upload_log_host(
default=["SERIAL", "OTA"],
check_default=None,
purpose=Purpose.UPLOADING,
)
result = choose_upload_log_host(
default=["SERIAL", "OTA"],
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == []
assert (
"All specified devices: ['SERIAL', 'OTA'] could not be resolved." in caplog.text
)
@pytest.mark.usefixtures("mock_no_serial_ports", "mock_no_mqtt_logging")
@@ -766,14 +762,12 @@ def test_choose_upload_log_host_no_address_with_ota_config() -> None:
"""Test OTA device when OTA is configured but no address is set."""
setup_core(config={CONF_OTA: {}})
with pytest.raises(
EsphomeError, match="All specified devices .* could not be resolved"
):
choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
result = choose_upload_log_host(
default="OTA",
check_default=None,
purpose=Purpose.UPLOADING,
)
assert result == []
@dataclass

View File

@@ -387,42 +387,6 @@ def test_idedata_addr2line_path_unix(setup_core: Path) -> None:
assert result == "/usr/bin/addr2line"
def test_idedata_objdump_path_windows(setup_core: Path) -> None:
"""Test IDEData.objdump_path on Windows."""
raw_data = {"prog_path": "/path/to/firmware.elf", "cc_path": "C:\\tools\\gcc.exe"}
idedata = platformio_api.IDEData(raw_data)
result = idedata.objdump_path
assert result == "C:\\tools\\objdump.exe"
def test_idedata_objdump_path_unix(setup_core: Path) -> None:
"""Test IDEData.objdump_path on Unix."""
raw_data = {"prog_path": "/path/to/firmware.elf", "cc_path": "/usr/bin/gcc"}
idedata = platformio_api.IDEData(raw_data)
result = idedata.objdump_path
assert result == "/usr/bin/objdump"
def test_idedata_readelf_path_windows(setup_core: Path) -> None:
"""Test IDEData.readelf_path on Windows."""
raw_data = {"prog_path": "/path/to/firmware.elf", "cc_path": "C:\\tools\\gcc.exe"}
idedata = platformio_api.IDEData(raw_data)
result = idedata.readelf_path
assert result == "C:\\tools\\readelf.exe"
def test_idedata_readelf_path_unix(setup_core: Path) -> None:
"""Test IDEData.readelf_path on Unix."""
raw_data = {"prog_path": "/path/to/firmware.elf", "cc_path": "/usr/bin/gcc"}
idedata = platformio_api.IDEData(raw_data)
result = idedata.readelf_path
assert result == "/usr/bin/readelf"
def test_patch_structhash(setup_core: Path) -> None:
"""Test patch_structhash monkey patches platformio functions."""
# Create simple namespace objects to act as modules

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from esphome import config as config_module, yaml_util
from esphome.components import substitutions
from esphome.config import resolve_extend_remove
from esphome.config_helpers import merge_config
from esphome.const import CONF_PACKAGES, CONF_SUBSTITUTIONS
from esphome.core import CORE
@@ -82,8 +81,6 @@ def test_substitutions_fixtures(fixture_path):
substitutions.do_substitution_pass(config, None)
resolve_extend_remove(config)
# Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE
if expected_path.is_file():
expected = yaml_util.load_yaml(expected_path)