diff --git a/.clang-tidy.hash b/.clang-tidy.hash index f61b79de4d..4901c0ccac 100644 --- a/.clang-tidy.hash +++ b/.clang-tidy.hash @@ -1 +1 @@ -4368db58e8f884aff245996b1e8b644cc0796c0bb2fa706d5740d40b823d3ac9 +049d60eed541730efaa4c0dc5d337b4287bf29b6daa350b5dfc1f23915f1c52f diff --git a/.github/actions/build-image/action.yaml b/.github/actions/build-image/action.yaml index 403b9d8c2a..9c7f051e05 100644 --- a/.github/actions/build-image/action.yaml +++ b/.github/actions/build-image/action.yaml @@ -47,7 +47,7 @@ runs: - name: Build and push to ghcr by digest id: build-ghcr - uses: docker/build-push-action@v6.18.0 + uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0 env: DOCKER_BUILD_SUMMARY: false DOCKER_BUILD_RECORD_UPLOAD: false @@ -73,7 +73,7 @@ runs: - name: Build and push to dockerhub by digest id: build-dockerhub - uses: docker/build-push-action@v6.18.0 + uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6.18.0 env: DOCKER_BUILD_SUMMARY: false DOCKER_BUILD_RECORD_UPLOAD: false diff --git a/.github/actions/restore-python/action.yml b/.github/actions/restore-python/action.yml index 5d290894a7..f314e79ad9 100644 --- a/.github/actions/restore-python/action.yml +++ b/.github/actions/restore-python/action.yml @@ -17,12 +17,12 @@ runs: steps: - name: Set up Python ${{ inputs.python-version }} id: python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: ${{ inputs.python-version }} - name: Restore Python virtual environment id: cache-venv - uses: actions/cache/restore@v4.2.4 + uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: venv # yamllint disable-line rule:line-length diff --git a/.github/workflows/auto-label-pr.yml b/.github/workflows/auto-label-pr.yml index 66369c706f..1670bd1821 100644 --- a/.github/workflows/auto-label-pr.yml +++ b/.github/workflows/auto-label-pr.yml @@ -22,17 +22,17 @@ jobs: if: github.event.action != 'labeled' || github.event.sender.type != 'Bot' steps: - name: Checkout - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Generate a token id: generate-token - uses: actions/create-github-app-token@v2 + uses: actions/create-github-app-token@67018539274d69449ef7c02e8e71183d1719ab42 # v2 with: app-id: ${{ secrets.ESPHOME_GITHUB_APP_ID }} private-key: ${{ secrets.ESPHOME_GITHUB_APP_PRIVATE_KEY }} - name: Auto Label PR - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ steps.generate-token.outputs.token }} script: | diff --git a/.github/workflows/ci-api-proto.yml b/.github/workflows/ci-api-proto.yml index ec214d1a77..c122859442 100644 --- a/.github/workflows/ci-api-proto.yml +++ b/.github/workflows/ci-api-proto.yml @@ -21,9 +21,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.11" @@ -47,7 +47,7 @@ jobs: fi - if: failure() name: Review PR - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | await github.rest.pulls.createReview({ @@ -62,7 +62,7 @@ jobs: run: git diff - if: failure() name: Archive artifacts - uses: actions/upload-artifact@v4.6.2 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: generated-proto-files path: | @@ -70,7 +70,7 @@ jobs: esphome/components/api/api_pb2_service.* - if: success() name: Dismiss review - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | let reviews = await github.rest.pulls.listReviews({ diff --git a/.github/workflows/ci-clang-tidy-hash.yml b/.github/workflows/ci-clang-tidy-hash.yml index 2f47386abf..78d1c2b87f 100644 --- a/.github/workflows/ci-clang-tidy-hash.yml +++ b/.github/workflows/ci-clang-tidy-hash.yml @@ -6,6 +6,7 @@ on: - ".clang-tidy" - "platformio.ini" - "requirements_dev.txt" + - "sdkconfig.defaults" - ".clang-tidy.hash" - "script/clang_tidy_hash.py" - ".github/workflows/ci-clang-tidy-hash.yml" @@ -20,10 +21,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.11" @@ -41,7 +42,7 @@ jobs: - if: failure() name: Request changes - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | await github.rest.pulls.createReview({ @@ -54,7 +55,7 @@ jobs: - if: success() name: Dismiss review - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | let reviews = await github.rest.pulls.listReviews({ diff --git a/.github/workflows/ci-docker.yml b/.github/workflows/ci-docker.yml index 915a4dfb7e..7111c61dda 100644 --- a/.github/workflows/ci-docker.yml +++ b/.github/workflows/ci-docker.yml @@ -43,13 +43,13 @@ jobs: - "docker" # - "lint" steps: - - uses: actions/checkout@v5.0.0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.11" - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3.11.1 + uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 - name: Set TAG run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 07fd91b1c8..1d7043c888 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,18 +36,18 @@ jobs: cache-key: ${{ steps.cache-key.outputs.key }} steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Generate cache-key id: cache-key run: echo key="${{ hashFiles('requirements.txt', 'requirements_test.txt', '.pre-commit-config.yaml') }}" >> $GITHUB_OUTPUT - name: Set up Python ${{ env.DEFAULT_PYTHON }} id: python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: ${{ env.DEFAULT_PYTHON }} - name: Restore Python virtual environment id: cache-venv - uses: actions/cache@v4.2.4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: venv # yamllint disable-line rule:line-length @@ -70,7 +70,7 @@ jobs: if: needs.determine-jobs.outputs.python-linters == 'true' steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Restore Python uses: ./.github/actions/restore-python with: @@ -91,7 +91,7 @@ jobs: - common steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Restore Python uses: ./.github/actions/restore-python with: @@ -105,6 +105,7 @@ jobs: script/ci-custom.py script/build_codeowners.py --check script/build_language_schema.py --check + script/generate-esp32-boards.py --check pytest: name: Run pytest @@ -136,7 +137,7 @@ jobs: - common steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Restore Python id: restore-python uses: ./.github/actions/restore-python @@ -156,12 +157,12 @@ jobs: . venv/bin/activate pytest -vv --cov-report=xml --tb=native -n auto tests --ignore=tests/integration/ - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5.5.1 + uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5.5.1 with: token: ${{ secrets.CODECOV_TOKEN }} - name: Save Python virtual environment cache if: github.ref == 'refs/heads/dev' - uses: actions/cache/save@v4.2.4 + uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: venv key: ${{ runner.os }}-${{ steps.restore-python.outputs.python-version }}-venv-${{ needs.common.outputs.cache-key }} @@ -179,7 +180,7 @@ jobs: component-test-count: ${{ steps.determine.outputs.component-test-count }} steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: # Fetch enough history to find the merge base fetch-depth: 2 @@ -214,15 +215,15 @@ jobs: if: needs.determine-jobs.outputs.integration-tests == 'true' steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python 3.13 id: python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.13" - name: Restore Python virtual environment id: cache-venv - uses: actions/cache@v4.2.4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: venv key: ${{ runner.os }}-${{ steps.python.outputs.python-version }}-venv-${{ needs.common.outputs.cache-key }} @@ -287,7 +288,7 @@ jobs: steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: # Need history for HEAD~1 to work for checking changed files fetch-depth: 2 @@ -300,14 +301,14 @@ jobs: - name: Cache platformio if: github.ref == 'refs/heads/dev' - uses: actions/cache@v4.2.4 + uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: ~/.platformio key: platformio-${{ matrix.pio_cache_key }}-${{ hashFiles('platformio.ini') }} - name: Cache platformio if: github.ref != 'refs/heads/dev' - uses: actions/cache/restore@v4.2.4 + uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0 with: path: ~/.platformio key: platformio-${{ matrix.pio_cache_key }}-${{ hashFiles('platformio.ini') }} @@ -374,7 +375,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Restore Python uses: ./.github/actions/restore-python with: @@ -390,7 +391,7 @@ jobs: ./script/test_build_components -e compile -c ${{ matrix.file }} test-build-components-splitter: - name: Split components for testing into 20 groups maximum + name: Split components for testing into 10 components per group runs-on: ubuntu-24.04 needs: - common @@ -400,11 +401,11 @@ jobs: matrix: ${{ steps.split.outputs.components }} steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 - - name: Split components into 20 groups + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - name: Split components into groups of 10 id: split run: | - components=$(echo '${{ needs.determine-jobs.outputs.changed-components }}' | jq -c '.[]' | shuf | jq -s -c '[_nwise(20) | join(" ")]') + components=$(echo '${{ needs.determine-jobs.outputs.changed-components }}' | jq -c '.[]' | shuf | jq -s -c '[_nwise(10) | join(" ")]') echo "components=$components" >> $GITHUB_OUTPUT test-build-components-split: @@ -430,7 +431,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Restore Python uses: ./.github/actions/restore-python with: @@ -459,16 +460,16 @@ jobs: if: github.event_name == 'pull_request' && github.base_ref != 'beta' && github.base_ref != 'release' steps: - name: Check out code from GitHub - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Restore Python uses: ./.github/actions/restore-python with: python-version: ${{ env.DEFAULT_PYTHON }} cache-key: ${{ needs.common.outputs.cache-key }} - - uses: pre-commit/action@v3.0.1 + - uses: esphome/action@43cd1109c09c544d97196f7730ee5b2e0cc6d81e # v3.0.1 fork with pinned actions/cache env: SKIP: pylint,clang-tidy-hash - - uses: pre-commit-ci/lite-action@v1.1.0 + - uses: pre-commit-ci/lite-action@5d6cc0eb514c891a40562a58a8e71576c5c7fb43 # v1.1.0 if: always() ci-status: diff --git a/.github/workflows/codeowner-review-request.yml b/.github/workflows/codeowner-review-request.yml index 475e05b970..563d55f42b 100644 --- a/.github/workflows/codeowner-review-request.yml +++ b/.github/workflows/codeowner-review-request.yml @@ -25,7 +25,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Request reviews from component codeowners - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | const owner = context.repo.owner; diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 7a7c39aeec..59f58b7236 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -54,11 +54,11 @@ jobs: # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages steps: - name: Checkout repository - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@e296a935590eb16afc0c0108289f68c87e2a89a5 # v4.30.7 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} @@ -86,6 +86,6 @@ jobs: exit 1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@e296a935590eb16afc0c0108289f68c87e2a89a5 # v4.30.7 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/external-component-bot.yml b/.github/workflows/external-component-bot.yml index 736c986f7e..4fa020f63d 100644 --- a/.github/workflows/external-component-bot.yml +++ b/.github/workflows/external-component-bot.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add external component comment - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/issue-codeowner-notify.yml b/.github/workflows/issue-codeowner-notify.yml index ab9b96b45a..6faf956c87 100644 --- a/.github/workflows/issue-codeowner-notify.yml +++ b/.github/workflows/issue-codeowner-notify.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Notify codeowners for component issues - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | const owner = context.repo.owner; diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index efc8424cd6..2b3b3bdc1b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ jobs: branch_build: ${{ steps.tag.outputs.branch_build }} deploy_env: ${{ steps.tag.outputs.deploy_env }} steps: - - uses: actions/checkout@v5.0.0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Get tag id: tag # yamllint disable rule:line-length @@ -60,9 +60,9 @@ jobs: contents: read id-token: write steps: - - uses: actions/checkout@v5.0.0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.x" - name: Build @@ -70,7 +70,7 @@ jobs: pip3 install build python3 -m build - name: Publish - uses: pypa/gh-action-pypi-publish@v1.13.0 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: skip-existing: true @@ -92,22 +92,22 @@ jobs: os: "ubuntu-24.04-arm" steps: - - uses: actions/checkout@v5.0.0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Set up Python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.11" - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3.11.1 + uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 - name: Log in to docker hub - uses: docker/login-action@v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Log in to the GitHub container registry - uses: docker/login-action@v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -138,7 +138,7 @@ jobs: # version: ${{ needs.init.outputs.tag }} - name: Upload digests - uses: actions/upload-artifact@v4.6.2 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: digests-${{ matrix.platform.arch }} path: /tmp/digests @@ -168,27 +168,27 @@ jobs: - ghcr - dockerhub steps: - - uses: actions/checkout@v5.0.0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Download digests - uses: actions/download-artifact@v5.0.0 + uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0 with: pattern: digests-* path: /tmp/digests merge-multiple: true - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3.11.1 + uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1 - name: Log in to docker hub if: matrix.registry == 'dockerhub' - uses: docker/login-action@v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Log in to the GitHub container registry if: matrix.registry == 'ghcr' - uses: docker/login-action@v3.5.0 + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 with: registry: ghcr.io username: ${{ github.actor }} @@ -220,7 +220,7 @@ jobs: - deploy-manifest steps: - name: Trigger Workflow - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.DEPLOY_HA_ADDON_REPO_TOKEN }} script: | @@ -246,7 +246,7 @@ jobs: environment: ${{ needs.init.outputs.deploy_env }} steps: - name: Trigger Workflow - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.DEPLOY_ESPHOME_SCHEMA_REPO_TOKEN }} script: | diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 88e07d3f58..63a8ade37f 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -15,36 +15,52 @@ concurrency: jobs: stale: + if: github.repository_owner == 'esphome' runs-on: ubuntu-latest steps: - - uses: actions/stale@v10.0.0 + - name: Stale + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: + debug-only: ${{ github.ref != 'refs/heads/dev' }} # Dry-run when not run on dev branch + remove-stale-when-updated: true + operations-per-run: 150 + + # The 90 day stale policy for PRs + # - PRs + # - No PRs marked as "not-stale" + # - No Issues (see below) days-before-pr-stale: 90 days-before-pr-close: 7 - days-before-issue-stale: -1 - days-before-issue-close: -1 - remove-stale-when-updated: true stale-pr-label: "stale" exempt-pr-labels: "not-stale" stale-pr-message: > There hasn't been any activity on this pull request recently. This pull request has been automatically marked as stale because of that and will be closed if no further activity occurs within 7 days. - Thank you for your contributions. - # Use stale to automatically close issues with a - # reference to the issue tracker - close-issues: - runs-on: ubuntu-latest - steps: - - uses: actions/stale@v10.0.0 - with: - days-before-pr-stale: -1 - days-before-pr-close: -1 - days-before-issue-stale: 1 - days-before-issue-close: 1 - remove-stale-when-updated: true + If you are the author of this PR, please leave a comment if you want + to keep it open. Also, please rebase your PR onto the latest dev + branch to ensure that it's up to date with the latest changes. + + Thank you for your contribution! + + # The 90 day stale policy for Issues + # - Issues + # - No Issues marked as "not-stale" + # - No PRs (see above) + days-before-issue-stale: 90 + days-before-issue-close: 7 stale-issue-label: "stale" exempt-issue-labels: "not-stale" stale-issue-message: > - https://github.com/esphome/esphome/issues/430 + There hasn't been any activity on this issue recently. Due to the + high number of incoming GitHub notifications, we have to clean some + of the old issues, as many of them have already been resolved with + the latest updates. + + Please make sure to update to the latest ESPHome version and + check if that solves the issue. Let us know if that works for you by + adding a comment 👍 + + This issue has now been marked as stale and will be closed if no + further activity occurs. Thank you for your contributions. diff --git a/.github/workflows/status-check-labels.yml b/.github/workflows/status-check-labels.yml index 675be49c27..e44fd18132 100644 --- a/.github/workflows/status-check-labels.yml +++ b/.github/workflows/status-check-labels.yml @@ -16,7 +16,7 @@ jobs: - merge-after-release steps: - name: Check for ${{ matrix.label }} label - uses: actions/github-script@v8.0.0 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | const { data: labels } = await github.rest.issues.listLabelsOnIssue({ diff --git a/.github/workflows/sync-device-classes.yml b/.github/workflows/sync-device-classes.yml index b129e8f4bf..9479645ccc 100644 --- a/.github/workflows/sync-device-classes.yml +++ b/.github/workflows/sync-device-classes.yml @@ -13,16 +13,16 @@ jobs: if: github.repository == 'esphome/esphome' steps: - name: Checkout - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Checkout Home Assistant - uses: actions/checkout@v5.0.0 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: home-assistant/core path: lib/home-assistant - name: Setup Python - uses: actions/setup-python@v6.0.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: 3.13 @@ -30,13 +30,18 @@ jobs: run: | python -m pip install --upgrade pip pip install -e lib/home-assistant + pip install -r requirements_test.txt pre-commit - name: Sync run: | python ./script/sync-device_class.py + - name: Run pre-commit hooks + run: | + python script/run-in-env.py pre-commit run --all-files + - name: Commit changes - uses: peter-evans/create-pull-request@v7.0.8 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: commit-message: "Synchronise Device Classes from Home Assistant" committer: esphomebot diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b161cf05c..521aaf9cc8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.12.12 + rev: v0.14.0 hooks: # Run the linter. - id: ruff diff --git a/CODEOWNERS b/CODEOWNERS index dc567ca5c0..03ea5d0e47 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -139,6 +139,7 @@ esphome/components/ens160_base/* @latonita @vincentscode esphome/components/ens160_i2c/* @latonita esphome/components/ens160_spi/* @latonita esphome/components/ens210/* @itn3rd77 +esphome/components/epaper_spi/* @esphome/core esphome/components/es7210/* @kahrendt esphome/components/es7243e/* @kbx81 esphome/components/es8156/* @kbx81 @@ -160,7 +161,6 @@ esphome/components/esp_ldo/* @clydebarrow esphome/components/espnow/* @jesserockz esphome/components/ethernet_info/* @gtjadsonsantos esphome/components/event/* @nohat -esphome/components/event_emitter/* @Rapsssito esphome/components/exposure_notifications/* @OttoWinter esphome/components/ezo/* @ssieb esphome/components/ezo_pmp/* @carlos-sarmiento @@ -257,6 +257,7 @@ esphome/components/libretiny_pwm/* @kuba2k2 esphome/components/light/* @esphome/core esphome/components/lightwaverf/* @max246 esphome/components/lilygo_t5_47/touchscreen/* @jesserockz +esphome/components/lm75b/* @beormund esphome/components/ln882x/* @lamauny esphome/components/lock/* @esphome/core esphome/components/logger/* @esphome/core @@ -407,6 +408,7 @@ esphome/components/sensor/* @esphome/core esphome/components/sfa30/* @ghsensdev esphome/components/sgp40/* @SenexCrenshaw esphome/components/sgp4x/* @martgras @SenexCrenshaw +esphome/components/sha256/* @esphome/core esphome/components/shelly_dimmer/* @edge90 @rnauber esphome/components/sht3xd/* @mrtoy-me esphome/components/sht4x/* @sjtrny @@ -428,6 +430,7 @@ esphome/components/speaker/media_player/* @kahrendt @synesthesiam esphome/components/spi/* @clydebarrow @esphome/core esphome/components/spi_device/* @clydebarrow esphome/components/spi_led_strip/* @clydebarrow +esphome/components/split_buffer/* @jesserockz esphome/components/sprinkler/* @kbx81 esphome/components/sps30/* @martgras esphome/components/ssd1322_base/* @kbx81 @@ -533,6 +536,7 @@ esphome/components/wk2204_spi/* @DrCoolZic esphome/components/wk2212_i2c/* @DrCoolZic esphome/components/wk2212_spi/* @DrCoolZic esphome/components/wl_134/* @hobbypunk90 +esphome/components/wts01/* @alepee esphome/components/x9c/* @EtienneMD esphome/components/xgzp68xx/* @gcormier esphome/components/xiaomi_hhccjcy10/* @fariouche @@ -548,3 +552,4 @@ esphome/components/xxtea/* @clydebarrow esphome/components/zephyr/* @tomaszduda23 esphome/components/zhlt01/* @cfeenstra1024 esphome/components/zio_ultrasonic/* @kahrendt +esphome/components/zwave_proxy/* @kbx81 diff --git a/Doxyfile b/Doxyfile index d14d1a2adb..8284c564e0 100644 --- a/Doxyfile +++ b/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = ESPHome # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 2025.9.3 +PROJECT_NUMBER = 2025.10.0b1 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/esphome/__main__.py b/esphome/__main__.py index f54fa8e3c6..b0f541f521 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -6,6 +6,7 @@ import getpass import importlib import logging import os +from pathlib import Path import re import sys import time @@ -13,9 +14,11 @@ from typing import Protocol import argcomplete +# Note: Do not import modules from esphome.components here, as this would +# cause them to be loaded before external components are processed, resulting +# in the built-in version being used instead of the external component one. from esphome import const, writer, yaml_util import esphome.codegen as cg -from esphome.components.mqtt import CONF_DISCOVER_IP from esphome.config import iter_component_configs, read_config, strip_default_ids from esphome.const import ( ALLOWED_NAME_CHARS, @@ -114,6 +117,14 @@ class Purpose(StrEnum): LOGGING = "logging" +def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]: + """Resolve an address using cache if available, otherwise return the address itself.""" + if CORE.address_cache and (cached := CORE.address_cache.get_addresses(address)): + _LOGGER.debug("Using cached addresses for %s: %s", purpose.value, cached) + return cached + return [address] + + def choose_upload_log_host( default: list[str] | str | None, check_default: str | None, @@ -142,7 +153,7 @@ def choose_upload_log_host( (purpose == Purpose.LOGGING and has_api()) or (purpose == Purpose.UPLOADING and has_ota()) ): - resolved.append(CORE.address) + resolved.extend(_resolve_with_cache(CORE.address, purpose)) if purpose == Purpose.LOGGING: if has_api() and has_mqtt_ip_lookup(): @@ -152,15 +163,14 @@ def choose_upload_log_host( resolved.append("MQTT") if has_api() and has_non_ip_address(): - resolved.append(CORE.address) + resolved.extend(_resolve_with_cache(CORE.address, purpose)) elif purpose == Purpose.UPLOADING: if has_ota() and has_mqtt_ip_lookup(): resolved.append("MQTTIP") if has_ota() and has_non_ip_address(): - resolved.append(CORE.address) - + resolved.extend(_resolve_with_cache(CORE.address, purpose)) else: resolved.append(device) if not resolved: @@ -232,6 +242,8 @@ def has_ota() -> bool: def has_mqtt_ip_lookup() -> bool: """Check if MQTT is available and IP lookup is supported.""" + from esphome.components.mqtt import CONF_DISCOVER_IP + if CONF_MQTT not in CORE.config: return False # Default Enabled @@ -445,7 +457,7 @@ def upload_using_esptool( "detect", ] for img in flash_images: - cmd += [img.offset, img.path] + cmd += [img.offset, str(img.path)] if os.environ.get("ESPHOME_USE_SUBPROCESS") is None: import esptool @@ -531,7 +543,10 @@ def upload_program( remote_port = int(ota_conf[CONF_PORT]) password = ota_conf.get(CONF_PASSWORD, "") - binary = args.file if getattr(args, "file", None) is not None else CORE.firmware_bin + if getattr(args, "file", None) is not None: + binary = Path(args.file) + else: + binary = CORE.firmware_bin # MQTT address resolution if get_port_type(host) in ("MQTT", "MQTTIP"): @@ -598,7 +613,7 @@ def clean_mqtt(config: ConfigType, args: ArgsProtocol) -> int | None: def command_wizard(args: ArgsProtocol) -> int | None: from esphome import wizard - return wizard.wizard(args.configuration) + return wizard.wizard(Path(args.configuration)) def command_config(args: ArgsProtocol, config: ConfigType) -> int | None: @@ -720,6 +735,16 @@ def command_clean_mqtt(args: ArgsProtocol, config: ConfigType) -> int | None: return clean_mqtt(config, args) +def command_clean_all(args: ArgsProtocol) -> int | None: + try: + writer.clean_all(args.configuration) + except OSError as err: + _LOGGER.error("Error cleaning all files: %s", err) + return 1 + _LOGGER.info("Done!") + return 0 + + def command_mqtt_fingerprint(args: ArgsProtocol, config: ConfigType) -> int | None: from esphome import mqtt @@ -761,7 +786,7 @@ def command_update_all(args: ArgsProtocol) -> int | None: safe_print(f"{half_line}{middle_text}{half_line}") for f in files: - safe_print(f"Updating {color(AnsiFore.CYAN, f)}") + safe_print(f"Updating {color(AnsiFore.CYAN, str(f))}") safe_print("-" * twidth) safe_print() if CORE.dashboard: @@ -773,10 +798,10 @@ def command_update_all(args: ArgsProtocol) -> int | None: "esphome", "run", f, "--no-logs", "--device", "OTA" ) if rc == 0: - print_bar(f"[{color(AnsiFore.BOLD_GREEN, 'SUCCESS')}] {f}") + print_bar(f"[{color(AnsiFore.BOLD_GREEN, 'SUCCESS')}] {str(f)}") success[f] = True else: - print_bar(f"[{color(AnsiFore.BOLD_RED, 'ERROR')}] {f}") + print_bar(f"[{color(AnsiFore.BOLD_RED, 'ERROR')}] {str(f)}") success[f] = False safe_print() @@ -787,9 +812,9 @@ def command_update_all(args: ArgsProtocol) -> int | None: failed = 0 for f in files: if success[f]: - safe_print(f" - {f}: {color(AnsiFore.GREEN, 'SUCCESS')}") + safe_print(f" - {str(f)}: {color(AnsiFore.GREEN, 'SUCCESS')}") else: - safe_print(f" - {f}: {color(AnsiFore.BOLD_RED, 'FAILED')}") + safe_print(f" - {str(f)}: {color(AnsiFore.BOLD_RED, 'FAILED')}") failed += 1 return failed @@ -811,7 +836,8 @@ def command_idedata(args: ArgsProtocol, config: ConfigType) -> int: def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None: - for c in args.name: + new_name = args.name + for c in new_name: if c not in ALLOWED_NAME_CHARS: print( color( @@ -822,8 +848,7 @@ def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None: ) return 1 # Load existing yaml file - with open(CORE.config_path, mode="r+", encoding="utf-8") as raw_file: - raw_contents = raw_file.read() + raw_contents = CORE.config_path.read_text(encoding="utf-8") yaml = yaml_util.load_yaml(CORE.config_path) if CONF_ESPHOME not in yaml or CONF_NAME not in yaml[CONF_ESPHOME]: @@ -838,7 +863,7 @@ def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None: if match is None: new_raw = re.sub( rf"name:\s+[\"']?{old_name}[\"']?", - f'name: "{args.name}"', + f'name: "{new_name}"', raw_contents, ) else: @@ -858,29 +883,28 @@ def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None: new_raw = re.sub( rf"^(\s+{match.group(1)}):\s+[\"']?{old_name}[\"']?", - f'\\1: "{args.name}"', + f'\\1: "{new_name}"', raw_contents, flags=re.MULTILINE, ) - new_path = os.path.join(CORE.config_dir, args.name + ".yaml") + new_path: Path = CORE.config_dir / (new_name + ".yaml") print( - f"Updating {color(AnsiFore.CYAN, CORE.config_path)} to {color(AnsiFore.CYAN, new_path)}" + f"Updating {color(AnsiFore.CYAN, str(CORE.config_path))} to {color(AnsiFore.CYAN, str(new_path))}" ) print() - with open(new_path, mode="w", encoding="utf-8") as new_file: - new_file.write(new_raw) + new_path.write_text(new_raw, encoding="utf-8") - rc = run_external_process("esphome", "config", new_path) + rc = run_external_process("esphome", "config", str(new_path)) if rc != 0: print(color(AnsiFore.BOLD_RED, "Rename failed. Reverting changes.")) - os.remove(new_path) + new_path.unlink() return 1 cli_args = [ "run", - new_path, + str(new_path), "--no-logs", "--device", CORE.address, @@ -894,11 +918,11 @@ def command_rename(args: ArgsProtocol, config: ConfigType) -> int | None: except KeyboardInterrupt: rc = 1 if rc != 0: - os.remove(new_path) + new_path.unlink() return 1 if CORE.config_path != new_path: - os.remove(CORE.config_path) + CORE.config_path.unlink() print(color(AnsiFore.BOLD_GREEN, "SUCCESS")) print() @@ -911,6 +935,7 @@ PRE_CONFIG_ACTIONS = { "dashboard": command_dashboard, "vscode": command_vscode, "update-all": command_update_all, + "clean-all": command_clean_all, } POST_CONFIG_ACTIONS = { @@ -919,9 +944,9 @@ POST_CONFIG_ACTIONS = { "upload": command_upload, "logs": command_logs, "run": command_run, + "clean": command_clean, "clean-mqtt": command_clean_mqtt, "mqtt-fingerprint": command_mqtt_fingerprint, - "clean": command_clean, "idedata": command_idedata, "rename": command_rename, "discover": command_discover, @@ -965,6 +990,18 @@ def parse_args(argv): help="Add a substitution", metavar=("key", "value"), ) + options_parser.add_argument( + "--mdns-address-cache", + help="mDNS address cache mapping in format 'hostname=ip1,ip2'", + action="append", + default=[], + ) + options_parser.add_argument( + "--dns-address-cache", + help="DNS address cache mapping in format 'hostname=ip1,ip2'", + action="append", + default=[], + ) parser = argparse.ArgumentParser( description=f"ESPHome {const.__version__}", parents=[options_parser] @@ -1122,6 +1159,13 @@ def parse_args(argv): "configuration", help="Your YAML configuration file(s).", nargs="+" ) + parser_clean_all = subparsers.add_parser( + "clean-all", help="Clean all build and platform files." + ) + parser_clean_all.add_argument( + "configuration", help="Your YAML configuration directory.", nargs="*" + ) + parser_dashboard = subparsers.add_parser( "dashboard", help="Create a simple web server for a dashboard." ) @@ -1168,7 +1212,7 @@ def parse_args(argv): parser_update = subparsers.add_parser("update-all") parser_update.add_argument( - "configuration", help="Your YAML configuration file directories.", nargs="+" + "configuration", help="Your YAML configuration file or directory.", nargs="+" ) parser_idedata = subparsers.add_parser("idedata") @@ -1212,9 +1256,15 @@ def parse_args(argv): def run_esphome(argv): + from esphome.address_cache import AddressCache + args = parse_args(argv) CORE.dashboard = args.dashboard + # Create address cache from command-line arguments + CORE.address_cache = AddressCache.from_cli_args( + args.mdns_address_cache, args.dns_address_cache + ) # Override log level if verbose is set if args.verbose: args.log_level = "DEBUG" @@ -1237,14 +1287,20 @@ def run_esphome(argv): _LOGGER.info("ESPHome %s", const.__version__) for conf_path in args.configuration: - if any(os.path.basename(conf_path) == x for x in SECRETS_FILES): + conf_path = Path(conf_path) + if any(conf_path.name == x for x in SECRETS_FILES): _LOGGER.warning("Skipping secrets file %s", conf_path) continue CORE.config_path = conf_path CORE.dashboard = args.dashboard - config = read_config(dict(args.substitution) if args.substitution else {}) + # For logs command, skip updating external components + skip_external = args.command == "logs" + config = read_config( + dict(args.substitution) if args.substitution else {}, + skip_external_update=skip_external, + ) if config is None: return 2 CORE.config = config diff --git a/esphome/address_cache.py b/esphome/address_cache.py new file mode 100644 index 0000000000..7c20be90f0 --- /dev/null +++ b/esphome/address_cache.py @@ -0,0 +1,142 @@ +"""Address cache for DNS and mDNS lookups.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + +_LOGGER = logging.getLogger(__name__) + + +def normalize_hostname(hostname: str) -> str: + """Normalize hostname for cache lookups. + + Removes trailing dots and converts to lowercase. + """ + return hostname.rstrip(".").lower() + + +class AddressCache: + """Cache for DNS and mDNS address lookups. + + This cache stores pre-resolved addresses from command-line arguments + to avoid slow DNS/mDNS lookups during builds. + """ + + def __init__( + self, + mdns_cache: dict[str, list[str]] | None = None, + dns_cache: dict[str, list[str]] | None = None, + ) -> None: + """Initialize the address cache. + + Args: + mdns_cache: Pre-populated mDNS addresses (hostname -> IPs) + dns_cache: Pre-populated DNS addresses (hostname -> IPs) + """ + self.mdns_cache = mdns_cache or {} + self.dns_cache = dns_cache or {} + + def _get_cached_addresses( + self, hostname: str, cache: dict[str, list[str]], cache_type: str + ) -> list[str] | None: + """Get cached addresses from a specific cache. + + Args: + hostname: The hostname to look up + cache: The cache dictionary to check + cache_type: Type of cache for logging ("mDNS" or "DNS") + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if addresses := cache.get(normalized): + _LOGGER.debug("Using %s cache for %s: %s", cache_type, hostname, addresses) + return addresses + return None + + def get_mdns_addresses(self, hostname: str) -> list[str] | None: + """Get cached mDNS addresses for a hostname. + + Args: + hostname: The hostname to look up (should end with .local) + + Returns: + List of IP addresses if found in cache, None otherwise + """ + return self._get_cached_addresses(hostname, self.mdns_cache, "mDNS") + + def get_dns_addresses(self, hostname: str) -> list[str] | None: + """Get cached DNS addresses for a hostname. + + Args: + hostname: The hostname to look up + + Returns: + List of IP addresses if found in cache, None otherwise + """ + return self._get_cached_addresses(hostname, self.dns_cache, "DNS") + + def get_addresses(self, hostname: str) -> list[str] | None: + """Get cached addresses for a hostname. + + Checks mDNS cache for .local domains, DNS cache otherwise. + + Args: + hostname: The hostname to look up + + Returns: + List of IP addresses if found in cache, None otherwise + """ + normalized = normalize_hostname(hostname) + if normalized.endswith(".local"): + return self.get_mdns_addresses(hostname) + return self.get_dns_addresses(hostname) + + def has_cache(self) -> bool: + """Check if any cache entries exist.""" + return bool(self.mdns_cache or self.dns_cache) + + @classmethod + def from_cli_args( + cls, mdns_args: Iterable[str], dns_args: Iterable[str] + ) -> AddressCache: + """Create cache from command-line arguments. + + Args: + mdns_args: List of mDNS cache entries like ['host=ip1,ip2'] + dns_args: List of DNS cache entries like ['host=ip1,ip2'] + + Returns: + Configured AddressCache instance + """ + mdns_cache = cls._parse_cache_args(mdns_args) + dns_cache = cls._parse_cache_args(dns_args) + return cls(mdns_cache=mdns_cache, dns_cache=dns_cache) + + @staticmethod + def _parse_cache_args(cache_args: Iterable[str]) -> dict[str, list[str]]: + """Parse cache arguments into a dictionary. + + Args: + cache_args: List of cache mappings like ['host1=ip1,ip2', 'host2=ip3'] + + Returns: + Dictionary mapping normalized hostnames to list of IP addresses + """ + cache: dict[str, list[str]] = {} + for arg in cache_args: + if "=" not in arg: + _LOGGER.warning( + "Invalid cache format: %s (expected 'hostname=ip1,ip2')", arg + ) + continue + hostname, ips = arg.split("=", 1) + # Normalize hostname for consistent lookups + normalized = normalize_hostname(hostname) + cache[normalized] = [ip.strip() for ip in ips.split(",")] + return cache diff --git a/esphome/automation.py b/esphome/automation.py index 99d4362845..99def9f273 100644 --- a/esphome/automation.py +++ b/esphome/automation.py @@ -15,7 +15,10 @@ from esphome.const import ( CONF_TYPE_ID, CONF_UPDATE_INTERVAL, ) +from esphome.core import ID +from esphome.cpp_generator import MockObj, MockObjClass, TemplateArgsType from esphome.schema_extractors import SCHEMA_EXTRACT, schema_extractor +from esphome.types import ConfigType from esphome.util import Registry @@ -49,11 +52,11 @@ def maybe_conf(conf, *validators): return validate -def register_action(name, action_type, schema): +def register_action(name: str, action_type: MockObjClass, schema: cv.Schema): return ACTION_REGISTRY.register(name, action_type, schema) -def register_condition(name, condition_type, schema): +def register_condition(name: str, condition_type: MockObjClass, schema: cv.Schema): return CONDITION_REGISTRY.register(name, condition_type, schema) @@ -164,43 +167,78 @@ XorCondition = cg.esphome_ns.class_("XorCondition", Condition) @register_condition("and", AndCondition, validate_condition_list) -async def and_condition_to_code(config, condition_id, template_arg, args): +async def and_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: conditions = await build_condition_list(config, template_arg, args) return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("or", OrCondition, validate_condition_list) -async def or_condition_to_code(config, condition_id, template_arg, args): +async def or_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: conditions = await build_condition_list(config, template_arg, args) return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("all", AndCondition, validate_condition_list) -async def all_condition_to_code(config, condition_id, template_arg, args): +async def all_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: conditions = await build_condition_list(config, template_arg, args) return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("any", OrCondition, validate_condition_list) -async def any_condition_to_code(config, condition_id, template_arg, args): +async def any_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: conditions = await build_condition_list(config, template_arg, args) return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("not", NotCondition, validate_potentially_and_condition) -async def not_condition_to_code(config, condition_id, template_arg, args): +async def not_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: condition = await build_condition(config, template_arg, args) return cg.new_Pvariable(condition_id, template_arg, condition) @register_condition("xor", XorCondition, validate_condition_list) -async def xor_condition_to_code(config, condition_id, template_arg, args): +async def xor_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: conditions = await build_condition_list(config, template_arg, args) return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("lambda", LambdaCondition, cv.returning_lambda) -async def lambda_condition_to_code(config, condition_id, template_arg, args): +async def lambda_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: lambda_ = await cg.process_lambda(config, args, return_type=bool) return cg.new_Pvariable(condition_id, template_arg, lambda_) @@ -217,7 +255,12 @@ async def lambda_condition_to_code(config, condition_id, template_arg, args): } ).extend(cv.COMPONENT_SCHEMA), ) -async def for_condition_to_code(config, condition_id, template_arg, args): +async def for_condition_to_code( + config: ConfigType, + condition_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: condition = await build_condition( config[CONF_CONDITION], cg.TemplateArguments(), [] ) @@ -231,7 +274,12 @@ async def for_condition_to_code(config, condition_id, template_arg, args): @register_action( "delay", DelayAction, cv.templatable(cv.positive_time_period_milliseconds) ) -async def delay_action_to_code(config, action_id, template_arg, args): +async def delay_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: var = cg.new_Pvariable(action_id, template_arg) await cg.register_component(var, {}) template_ = await cg.templatable(config, args, cg.uint32) @@ -256,10 +304,15 @@ async def delay_action_to_code(config, action_id, template_arg, args): cv.has_at_least_one_key(CONF_CONDITION, CONF_ANY, CONF_ALL), ), ) -async def if_action_to_code(config, action_id, template_arg, args): +async def if_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: cond_conf = next(el for el in config if el in (CONF_ANY, CONF_ALL, CONF_CONDITION)) - conditions = await build_condition(config[cond_conf], template_arg, args) - var = cg.new_Pvariable(action_id, template_arg, conditions) + condition = await build_condition(config[cond_conf], template_arg, args) + var = cg.new_Pvariable(action_id, template_arg, condition) if CONF_THEN in config: actions = await build_action_list(config[CONF_THEN], template_arg, args) cg.add(var.add_then(actions)) @@ -279,9 +332,14 @@ async def if_action_to_code(config, action_id, template_arg, args): } ), ) -async def while_action_to_code(config, action_id, template_arg, args): - conditions = await build_condition(config[CONF_CONDITION], template_arg, args) - var = cg.new_Pvariable(action_id, template_arg, conditions) +async def while_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: + condition = await build_condition(config[CONF_CONDITION], template_arg, args) + var = cg.new_Pvariable(action_id, template_arg, condition) actions = await build_action_list(config[CONF_THEN], template_arg, args) cg.add(var.add_then(actions)) return var @@ -297,7 +355,12 @@ async def while_action_to_code(config, action_id, template_arg, args): } ), ) -async def repeat_action_to_code(config, action_id, template_arg, args): +async def repeat_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: var = cg.new_Pvariable(action_id, template_arg) count_template = await cg.templatable(config[CONF_COUNT], args, cg.uint32) cg.add(var.set_count(count_template)) @@ -320,9 +383,14 @@ _validate_wait_until = cv.maybe_simple_value( @register_action("wait_until", WaitUntilAction, _validate_wait_until) -async def wait_until_action_to_code(config, action_id, template_arg, args): - conditions = await build_condition(config[CONF_CONDITION], template_arg, args) - var = cg.new_Pvariable(action_id, template_arg, conditions) +async def wait_until_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: + condition = await build_condition(config[CONF_CONDITION], template_arg, args) + var = cg.new_Pvariable(action_id, template_arg, condition) if CONF_TIMEOUT in config: template_ = await cg.templatable(config[CONF_TIMEOUT], args, cg.uint32) cg.add(var.set_timeout_value(template_)) @@ -331,7 +399,12 @@ async def wait_until_action_to_code(config, action_id, template_arg, args): @register_action("lambda", LambdaAction, cv.lambda_) -async def lambda_action_to_code(config, action_id, template_arg, args): +async def lambda_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: lambda_ = await cg.process_lambda(config, args, return_type=cg.void) return cg.new_Pvariable(action_id, template_arg, lambda_) @@ -345,7 +418,12 @@ async def lambda_action_to_code(config, action_id, template_arg, args): } ), ) -async def component_update_action_to_code(config, action_id, template_arg, args): +async def component_update_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: comp = await cg.get_variable(config[CONF_ID]) return cg.new_Pvariable(action_id, template_arg, comp) @@ -359,7 +437,12 @@ async def component_update_action_to_code(config, action_id, template_arg, args) } ), ) -async def component_suspend_action_to_code(config, action_id, template_arg, args): +async def component_suspend_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: comp = await cg.get_variable(config[CONF_ID]) return cg.new_Pvariable(action_id, template_arg, comp) @@ -376,7 +459,12 @@ async def component_suspend_action_to_code(config, action_id, template_arg, args } ), ) -async def component_resume_action_to_code(config, action_id, template_arg, args): +async def component_resume_action_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +) -> MockObj: comp = await cg.get_variable(config[CONF_ID]) var = cg.new_Pvariable(action_id, template_arg, comp) if CONF_UPDATE_INTERVAL in config: @@ -385,7 +473,9 @@ async def component_resume_action_to_code(config, action_id, template_arg, args) return var -async def build_action(full_config, template_arg, args): +async def build_action( + full_config: ConfigType, template_arg: cg.TemplateArguments, args: TemplateArgsType +) -> MockObj: registry_entry, config = cg.extract_registry_entry_config( ACTION_REGISTRY, full_config ) @@ -394,15 +484,19 @@ async def build_action(full_config, template_arg, args): return await builder(config, action_id, template_arg, args) -async def build_action_list(config, templ, arg_type): - actions = [] +async def build_action_list( + config: list[ConfigType], templ: cg.TemplateArguments, arg_type: TemplateArgsType +) -> list[MockObj]: + actions: list[MockObj] = [] for conf in config: action = await build_action(conf, templ, arg_type) actions.append(action) return actions -async def build_condition(full_config, template_arg, args): +async def build_condition( + full_config: ConfigType, template_arg: cg.TemplateArguments, args: TemplateArgsType +) -> MockObj: registry_entry, config = cg.extract_registry_entry_config( CONDITION_REGISTRY, full_config ) @@ -411,15 +505,19 @@ async def build_condition(full_config, template_arg, args): return await builder(config, action_id, template_arg, args) -async def build_condition_list(config, templ, args): - conditions = [] +async def build_condition_list( + config: ConfigType, templ: cg.TemplateArguments, args: TemplateArgsType +) -> list[MockObj]: + conditions: list[MockObj] = [] for conf in config: condition = await build_condition(conf, templ, args) conditions.append(condition) return conditions -async def build_automation(trigger, args, config): +async def build_automation( + trigger: MockObj, args: TemplateArgsType, config: ConfigType +) -> MockObj: arg_types = [arg[0] for arg in args] templ = cg.TemplateArguments(*arg_types) obj = cg.new_Pvariable(config[CONF_AUTOMATION_ID], templ, trigger) diff --git a/esphome/build_gen/platformio.py b/esphome/build_gen/platformio.py index 9bbe86694b..30dbb69d86 100644 --- a/esphome/build_gen/platformio.py +++ b/esphome/build_gen/platformio.py @@ -1,5 +1,3 @@ -import os - from esphome.const import __version__ from esphome.core import CORE from esphome.helpers import mkdir_p, read_file, write_file_if_changed @@ -63,7 +61,7 @@ def write_ini(content): update_storage_json() path = CORE.relative_build_path("platformio.ini") - if os.path.isfile(path): + if path.is_file(): text = read_file(path) content_format = find_begin_end( text, INI_AUTO_GENERATE_BEGIN, INI_AUTO_GENERATE_END diff --git a/esphome/codegen.py b/esphome/codegen.py index 8e02ec1164..6decd77c62 100644 --- a/esphome/codegen.py +++ b/esphome/codegen.py @@ -12,6 +12,7 @@ from esphome.cpp_generator import ( # noqa: F401 ArrayInitializer, Expression, LineComment, + LogStringLiteral, MockObj, MockObjClass, Pvariable, diff --git a/esphome/components/animation/animation.cpp b/esphome/components/animation/animation.cpp index 6db6f1a7bd..c2ae3b2f76 100644 --- a/esphome/components/animation/animation.cpp +++ b/esphome/components/animation/animation.cpp @@ -26,12 +26,12 @@ uint32_t Animation::get_animation_frame_count() const { return this->animation_f int Animation::get_current_frame() const { return this->current_frame_; } void Animation::next_frame() { this->current_frame_++; - if (loop_count_ && this->current_frame_ == loop_end_frame_ && + if (loop_count_ && static_cast(this->current_frame_) == loop_end_frame_ && (this->loop_current_iteration_ < loop_count_ || loop_count_ < 0)) { this->current_frame_ = loop_start_frame_; this->loop_current_iteration_++; } - if (this->current_frame_ >= animation_frame_count_) { + if (static_cast(this->current_frame_) >= animation_frame_count_) { this->loop_current_iteration_ = 1; this->current_frame_ = 0; } diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index b120503a2e..58828c131d 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -1,4 +1,5 @@ import base64 +import logging from esphome import automation from esphome.automation import Condition @@ -8,34 +9,59 @@ import esphome.config_validation as cv from esphome.const import ( CONF_ACTION, CONF_ACTIONS, + CONF_CAPTURE_RESPONSE, CONF_DATA, CONF_DATA_TEMPLATE, CONF_EVENT, CONF_ID, CONF_KEY, + CONF_MAX_CONNECTIONS, CONF_ON_CLIENT_CONNECTED, CONF_ON_CLIENT_DISCONNECTED, + CONF_ON_ERROR, + CONF_ON_SUCCESS, CONF_PASSWORD, CONF_PORT, CONF_REBOOT_TIMEOUT, + CONF_RESPONSE_TEMPLATE, CONF_SERVICE, CONF_SERVICES, CONF_TAG, CONF_TRIGGER_ID, CONF_VARIABLES, ) -from esphome.core import CORE, CoroPriority, coroutine_with_priority +from esphome.core import CORE, ID, CoroPriority, coroutine_with_priority +from esphome.cpp_generator import TemplateArgsType +from esphome.types import ConfigType + +_LOGGER = logging.getLogger(__name__) DOMAIN = "api" DEPENDENCIES = ["network"] -AUTO_LOAD = ["socket"] CODEOWNERS = ["@esphome/core"] + +def AUTO_LOAD(config: ConfigType) -> list[str]: + """Conditionally auto-load json only when capture_response is used.""" + base = ["socket"] + + # Check if any homeassistant.action/homeassistant.service has capture_response: true + # This flag is set during config validation in _validate_response_config + if not config or CORE.data.get(DOMAIN, {}).get(CONF_CAPTURE_RESPONSE, False): + return base + ["json"] + + return base + + api_ns = cg.esphome_ns.namespace("api") APIServer = api_ns.class_("APIServer", cg.Component, cg.Controller) HomeAssistantServiceCallAction = api_ns.class_( "HomeAssistantServiceCallAction", automation.Action ) +ActionResponse = api_ns.class_("ActionResponse") +HomeAssistantActionResponseTrigger = api_ns.class_( + "HomeAssistantActionResponseTrigger", automation.Trigger +) APIConnectedCondition = api_ns.class_("APIConnectedCondition", Condition) UserServiceTrigger = api_ns.class_("UserServiceTrigger", automation.Trigger) @@ -55,6 +81,8 @@ CONF_BATCH_DELAY = "batch_delay" CONF_CUSTOM_SERVICES = "custom_services" CONF_HOMEASSISTANT_SERVICES = "homeassistant_services" CONF_HOMEASSISTANT_STATES = "homeassistant_states" +CONF_LISTEN_BACKLOG = "listen_backlog" +CONF_MAX_SEND_QUEUE = "max_send_queue" def validate_encryption_key(value): @@ -101,6 +129,32 @@ def _encryption_schema(config): return ENCRYPTION_SCHEMA(config) +def _validate_api_config(config: ConfigType) -> ConfigType: + """Validate API configuration with mutual exclusivity check and deprecation warning.""" + # Check if both password and encryption are configured + has_password = CONF_PASSWORD in config and config[CONF_PASSWORD] + has_encryption = CONF_ENCRYPTION in config + + if has_password and has_encryption: + raise cv.Invalid( + "The 'password' and 'encryption' options are mutually exclusive. " + "The API client only supports one authentication method at a time. " + "Please remove one of them. " + "Note: 'password' authentication is deprecated and will be removed in version 2026.1.0. " + "We strongly recommend using 'encryption' instead for better security." + ) + + # Warn about password deprecation + if has_password: + _LOGGER.warning( + "API 'password' authentication has been deprecated since May 2022 and will be removed in version 2026.1.0. " + "Please migrate to the 'encryption' configuration. " + "See https://esphome.io/components/api.html#configuration-variables" + ) + + return config + + CONFIG_SCHEMA = cv.All( cv.Schema( { @@ -128,9 +182,46 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_ON_CLIENT_DISCONNECTED): automation.validate_automation( single=True ), + # Connection limits to prevent memory exhaustion on resource-constrained devices + # Each connection uses ~500-1000 bytes of RAM plus system resources + # Platform defaults based on available RAM and network stack implementation: + cv.SplitDefault( + CONF_LISTEN_BACKLOG, + esp8266=1, # Limited RAM (~40KB free), LWIP raw sockets + esp32=4, # More RAM (520KB), BSD sockets + rp2040=1, # Limited RAM (264KB), LWIP raw sockets like ESP8266 + bk72xx=4, # Moderate RAM, BSD-style sockets + rtl87xx=4, # Moderate RAM, BSD-style sockets + host=4, # Abundant resources + ln882x=4, # Moderate RAM + ): cv.int_range(min=1, max=10), + cv.SplitDefault( + CONF_MAX_CONNECTIONS, + esp8266=4, # ~40KB free RAM, each connection uses ~500-1000 bytes + esp32=8, # 520KB RAM available + rp2040=4, # 264KB RAM but LWIP constraints + bk72xx=8, # Moderate RAM + rtl87xx=8, # Moderate RAM + host=8, # Abundant resources + ln882x=8, # Moderate RAM + ): cv.int_range(min=1, max=20), + # Maximum queued send buffers per connection before dropping connection + # Each buffer uses ~8-12 bytes overhead plus actual message size + # Platform defaults based on available RAM and typical message rates: + cv.SplitDefault( + CONF_MAX_SEND_QUEUE, + esp8266=5, # Limited RAM, need to fail fast + esp32=8, # More RAM, can buffer more + rp2040=5, # Limited RAM + bk72xx=8, # Moderate RAM + rtl87xx=8, # Moderate RAM + host=16, # Abundant resources + ln882x=8, # Moderate RAM + ): cv.int_range(min=1, max=64), } ).extend(cv.COMPONENT_SCHEMA), cv.rename_key(CONF_SERVICES, CONF_ACTIONS), + _validate_api_config, ) @@ -145,6 +236,11 @@ async def to_code(config): cg.add(var.set_password(config[CONF_PASSWORD])) cg.add(var.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) cg.add(var.set_batch_delay(config[CONF_BATCH_DELAY])) + if CONF_LISTEN_BACKLOG in config: + cg.add(var.set_listen_backlog(config[CONF_LISTEN_BACKLOG])) + if CONF_MAX_CONNECTIONS in config: + cg.add(var.set_max_connections(config[CONF_MAX_CONNECTIONS])) + cg.add_define("API_MAX_SEND_QUEUE", config[CONF_MAX_SEND_QUEUE]) # Set USE_API_SERVICES if any services are enabled if config.get(CONF_ACTIONS) or config[CONF_CUSTOM_SERVICES]: @@ -213,6 +309,29 @@ async def to_code(config): KEY_VALUE_SCHEMA = cv.Schema({cv.string: cv.templatable(cv.string_strict)}) +def _validate_response_config(config: ConfigType) -> ConfigType: + # Validate dependencies: + # - response_template requires capture_response: true + # - capture_response: true requires on_success + if CONF_RESPONSE_TEMPLATE in config and not config[CONF_CAPTURE_RESPONSE]: + raise cv.Invalid( + f"`{CONF_RESPONSE_TEMPLATE}` requires `{CONF_CAPTURE_RESPONSE}: true` to be set.", + path=[CONF_RESPONSE_TEMPLATE], + ) + + if config[CONF_CAPTURE_RESPONSE] and CONF_ON_SUCCESS not in config: + raise cv.Invalid( + f"`{CONF_CAPTURE_RESPONSE}: true` requires `{CONF_ON_SUCCESS}` to be set.", + path=[CONF_CAPTURE_RESPONSE], + ) + + # Track if any action uses capture_response for AUTO_LOAD + if config[CONF_CAPTURE_RESPONSE]: + CORE.data.setdefault(DOMAIN, {})[CONF_CAPTURE_RESPONSE] = True + + return config + + HOMEASSISTANT_ACTION_ACTION_SCHEMA = cv.All( cv.Schema( { @@ -228,10 +347,15 @@ HOMEASSISTANT_ACTION_ACTION_SCHEMA = cv.All( cv.Optional(CONF_VARIABLES, default={}): cv.Schema( {cv.string: cv.returning_lambda} ), + cv.Optional(CONF_RESPONSE_TEMPLATE): cv.templatable(cv.string), + cv.Optional(CONF_CAPTURE_RESPONSE, default=False): cv.boolean, + cv.Optional(CONF_ON_SUCCESS): automation.validate_automation(single=True), + cv.Optional(CONF_ON_ERROR): automation.validate_automation(single=True), } ), cv.has_exactly_one_key(CONF_SERVICE, CONF_ACTION), cv.rename_key(CONF_SERVICE, CONF_ACTION), + _validate_response_config, ) @@ -245,7 +369,12 @@ HOMEASSISTANT_ACTION_ACTION_SCHEMA = cv.All( HomeAssistantServiceCallAction, HOMEASSISTANT_ACTION_ACTION_SCHEMA, ) -async def homeassistant_service_to_code(config, action_id, template_arg, args): +async def homeassistant_service_to_code( + config: ConfigType, + action_id: ID, + template_arg: cg.TemplateArguments, + args: TemplateArgsType, +): cg.add_define("USE_API_HOMEASSISTANT_SERVICES") serv = await cg.get_variable(config[CONF_ID]) var = cg.new_Pvariable(action_id, template_arg, serv, False) @@ -260,6 +389,40 @@ async def homeassistant_service_to_code(config, action_id, template_arg, args): for key, value in config[CONF_VARIABLES].items(): templ = await cg.templatable(value, args, None) cg.add(var.add_variable(key, templ)) + + if on_error := config.get(CONF_ON_ERROR): + cg.add_define("USE_API_HOMEASSISTANT_ACTION_RESPONSES") + cg.add_define("USE_API_HOMEASSISTANT_ACTION_RESPONSES_ERRORS") + cg.add(var.set_wants_status()) + await automation.build_automation( + var.get_error_trigger(), + [(cg.std_string, "error"), *args], + on_error, + ) + + if on_success := config.get(CONF_ON_SUCCESS): + cg.add_define("USE_API_HOMEASSISTANT_ACTION_RESPONSES") + cg.add(var.set_wants_status()) + if config[CONF_CAPTURE_RESPONSE]: + cg.add(var.set_wants_response()) + cg.add_define("USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON") + await automation.build_automation( + var.get_success_trigger_with_response(), + [(cg.JsonObjectConst, "response"), *args], + on_success, + ) + + if response_template := config.get(CONF_RESPONSE_TEMPLATE): + templ = await cg.templatable(response_template, args, cg.std_string) + cg.add(var.set_response_template(templ)) + + else: + await automation.build_automation( + var.get_success_trigger(), + args, + on_success, + ) + return var diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto index 471127e93a..87f477799d 100644 --- a/esphome/components/api/api.proto +++ b/esphome/components/api/api.proto @@ -7,7 +7,7 @@ service APIConnection { option (needs_setup_connection) = false; option (needs_authentication) = false; } - rpc connect (ConnectRequest) returns (ConnectResponse) { + rpc authenticate (AuthenticationRequest) returns (AuthenticationResponse) { option (needs_setup_connection) = false; option (needs_authentication) = false; } @@ -66,6 +66,9 @@ service APIConnection { rpc voice_assistant_set_configuration(VoiceAssistantSetConfiguration) returns (void) {} rpc alarm_control_panel_command (AlarmControlPanelCommandRequest) returns (void) {} + + rpc zwave_proxy_frame(ZWaveProxyFrame) returns (void) {} + rpc zwave_proxy_request(ZWaveProxyRequest) returns (void) {} } @@ -99,7 +102,7 @@ message HelloRequest { // For example "Home Assistant" // Not strictly necessary to send but nice for debugging // purposes. - string client_info = 1; + string client_info = 1 [(pointer_to_buffer) = true]; uint32 api_version_major = 2; uint32 api_version_minor = 3; } @@ -129,21 +132,23 @@ message HelloResponse { // Message sent at the beginning of each connection to authenticate the client // Can only be sent by the client and only at the beginning of the connection -message ConnectRequest { +message AuthenticationRequest { option (id) = 3; option (source) = SOURCE_CLIENT; option (no_delay) = true; + option (ifdef) = "USE_API_PASSWORD"; // The password to log in with - string password = 1; + string password = 1 [(pointer_to_buffer) = true]; } // Confirmation of successful connection. After this the connection is available for all traffic. // Can only be sent by the server and only at the beginning of the connection -message ConnectResponse { +message AuthenticationResponse { option (id) = 4; option (source) = SOURCE_SERVER; option (no_delay) = true; + option (ifdef) = "USE_API_PASSWORD"; bool invalid_password = 1; } @@ -252,6 +257,10 @@ message DeviceInfoResponse { // Top-level area info to phase out suggested_area AreaInfo area = 22 [(field_ifdef) = "USE_AREAS"]; + + // Indicates if Z-Wave proxy support is available and features supported + uint32 zwave_proxy_feature_flags = 23 [(field_ifdef) = "USE_ZWAVE_PROXY"]; + uint32 zwave_home_id = 24 [(field_ifdef) = "USE_ZWAVE_PROXY"]; } message ListEntitiesRequest { @@ -760,7 +769,7 @@ message HomeassistantServiceMap { string value = 2 [(no_zero_copy) = true]; } -message HomeassistantServiceResponse { +message HomeassistantActionRequest { option (id) = 35; option (source) = SOURCE_SERVER; option (no_delay) = true; @@ -771,6 +780,22 @@ message HomeassistantServiceResponse { repeated HomeassistantServiceMap data_template = 3; repeated HomeassistantServiceMap variables = 4; bool is_event = 5; + uint32 call_id = 6 [(field_ifdef) = "USE_API_HOMEASSISTANT_ACTION_RESPONSES"]; + bool wants_response = 7 [(field_ifdef) = "USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON"]; + string response_template = 8 [(no_zero_copy) = true, (field_ifdef) = "USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON"]; +} + +// Message sent by Home Assistant to ESPHome with service call response data +message HomeassistantActionResponse { + option (id) = 130; + option (source) = SOURCE_CLIENT; + option (no_delay) = true; + option (ifdef) = "USE_API_HOMEASSISTANT_ACTION_RESPONSES"; + + uint32 call_id = 1; // Matches the call_id from HomeassistantActionRequest + bool success = 2; // Whether the service call succeeded + string error_message = 3; // Error message if success = false + bytes response_data = 4 [(pointer_to_buffer) = true, (field_ifdef) = "USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON"]; } // ==================== IMPORT HOME ASSISTANT STATES ==================== @@ -815,7 +840,7 @@ message GetTimeResponse { option (no_delay) = true; fixed32 epoch_seconds = 1; - string timezone = 2; + string timezone = 2 [(pointer_to_buffer) = true]; } // ==================== USER-DEFINES SERVICES ==================== @@ -1456,7 +1481,7 @@ message BluetoothDeviceRequest { uint64 address = 1; BluetoothDeviceRequestType request_type = 2; - bool has_address_type = 3; + bool has_address_type = 3; // Deprecated, should be removed in 2027.8 - https://github.com/esphome/esphome/pull/10318 uint32 address_type = 4; } @@ -1562,7 +1587,7 @@ message BluetoothGATTWriteRequest { uint32 handle = 2; bool response = 3; - bytes data = 4; + bytes data = 4 [(pointer_to_buffer) = true]; } message BluetoothGATTReadDescriptorRequest { @@ -1582,7 +1607,7 @@ message BluetoothGATTWriteDescriptorRequest { uint64 address = 1; uint32 handle = 2; - bytes data = 3; + bytes data = 3 [(pointer_to_buffer) = true]; } message BluetoothGATTNotifyRequest { @@ -1856,10 +1881,22 @@ message VoiceAssistantWakeWord { repeated string trained_languages = 3; } +message VoiceAssistantExternalWakeWord { + string id = 1; + string wake_word = 2; + repeated string trained_languages = 3; + string model_type = 4; + uint32 model_size = 5; + string model_hash = 6; + string url = 7; +} + message VoiceAssistantConfigurationRequest { option (id) = 121; option (source) = SOURCE_CLIENT; option (ifdef) = "USE_VOICE_ASSISTANT"; + + repeated VoiceAssistantExternalWakeWord external_wake_words = 1; } message VoiceAssistantConfigurationResponse { @@ -2274,3 +2311,28 @@ message UpdateCommandRequest { UpdateCommand command = 2; uint32 device_id = 3 [(field_ifdef) = "USE_DEVICES"]; } + +// ==================== Z-WAVE ==================== + +message ZWaveProxyFrame { + option (id) = 128; + option (source) = SOURCE_BOTH; + option (ifdef) = "USE_ZWAVE_PROXY"; + option (no_delay) = true; + + bytes data = 1 [(pointer_to_buffer) = true]; +} + +enum ZWaveProxyRequestType { + ZWAVE_PROXY_REQUEST_TYPE_SUBSCRIBE = 0; + ZWAVE_PROXY_REQUEST_TYPE_UNSUBSCRIBE = 1; + ZWAVE_PROXY_REQUEST_TYPE_HOME_ID_CHANGE = 2; +} +message ZWaveProxyRequest { + option (id) = 129; + option (source) = SOURCE_BOTH; + option (ifdef) = "USE_ZWAVE_PROXY"; + + ZWaveProxyRequestType type = 1; + bytes data = 2 [(pointer_to_buffer) = true]; +} diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 7b7853f040..ae03dfbb33 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -8,9 +8,9 @@ #endif #include #include -#include #include #include +#include #include "esphome/components/network/util.h" #include "esphome/core/application.h" #include "esphome/core/entity_base.h" @@ -30,6 +30,9 @@ #ifdef USE_VOICE_ASSISTANT #include "esphome/components/voice_assistant/voice_assistant.h" #endif +#ifdef USE_ZWAVE_PROXY +#include "esphome/components/zwave_proxy/zwave_proxy.h" +#endif namespace esphome::api { @@ -113,8 +116,7 @@ void APIConnection::start() { APIError err = this->helper_->init(); if (err != APIError::OK) { - on_fatal_error(); - this->log_warning_(LOG_STR("Helper init failed"), err); + this->fatal_error_with_log_(LOG_STR("Helper init failed"), err); return; } this->client_info_.peername = helper_->getpeername(); @@ -144,8 +146,7 @@ void APIConnection::loop() { APIError err = this->helper_->loop(); if (err != APIError::OK) { - on_fatal_error(); - this->log_socket_operation_failed_(err); + this->fatal_error_with_log_(LOG_STR("Socket operation failed"), err); return; } @@ -160,17 +161,13 @@ void APIConnection::loop() { // No more data available break; } else if (err != APIError::OK) { - on_fatal_error(); - this->log_warning_(LOG_STR("Reading failed"), err); + this->fatal_error_with_log_(LOG_STR("Reading failed"), err); return; } else { this->last_traffic_ = now; // read a packet - if (buffer.data_len > 0) { - this->read_message(buffer.data_len, buffer.type, &buffer.container[buffer.data_offset]); - } else { - this->read_message(0, buffer.type, nullptr); - } + this->read_message(buffer.data_len, buffer.type, + buffer.data_len > 0 ? &buffer.container[buffer.data_offset] : nullptr); if (this->flags_.remove) return; } @@ -202,7 +199,8 @@ void APIConnection::loop() { // Disconnect if not responded within 2.5*keepalive if (now - this->last_traffic_ > KEEPALIVE_DISCONNECT_TIMEOUT) { on_fatal_error(); - ESP_LOGW(TAG, "%s is unresponsive; disconnecting", this->get_client_combined_info().c_str()); + ESP_LOGW(TAG, "%s (%s) is unresponsive; disconnecting", this->client_info_.name.c_str(), + this->client_info_.peername.c_str()); } } else if (now - this->last_traffic_ > KEEPALIVE_TIMEOUT_MS && !this->flags_.remove) { // Only send ping if we're not disconnecting @@ -252,7 +250,7 @@ bool APIConnection::send_disconnect_response(const DisconnectRequest &msg) { // remote initiated disconnect_client // don't close yet, we still need to send the disconnect response // close will happen on next loop - ESP_LOGD(TAG, "%s disconnected", this->get_client_combined_info().c_str()); + ESP_LOGD(TAG, "%s (%s) disconnected", this->client_info_.name.c_str(), this->client_info_.peername.c_str()); this->flags_.next_close = true; DisconnectResponse resp; return this->send_message(resp, DisconnectResponse::MESSAGE_TYPE); @@ -1075,8 +1073,14 @@ void APIConnection::on_get_time_response(const GetTimeResponse &value) { if (homeassistant::global_homeassistant_time != nullptr) { homeassistant::global_homeassistant_time->set_epoch_time(value.epoch_seconds); #ifdef USE_TIME_TIMEZONE - if (!value.timezone.empty() && value.timezone != homeassistant::global_homeassistant_time->get_timezone()) { - homeassistant::global_homeassistant_time->set_timezone(value.timezone); + if (value.timezone_len > 0) { + const std::string ¤t_tz = homeassistant::global_homeassistant_time->get_timezone(); + // Compare without allocating a string + if (current_tz.length() != value.timezone_len || + memcmp(current_tz.c_str(), value.timezone, value.timezone_len) != 0) { + homeassistant::global_homeassistant_time->set_timezone( + std::string(reinterpret_cast(value.timezone), value.timezone_len)); + } } #endif } @@ -1193,6 +1197,23 @@ bool APIConnection::send_voice_assistant_get_configuration_response(const VoiceA resp_wake_word.trained_languages.push_back(lang); } } + + // Filter external wake words + for (auto &wake_word : msg.external_wake_words) { + if (wake_word.model_type != "micro") { + // microWakeWord only + continue; + } + + resp.available_wake_words.emplace_back(); + auto &resp_wake_word = resp.available_wake_words.back(); + resp_wake_word.set_id(StringRef(wake_word.id)); + resp_wake_word.set_wake_word(StringRef(wake_word.wake_word)); + for (const auto &lang : wake_word.trained_languages) { + resp_wake_word.trained_languages.push_back(lang); + } + } + resp.active_wake_words = &config.active_wake_words; resp.max_active_wake_words = config.max_active_wake_words; return this->send_message(resp, VoiceAssistantConfigurationResponse::MESSAGE_TYPE); @@ -1203,7 +1224,16 @@ void APIConnection::voice_assistant_set_configuration(const VoiceAssistantSetCon voice_assistant::global_voice_assistant->on_set_configuration(msg.active_wake_words); } } +#endif +#ifdef USE_ZWAVE_PROXY +void APIConnection::zwave_proxy_frame(const ZWaveProxyFrame &msg) { + zwave_proxy::global_zwave_proxy->send_frame(msg.data, msg.data_len); +} + +void APIConnection::zwave_proxy_request(const ZWaveProxyRequest &msg) { + zwave_proxy::global_zwave_proxy->zwave_proxy_request(this, msg.type); +} #endif #ifdef USE_ALARM_CONTROL_PANEL @@ -1350,7 +1380,7 @@ void APIConnection::complete_authentication_() { } this->flags_.connection_state = static_cast(ConnectionState::AUTHENTICATED); - ESP_LOGD(TAG, "%s connected", this->get_client_combined_info().c_str()); + ESP_LOGD(TAG, "%s (%s) connected", this->client_info_.name.c_str(), this->client_info_.peername.c_str()); #ifdef USE_API_CLIENT_CONNECTED_TRIGGER this->parent_->get_client_connected_trigger()->trigger(this->client_info_.name, this->client_info_.peername); #endif @@ -1359,10 +1389,15 @@ void APIConnection::complete_authentication_() { this->send_time_request(); } #endif +#ifdef USE_ZWAVE_PROXY + if (zwave_proxy::global_zwave_proxy != nullptr) { + zwave_proxy::global_zwave_proxy->api_connection_authenticated(this); + } +#endif } bool APIConnection::send_hello_response(const HelloRequest &msg) { - this->client_info_.name = msg.client_info; + this->client_info_.name.assign(reinterpret_cast(msg.client_info), msg.client_info_len); this->client_info_.peername = this->helper_->getpeername(); this->client_api_version_major_ = msg.api_version_major; this->client_api_version_minor_ = msg.api_version_minor; @@ -1386,20 +1421,17 @@ bool APIConnection::send_hello_response(const HelloRequest &msg) { return this->send_message(resp, HelloResponse::MESSAGE_TYPE); } -bool APIConnection::send_connect_response(const ConnectRequest &msg) { - bool correct = true; #ifdef USE_API_PASSWORD - correct = this->parent_->check_password(msg.password); -#endif - - ConnectResponse resp; +bool APIConnection::send_authenticate_response(const AuthenticationRequest &msg) { + AuthenticationResponse resp; // bool invalid_password = 1; - resp.invalid_password = !correct; - if (correct) { + resp.invalid_password = !this->parent_->check_password(msg.password, msg.password_len); + if (!resp.invalid_password) { this->complete_authentication_(); } - return this->send_message(resp, ConnectResponse::MESSAGE_TYPE); + return this->send_message(resp, AuthenticationResponse::MESSAGE_TYPE); } +#endif // USE_API_PASSWORD bool APIConnection::send_ping_response(const PingRequest &msg) { PingResponse resp; @@ -1463,6 +1495,10 @@ bool APIConnection::send_device_info_response(const DeviceInfoRequest &msg) { #ifdef USE_VOICE_ASSISTANT resp.voice_assistant_feature_flags = voice_assistant::global_voice_assistant->get_feature_flags(); #endif +#ifdef USE_ZWAVE_PROXY + resp.zwave_proxy_feature_flags = zwave_proxy::global_zwave_proxy->get_feature_flags(); + resp.zwave_home_id = zwave_proxy::global_zwave_proxy->get_home_id(); +#endif #ifdef USE_API_NOISE resp.api_encryption_supported = true; #endif @@ -1513,6 +1549,20 @@ void APIConnection::execute_service(const ExecuteServiceRequest &msg) { } } #endif + +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +void APIConnection::on_homeassistant_action_response(const HomeassistantActionResponse &msg) { +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + if (msg.response_data_len > 0) { + this->parent_->handle_action_response(msg.call_id, msg.success, msg.error_message, msg.response_data, + msg.response_data_len); + } else +#endif + { + this->parent_->handle_action_response(msg.call_id, msg.success, msg.error_message); + } +}; +#endif #ifdef USE_API_NOISE bool APIConnection::send_noise_encryption_set_key_response(const NoiseEncryptionSetKeyRequest &msg) { NoiseEncryptionSetKeyResponse resp; @@ -1543,8 +1593,7 @@ bool APIConnection::try_to_clear_buffer(bool log_out_of_space) { delay(0); APIError err = this->helper_->loop(); if (err != APIError::OK) { - on_fatal_error(); - this->log_socket_operation_failed_(err); + this->fatal_error_with_log_(LOG_STR("Socket operation failed"), err); return false; } if (this->helper_->can_write_without_blocking()) @@ -1563,8 +1612,7 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) { if (err == APIError::WOULD_BLOCK) return false; if (err != APIError::OK) { - on_fatal_error(); - this->log_warning_(LOG_STR("Packet write failed"), err); + this->fatal_error_with_log_(LOG_STR("Packet write failed"), err); return false; } // Do not set last_traffic_ on send @@ -1573,12 +1621,12 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) { #ifdef USE_API_PASSWORD void APIConnection::on_unauthenticated_access() { this->on_fatal_error(); - ESP_LOGD(TAG, "%s access without authentication", this->get_client_combined_info().c_str()); + ESP_LOGD(TAG, "%s (%s) no authentication", this->client_info_.name.c_str(), this->client_info_.peername.c_str()); } #endif void APIConnection::on_no_setup_connection() { this->on_fatal_error(); - ESP_LOGD(TAG, "%s access without full connection", this->get_client_combined_info().c_str()); + ESP_LOGD(TAG, "%s (%s) no connection setup", this->client_info_.name.c_str(), this->client_info_.peername.c_str()); } void APIConnection::on_fatal_error() { this->helper_->close(); @@ -1750,8 +1798,7 @@ void APIConnection::process_batch_() { APIError err = this->helper_->write_protobuf_packets(ProtoWriteBuffer{&shared_buf}, std::span(packet_info, packet_count)); if (err != APIError::OK && err != APIError::WOULD_BLOCK) { - on_fatal_error(); - this->log_warning_(LOG_STR("Batch write failed"), err); + this->fatal_error_with_log_(LOG_STR("Batch write failed"), err); } #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1830,12 +1877,8 @@ void APIConnection::process_state_subscriptions_() { #endif // USE_API_HOMEASSISTANT_STATES void APIConnection::log_warning_(const LogString *message, APIError err) { - ESP_LOGW(TAG, "%s: %s %s errno=%d", this->get_client_combined_info().c_str(), LOG_STR_ARG(message), - LOG_STR_ARG(api_error_to_logstr(err)), errno); -} - -void APIConnection::log_socket_operation_failed_(APIError err) { - this->log_warning_(LOG_STR("Socket operation failed"), err); + ESP_LOGW(TAG, "%s (%s): %s %s errno=%d", this->client_info_.name.c_str(), this->client_info_.peername.c_str(), + LOG_STR_ARG(message), LOG_STR_ARG(api_error_to_logstr(err)), errno); } } // namespace esphome::api diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 8f93f38203..284fa11a95 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -10,8 +10,8 @@ #include "esphome/core/component.h" #include "esphome/core/entity_base.h" -#include #include +#include namespace esphome::api { @@ -19,14 +19,6 @@ namespace esphome::api { struct ClientInfo { std::string name; // Client name from Hello message std::string peername; // IP:port from socket - - std::string get_combined_info() const { - if (name == peername) { - // Before Hello message, both are the same - return name; - } - return name + " (" + peername + ")"; - } }; // Keepalive timeout in milliseconds @@ -132,12 +124,15 @@ class APIConnection final : public APIServerConnection { #endif bool try_send_log_message(int level, const char *tag, const char *line, size_t message_len); #ifdef USE_API_HOMEASSISTANT_SERVICES - void send_homeassistant_service_call(const HomeassistantServiceResponse &call) { + void send_homeassistant_action(const HomeassistantActionRequest &call) { if (!this->flags_.service_call_subscription) return; - this->send_message(call, HomeassistantServiceResponse::MESSAGE_TYPE); + this->send_message(call, HomeassistantActionRequest::MESSAGE_TYPE); } -#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + void on_homeassistant_action_response(const HomeassistantActionResponse &msg) override; +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES +#endif // USE_API_HOMEASSISTANT_SERVICES #ifdef USE_BLUETOOTH_PROXY void subscribe_bluetooth_le_advertisements(const SubscribeBluetoothLEAdvertisementsRequest &msg) override; void unsubscribe_bluetooth_le_advertisements(const UnsubscribeBluetoothLEAdvertisementsRequest &msg) override; @@ -171,6 +166,11 @@ class APIConnection final : public APIServerConnection { void voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) override; #endif +#ifdef USE_ZWAVE_PROXY + void zwave_proxy_frame(const ZWaveProxyFrame &msg) override; + void zwave_proxy_request(const ZWaveProxyRequest &msg) override; +#endif + #ifdef USE_ALARM_CONTROL_PANEL bool send_alarm_control_panel_state(alarm_control_panel::AlarmControlPanel *a_alarm_control_panel); void alarm_control_panel_command(const AlarmControlPanelCommandRequest &msg) override; @@ -197,7 +197,9 @@ class APIConnection final : public APIServerConnection { void on_get_time_response(const GetTimeResponse &value) override; #endif bool send_hello_response(const HelloRequest &msg) override; - bool send_connect_response(const ConnectRequest &msg) override; +#ifdef USE_API_PASSWORD + bool send_authenticate_response(const AuthenticationRequest &msg) override; +#endif bool send_disconnect_response(const DisconnectRequest &msg) override; bool send_ping_response(const PingRequest &msg) override; bool send_device_info_response(const DeviceInfoRequest &msg) override; @@ -271,7 +273,8 @@ class APIConnection final : public APIServerConnection { bool try_to_clear_buffer(bool log_out_of_space); bool send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) override; - std::string get_client_combined_info() const { return this->client_info_.get_combined_info(); } + const std::string &get_name() const { return this->client_info_.name; } + const std::string &get_peername() const { return this->client_info_.peername; } protected: // Helper function to handle authentication completion @@ -732,8 +735,11 @@ class APIConnection final : public APIServerConnection { // Helper function to log API errors with errno void log_warning_(const LogString *message, APIError err); - // Specific helper for duplicated error message - void log_socket_operation_failed_(APIError err); + // Helper to handle fatal errors with logging + inline void fatal_error_with_log_(const LogString *message, APIError err) { + this->on_fatal_error(); + this->log_warning_(message, err); + } }; } // namespace esphome::api diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index a284e09c4a..20f8fcaf61 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -13,7 +13,8 @@ namespace esphome::api { static const char *const TAG = "api.frame_helper"; -#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__) +#define HELPER_LOG(msg, ...) \ + ESP_LOGVV(TAG, "%s (%s): " msg, this->client_info_->name.c_str(), this->client_info_->peername.c_str(), ##__VA_ARGS__) #ifdef HELPER_LOG_PACKETS #define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str()) @@ -80,7 +81,7 @@ const LogString *api_error_to_logstr(APIError err) { // Default implementation for loop - handles sending buffered data APIError APIFrameHelper::loop() { - if (!this->tx_buf_.empty()) { + if (this->tx_buf_count_ > 0) { APIError err = try_send_tx_buf_(); if (err != APIError::OK && err != APIError::WOULD_BLOCK) { return err; @@ -102,9 +103,20 @@ APIError APIFrameHelper::handle_socket_write_error_() { // Helper method to buffer data from IOVs void APIFrameHelper::buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len, uint16_t offset) { - SendBuffer buffer; - buffer.size = total_write_len - offset; - buffer.data = std::make_unique(buffer.size); + // Check if queue is full + if (this->tx_buf_count_ >= API_MAX_SEND_QUEUE) { + HELPER_LOG("Send queue full (%u buffers), dropping connection", this->tx_buf_count_); + this->state_ = State::FAILED; + return; + } + + uint16_t buffer_size = total_write_len - offset; + auto &buffer = this->tx_buf_[this->tx_buf_tail_]; + buffer = std::make_unique(SendBuffer{ + .data = std::make_unique(buffer_size), + .size = buffer_size, + .offset = 0, + }); uint16_t to_skip = offset; uint16_t write_pos = 0; @@ -117,12 +129,15 @@ void APIFrameHelper::buffer_data_from_iov_(const struct iovec *iov, int iovcnt, // Include this segment (partially or fully) const uint8_t *src = reinterpret_cast(iov[i].iov_base) + to_skip; uint16_t len = static_cast(iov[i].iov_len) - to_skip; - std::memcpy(buffer.data.get() + write_pos, src, len); + std::memcpy(buffer->data.get() + write_pos, src, len); write_pos += len; to_skip = 0; } } - this->tx_buf_.push_back(std::move(buffer)); + + // Update circular buffer tracking + this->tx_buf_tail_ = (this->tx_buf_tail_ + 1) % API_MAX_SEND_QUEUE; + this->tx_buf_count_++; } // This method writes data to socket or buffers it @@ -140,7 +155,7 @@ APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, uint16_ #endif // Try to send any existing buffered data first if there is any - if (!this->tx_buf_.empty()) { + if (this->tx_buf_count_ > 0) { APIError send_result = try_send_tx_buf_(); // If real error occurred (not just WOULD_BLOCK), return it if (send_result != APIError::OK && send_result != APIError::WOULD_BLOCK) { @@ -149,7 +164,7 @@ APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, uint16_ // If there is still data in the buffer, we can't send, buffer // the new data and return - if (!this->tx_buf_.empty()) { + if (this->tx_buf_count_ > 0) { this->buffer_data_from_iov_(iov, iovcnt, total_write_len, 0); return APIError::OK; // Success, data buffered } @@ -177,32 +192,31 @@ APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, uint16_ } // Common implementation for trying to send buffered data -// IMPORTANT: Caller MUST ensure tx_buf_ is not empty before calling this method +// IMPORTANT: Caller MUST ensure tx_buf_count_ > 0 before calling this method APIError APIFrameHelper::try_send_tx_buf_() { // Try to send from tx_buf - we assume it's not empty as it's the caller's responsibility to check - bool tx_buf_empty = false; - while (!tx_buf_empty) { + while (this->tx_buf_count_ > 0) { // Get the first buffer in the queue - SendBuffer &front_buffer = this->tx_buf_.front(); + SendBuffer *front_buffer = this->tx_buf_[this->tx_buf_head_].get(); // Try to send the remaining data in this buffer - ssize_t sent = this->socket_->write(front_buffer.current_data(), front_buffer.remaining()); + ssize_t sent = this->socket_->write(front_buffer->current_data(), front_buffer->remaining()); if (sent == -1) { return this->handle_socket_write_error_(); } else if (sent == 0) { // Nothing sent but not an error return APIError::WOULD_BLOCK; - } else if (static_cast(sent) < front_buffer.remaining()) { + } else if (static_cast(sent) < front_buffer->remaining()) { // Partially sent, update offset // Cast to ensure no overflow issues with uint16_t - front_buffer.offset += static_cast(sent); + front_buffer->offset += static_cast(sent); return APIError::WOULD_BLOCK; // Stop processing more buffers if we couldn't send a complete buffer } else { // Buffer completely sent, remove it from the queue - this->tx_buf_.pop_front(); - // Update empty status for the loop condition - tx_buf_empty = this->tx_buf_.empty(); + this->tx_buf_[this->tx_buf_head_].reset(); + this->tx_buf_head_ = (this->tx_buf_head_ + 1) % API_MAX_SEND_QUEUE; + this->tx_buf_count_--; // Continue loop to try sending the next buffer } } diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index c11d701ffe..9aaada3cf7 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -1,7 +1,8 @@ #pragma once +#include #include -#include #include +#include #include #include #include @@ -17,6 +18,17 @@ namespace esphome::api { // uncomment to log raw packets //#define HELPER_LOG_PACKETS +// Maximum message size limits to prevent OOM on constrained devices +// Handshake messages are limited to a small size for security +static constexpr uint16_t MAX_HANDSHAKE_SIZE = 128; + +// Data message limits vary by platform based on available memory +#ifdef USE_ESP8266 +static constexpr uint16_t MAX_MESSAGE_SIZE = 8192; // 8 KiB for ESP8266 +#else +static constexpr uint16_t MAX_MESSAGE_SIZE = 32768; // 32 KiB for ESP32 and other platforms +#endif + // Forward declaration struct ClientInfo; @@ -79,7 +91,7 @@ class APIFrameHelper { virtual APIError init() = 0; virtual APIError loop(); virtual APIError read_packet(ReadPacketBuffer *buffer) = 0; - bool can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } + bool can_write_without_blocking() { return this->state_ == State::DATA && this->tx_buf_count_ == 0; } std::string getpeername() { return socket_->getpeername(); } int getpeername(struct sockaddr *addr, socklen_t *addrlen) { return socket_->getpeername(addr, addrlen); } APIError close() { @@ -161,7 +173,7 @@ class APIFrameHelper { }; // Containers (size varies, but typically 12+ bytes on 32-bit) - std::deque tx_buf_; + std::array, API_MAX_SEND_QUEUE> tx_buf_; std::vector reusable_iovs_; std::vector rx_buf_; @@ -174,7 +186,10 @@ class APIFrameHelper { State state_{State::INITIALIZE}; uint8_t frame_header_padding_{0}; uint8_t frame_footer_size_{0}; - // 5 bytes total, 3 bytes padding + uint8_t tx_buf_head_{0}; + uint8_t tx_buf_tail_{0}; + uint8_t tx_buf_count_{0}; + // 8 bytes total, 0 bytes padding // Common initialization for both plaintext and noise protocols APIError init_common_(); diff --git a/esphome/components/api/api_frame_helper_noise.cpp b/esphome/components/api/api_frame_helper_noise.cpp index 0e49f93db5..1213e65948 100644 --- a/esphome/components/api/api_frame_helper_noise.cpp +++ b/esphome/components/api/api_frame_helper_noise.cpp @@ -24,7 +24,8 @@ static const char *const PROLOGUE_INIT = "NoiseAPIInit"; #endif static constexpr size_t PROLOGUE_INIT_LEN = 12; // strlen("NoiseAPIInit") -#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__) +#define HELPER_LOG(msg, ...) \ + ESP_LOGVV(TAG, "%s (%s): " msg, this->client_info_->name.c_str(), this->client_info_->peername.c_str(), ##__VA_ARGS__) #ifdef HELPER_LOG_PACKETS #define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str()) @@ -131,26 +132,16 @@ APIError APINoiseFrameHelper::loop() { return APIFrameHelper::loop(); } -/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter +/** Read a packet into the rx_buf_. * - * @param frame: The struct to hold the frame information in. - * msg_start: points to the start of the payload - this pointer is only valid until the next - * try_receive_raw_ call - * - * @return 0 if a full packet is in rx_buf_ - * @return -1 if error, check errno. + * @return APIError::OK if a full packet is in rx_buf_ * * errno EWOULDBLOCK: Packet could not be read without blocking. Try again later. * errno ENOMEM: Not enough memory for reading packet. * errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. * errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase. */ -APIError APINoiseFrameHelper::try_read_frame_(std::vector *frame) { - if (frame == nullptr) { - HELPER_LOG("Bad argument for try_read_frame_"); - return APIError::BAD_ARG; - } - +APIError APINoiseFrameHelper::try_read_frame_() { // read header if (rx_header_buf_len_ < 3) { // no header information yet @@ -177,16 +168,17 @@ APIError APINoiseFrameHelper::try_read_frame_(std::vector *frame) { // read body uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2]; - if (state_ != State::DATA && msg_size > 128) { - // for handshake message only permit up to 128 bytes + // Check against size limits to prevent OOM: MAX_HANDSHAKE_SIZE for handshake, MAX_MESSAGE_SIZE for data + uint16_t limit = (state_ == State::DATA) ? MAX_MESSAGE_SIZE : MAX_HANDSHAKE_SIZE; + if (msg_size > limit) { state_ = State::FAILED; - HELPER_LOG("Bad packet len for handshake: %d", msg_size); - return APIError::BAD_HANDSHAKE_PACKET_LEN; + HELPER_LOG("Bad packet: message size %u exceeds maximum %u", msg_size, limit); + return (state_ == State::DATA) ? APIError::BAD_DATA_PACKET : APIError::BAD_HANDSHAKE_PACKET_LEN; } - // reserve space for body - if (rx_buf_.size() != msg_size) { - rx_buf_.resize(msg_size); + // Reserve space for body + if (this->rx_buf_.size() != msg_size) { + this->rx_buf_.resize(msg_size); } if (rx_buf_len_ < msg_size) { @@ -204,12 +196,12 @@ APIError APINoiseFrameHelper::try_read_frame_(std::vector *frame) { } } - LOG_PACKET_RECEIVED(rx_buf_); - *frame = std::move(rx_buf_); - // consume msg - rx_buf_ = {}; - rx_buf_len_ = 0; - rx_header_buf_len_ = 0; + LOG_PACKET_RECEIVED(this->rx_buf_); + + // Clear state for next frame (rx_buf_ still contains data for caller) + this->rx_buf_len_ = 0; + this->rx_header_buf_len_ = 0; + return APIError::OK; } @@ -231,18 +223,17 @@ APIError APINoiseFrameHelper::state_action_() { } if (state_ == State::CLIENT_HELLO) { // waiting for client hello - std::vector frame; - aerr = try_read_frame_(&frame); + aerr = this->try_read_frame_(); if (aerr != APIError::OK) { return handle_handshake_frame_error_(aerr); } // ignore contents, may be used in future for flags // Resize for: existing prologue + 2 size bytes + frame data - size_t old_size = prologue_.size(); - prologue_.resize(old_size + 2 + frame.size()); - prologue_[old_size] = (uint8_t) (frame.size() >> 8); - prologue_[old_size + 1] = (uint8_t) frame.size(); - std::memcpy(prologue_.data() + old_size + 2, frame.data(), frame.size()); + size_t old_size = this->prologue_.size(); + this->prologue_.resize(old_size + 2 + this->rx_buf_.size()); + this->prologue_[old_size] = (uint8_t) (this->rx_buf_.size() >> 8); + this->prologue_[old_size + 1] = (uint8_t) this->rx_buf_.size(); + std::memcpy(this->prologue_.data() + old_size + 2, this->rx_buf_.data(), this->rx_buf_.size()); state_ = State::SERVER_HELLO; } @@ -284,24 +275,23 @@ APIError APINoiseFrameHelper::state_action_() { int action = noise_handshakestate_get_action(handshake_); if (action == NOISE_ACTION_READ_MESSAGE) { // waiting for handshake msg - std::vector frame; - aerr = try_read_frame_(&frame); + aerr = this->try_read_frame_(); if (aerr != APIError::OK) { return handle_handshake_frame_error_(aerr); } - if (frame.empty()) { + if (this->rx_buf_.empty()) { send_explicit_handshake_reject_(LOG_STR("Empty handshake message")); return APIError::BAD_HANDSHAKE_ERROR_BYTE; - } else if (frame[0] != 0x00) { - HELPER_LOG("Bad handshake error byte: %u", frame[0]); + } else if (this->rx_buf_[0] != 0x00) { + HELPER_LOG("Bad handshake error byte: %u", this->rx_buf_[0]); send_explicit_handshake_reject_(LOG_STR("Bad handshake error byte")); return APIError::BAD_HANDSHAKE_ERROR_BYTE; } NoiseBuffer mbuf; noise_buffer_init(mbuf); - noise_buffer_set_input(mbuf, frame.data() + 1, frame.size() - 1); + noise_buffer_set_input(mbuf, this->rx_buf_.data() + 1, this->rx_buf_.size() - 1); err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); if (err != 0) { // Special handling for MAC failure @@ -378,35 +368,33 @@ void APINoiseFrameHelper::send_explicit_handshake_reject_(const LogString *reaso state_ = orig_state; } APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { - int err; - APIError aerr; - aerr = state_action_(); + APIError aerr = this->state_action_(); if (aerr != APIError::OK) { return aerr; } - if (state_ != State::DATA) { + if (this->state_ != State::DATA) { return APIError::WOULD_BLOCK; } - std::vector frame; - aerr = try_read_frame_(&frame); + aerr = this->try_read_frame_(); if (aerr != APIError::OK) return aerr; NoiseBuffer mbuf; noise_buffer_init(mbuf); - noise_buffer_set_inout(mbuf, frame.data(), frame.size(), frame.size()); - err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); + noise_buffer_set_inout(mbuf, this->rx_buf_.data(), this->rx_buf_.size(), this->rx_buf_.size()); + int err = noise_cipherstate_decrypt(this->recv_cipher_, &mbuf); APIError decrypt_err = handle_noise_error_(err, LOG_STR("noise_cipherstate_decrypt"), APIError::CIPHERSTATE_DECRYPT_FAILED); - if (decrypt_err != APIError::OK) + if (decrypt_err != APIError::OK) { return decrypt_err; + } uint16_t msg_size = mbuf.size; - uint8_t *msg_data = frame.data(); + uint8_t *msg_data = this->rx_buf_.data(); if (msg_size < 4) { - state_ = State::FAILED; + this->state_ = State::FAILED; HELPER_LOG("Bad data packet: size %d too short", msg_size); return APIError::BAD_DATA_PACKET; } @@ -414,12 +402,12 @@ APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1]; uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3]; if (data_len > msg_size - 4) { - state_ = State::FAILED; + this->state_ = State::FAILED; HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size); return APIError::BAD_DATA_PACKET; } - buffer->container = std::move(frame); + buffer->container = std::move(this->rx_buf_); buffer->data_offset = 4; buffer->data_len = data_len; buffer->type = type; diff --git a/esphome/components/api/api_frame_helper_noise.h b/esphome/components/api/api_frame_helper_noise.h index 71a217c4ca..e3243e4fa5 100644 --- a/esphome/components/api/api_frame_helper_noise.h +++ b/esphome/components/api/api_frame_helper_noise.h @@ -28,7 +28,7 @@ class APINoiseFrameHelper final : public APIFrameHelper { protected: APIError state_action_(); - APIError try_read_frame_(std::vector *frame); + APIError try_read_frame_(); APIError write_frame_(const uint8_t *data, uint16_t len); APIError init_handshake_(); APIError check_handshake_finished_(); diff --git a/esphome/components/api/api_frame_helper_plaintext.cpp b/esphome/components/api/api_frame_helper_plaintext.cpp index 859bb26630..471e6c5404 100644 --- a/esphome/components/api/api_frame_helper_plaintext.cpp +++ b/esphome/components/api/api_frame_helper_plaintext.cpp @@ -18,7 +18,8 @@ namespace esphome::api { static const char *const TAG = "api.plaintext"; -#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->client_info_->get_combined_info().c_str(), ##__VA_ARGS__) +#define HELPER_LOG(msg, ...) \ + ESP_LOGVV(TAG, "%s (%s): " msg, this->client_info_->name.c_str(), this->client_info_->peername.c_str(), ##__VA_ARGS__) #ifdef HELPER_LOG_PACKETS #define LOG_PACKET_RECEIVED(buffer) ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(buffer).c_str()) @@ -46,21 +47,13 @@ APIError APIPlaintextFrameHelper::loop() { return APIFrameHelper::loop(); } -/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter - * - * @param frame: The struct to hold the frame information in. - * msg: store the parsed frame in that struct +/** Read a packet into the rx_buf_. * * @return See APIError * * error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. */ -APIError APIPlaintextFrameHelper::try_read_frame_(std::vector *frame) { - if (frame == nullptr) { - HELPER_LOG("Bad argument for try_read_frame_"); - return APIError::BAD_ARG; - } - +APIError APIPlaintextFrameHelper::try_read_frame_() { // read header while (!rx_header_parsed_) { // Now that we know when the socket is ready, we can read up to 3 bytes @@ -122,10 +115,10 @@ APIError APIPlaintextFrameHelper::try_read_frame_(std::vector *frame) { continue; } - if (msg_size_varint->as_uint32() > std::numeric_limits::max()) { + if (msg_size_varint->as_uint32() > MAX_MESSAGE_SIZE) { state_ = State::FAILED; HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum %u", msg_size_varint->as_uint32(), - std::numeric_limits::max()); + MAX_MESSAGE_SIZE); return APIError::BAD_DATA_PACKET; } rx_header_parsed_len_ = msg_size_varint->as_uint16(); @@ -149,9 +142,9 @@ APIError APIPlaintextFrameHelper::try_read_frame_(std::vector *frame) { } // header reading done - // reserve space for body - if (rx_buf_.size() != rx_header_parsed_len_) { - rx_buf_.resize(rx_header_parsed_len_); + // Reserve space for body + if (this->rx_buf_.size() != this->rx_header_parsed_len_) { + this->rx_buf_.resize(this->rx_header_parsed_len_); } if (rx_buf_len_ < rx_header_parsed_len_) { @@ -169,24 +162,22 @@ APIError APIPlaintextFrameHelper::try_read_frame_(std::vector *frame) { } } - LOG_PACKET_RECEIVED(rx_buf_); - *frame = std::move(rx_buf_); - // consume msg - rx_buf_ = {}; - rx_buf_len_ = 0; - rx_header_buf_pos_ = 0; - rx_header_parsed_ = false; + LOG_PACKET_RECEIVED(this->rx_buf_); + + // Clear state for next frame (rx_buf_ still contains data for caller) + this->rx_buf_len_ = 0; + this->rx_header_buf_pos_ = 0; + this->rx_header_parsed_ = false; + return APIError::OK; } -APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { - APIError aerr; - if (state_ != State::DATA) { +APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { + if (this->state_ != State::DATA) { return APIError::WOULD_BLOCK; } - std::vector frame; - aerr = try_read_frame_(&frame); + APIError aerr = this->try_read_frame_(); if (aerr != APIError::OK) { if (aerr == APIError::BAD_INDICATOR) { // Make sure to tell the remote that we don't @@ -219,10 +210,10 @@ APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) { return aerr; } - buffer->container = std::move(frame); + buffer->container = std::move(this->rx_buf_); buffer->data_offset = 0; - buffer->data_len = rx_header_parsed_len_; - buffer->type = rx_header_parsed_type_; + buffer->data_len = this->rx_header_parsed_len_; + buffer->type = this->rx_header_parsed_type_; return APIError::OK; } APIError APIPlaintextFrameHelper::write_protobuf_packet(uint8_t type, ProtoWriteBuffer buffer) { diff --git a/esphome/components/api/api_frame_helper_plaintext.h b/esphome/components/api/api_frame_helper_plaintext.h index 55a6d0f744..bba981d26b 100644 --- a/esphome/components/api/api_frame_helper_plaintext.h +++ b/esphome/components/api/api_frame_helper_plaintext.h @@ -24,7 +24,7 @@ class APIPlaintextFrameHelper final : public APIFrameHelper { APIError write_protobuf_packets(ProtoWriteBuffer buffer, std::span packets) override; protected: - APIError try_read_frame_(std::vector *frame); + APIError try_read_frame_(); // Group 2-byte aligned types uint16_t rx_header_parsed_type_ = 0; diff --git a/esphome/components/api/api_options.proto b/esphome/components/api/api_options.proto index 50c43b96fd..633f39b552 100644 --- a/esphome/components/api/api_options.proto +++ b/esphome/components/api/api_options.proto @@ -32,6 +32,13 @@ extend google.protobuf.FieldOptions { optional string fixed_array_size_define = 50010; optional string fixed_array_with_length_define = 50011; + // pointer_to_buffer: Use pointer instead of array for fixed-size byte fields + // When set, the field will be declared as a pointer (const uint8_t *data) + // instead of an array (uint8_t data[N]). This allows zero-copy on decode + // by pointing directly to the protobuf buffer. The buffer must remain valid + // until the message is processed (which is guaranteed for stack-allocated messages). + optional bool pointer_to_buffer = 50012 [default=false]; + // container_pointer: Zero-copy optimization for repeated fields. // // When container_pointer is set on a repeated field, the generated message will diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index a92fca70d6..70bcf082a6 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -22,9 +22,12 @@ bool HelloRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { } bool HelloRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { - case 1: - this->client_info = value.as_string(); + case 1: { + // Use raw data directly to avoid allocation + this->client_info = value.data(); + this->client_info_len = value.size(); break; + } default: return false; } @@ -42,18 +45,23 @@ void HelloResponse::calculate_size(ProtoSize &size) const { size.add_length(1, this->server_info_ref_.size()); size.add_length(1, this->name_ref_.size()); } -bool ConnectRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { +#ifdef USE_API_PASSWORD +bool AuthenticationRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { - case 1: - this->password = value.as_string(); + case 1: { + // Use raw data directly to avoid allocation + this->password = value.data(); + this->password_len = value.size(); break; + } default: return false; } return true; } -void ConnectResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); } -void ConnectResponse::calculate_size(ProtoSize &size) const { size.add_bool(1, this->invalid_password); } +void AuthenticationResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); } +void AuthenticationResponse::calculate_size(ProtoSize &size) const { size.add_bool(1, this->invalid_password); } +#endif #ifdef USE_AREAS void AreaInfo::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(1, this->area_id); @@ -127,6 +135,12 @@ void DeviceInfoResponse::encode(ProtoWriteBuffer buffer) const { #ifdef USE_AREAS buffer.encode_message(22, this->area); #endif +#ifdef USE_ZWAVE_PROXY + buffer.encode_uint32(23, this->zwave_proxy_feature_flags); +#endif +#ifdef USE_ZWAVE_PROXY + buffer.encode_uint32(24, this->zwave_home_id); +#endif } void DeviceInfoResponse::calculate_size(ProtoSize &size) const { #ifdef USE_API_PASSWORD @@ -179,6 +193,12 @@ void DeviceInfoResponse::calculate_size(ProtoSize &size) const { #ifdef USE_AREAS size.add_message_object(2, this->area); #endif +#ifdef USE_ZWAVE_PROXY + size.add_uint32(2, this->zwave_proxy_feature_flags); +#endif +#ifdef USE_ZWAVE_PROXY + size.add_uint32(2, this->zwave_home_id); +#endif } #ifdef USE_BINARY_SENSOR void ListEntitiesBinarySensorResponse::encode(ProtoWriteBuffer buffer) const { @@ -852,7 +872,7 @@ void HomeassistantServiceMap::calculate_size(ProtoSize &size) const { size.add_length(1, this->key_ref_.size()); size.add_length(1, this->value.size()); } -void HomeassistantServiceResponse::encode(ProtoWriteBuffer buffer) const { +void HomeassistantActionRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->service_ref_); for (auto &it : this->data) { buffer.encode_message(2, it, true); @@ -864,13 +884,64 @@ void HomeassistantServiceResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(4, it, true); } buffer.encode_bool(5, this->is_event); +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + buffer.encode_uint32(6, this->call_id); +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + buffer.encode_bool(7, this->wants_response); +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + buffer.encode_string(8, this->response_template); +#endif } -void HomeassistantServiceResponse::calculate_size(ProtoSize &size) const { +void HomeassistantActionRequest::calculate_size(ProtoSize &size) const { size.add_length(1, this->service_ref_.size()); size.add_repeated_message(1, this->data); size.add_repeated_message(1, this->data_template); size.add_repeated_message(1, this->variables); size.add_bool(1, this->is_event); +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + size.add_uint32(1, this->call_id); +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + size.add_bool(1, this->wants_response); +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + size.add_length(1, this->response_template.size()); +#endif +} +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +bool HomeassistantActionResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: + this->call_id = value.as_uint32(); + break; + case 2: + this->success = value.as_bool(); + break; + default: + return false; + } + return true; +} +bool HomeassistantActionResponse::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 3: + this->error_message = value.as_string(); + break; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + case 4: { + // Use raw data directly to avoid allocation + this->response_data = value.data(); + this->response_data_len = value.size(); + break; + } +#endif + default: + return false; + } + return true; } #endif #ifdef USE_API_HOMEASSISTANT_STATES @@ -903,9 +974,12 @@ bool HomeAssistantStateResponse::decode_length(uint32_t field_id, ProtoLengthDel #endif bool GetTimeResponse::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { - case 2: - this->timezone = value.as_string(); + case 2: { + // Use raw data directly to avoid allocation + this->timezone = value.data(); + this->timezone_len = value.size(); break; + } default: return false; } @@ -2014,9 +2088,12 @@ bool BluetoothGATTWriteRequest::decode_varint(uint32_t field_id, ProtoVarInt val } bool BluetoothGATTWriteRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { - case 4: - this->data = value.as_string(); + case 4: { + // Use raw data directly to avoid allocation + this->data = value.data(); + this->data_len = value.size(); break; + } default: return false; } @@ -2050,9 +2127,12 @@ bool BluetoothGATTWriteDescriptorRequest::decode_varint(uint32_t field_id, Proto } bool BluetoothGATTWriteDescriptorRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { switch (field_id) { - case 3: - this->data = value.as_string(); + case 3: { + // Use raw data directly to avoid allocation + this->data = value.data(); + this->data_len = value.size(); break; + } default: return false; } @@ -2368,6 +2448,52 @@ void VoiceAssistantWakeWord::calculate_size(ProtoSize &size) const { } } } +bool VoiceAssistantExternalWakeWord::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 5: + this->model_size = value.as_uint32(); + break; + default: + return false; + } + return true; +} +bool VoiceAssistantExternalWakeWord::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 1: + this->id = value.as_string(); + break; + case 2: + this->wake_word = value.as_string(); + break; + case 3: + this->trained_languages.push_back(value.as_string()); + break; + case 4: + this->model_type = value.as_string(); + break; + case 6: + this->model_hash = value.as_string(); + break; + case 7: + this->url = value.as_string(); + break; + default: + return false; + } + return true; +} +bool VoiceAssistantConfigurationRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 1: + this->external_wake_words.emplace_back(); + value.decode_to_message(this->external_wake_words.back()); + break; + default: + return false; + } + return true; +} void VoiceAssistantConfigurationResponse::encode(ProtoWriteBuffer buffer) const { for (auto &it : this->available_wake_words) { buffer.encode_message(1, it, true); @@ -3011,5 +3137,53 @@ bool UpdateCommandRequest::decode_32bit(uint32_t field_id, Proto32Bit value) { return true; } #endif +#ifdef USE_ZWAVE_PROXY +bool ZWaveProxyFrame::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 1: { + // Use raw data directly to avoid allocation + this->data = value.data(); + this->data_len = value.size(); + break; + } + default: + return false; + } + return true; +} +void ZWaveProxyFrame::encode(ProtoWriteBuffer buffer) const { buffer.encode_bytes(1, this->data, this->data_len); } +void ZWaveProxyFrame::calculate_size(ProtoSize &size) const { size.add_length(1, this->data_len); } +bool ZWaveProxyRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: + this->type = static_cast(value.as_uint32()); + break; + default: + return false; + } + return true; +} +bool ZWaveProxyRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 2: { + // Use raw data directly to avoid allocation + this->data = value.data(); + this->data_len = value.size(); + break; + } + default: + return false; + } + return true; +} +void ZWaveProxyRequest::encode(ProtoWriteBuffer buffer) const { + buffer.encode_uint32(1, static_cast(this->type)); + buffer.encode_bytes(2, this->data, this->data_len); +} +void ZWaveProxyRequest::calculate_size(ProtoSize &size) const { + size.add_uint32(1, static_cast(this->type)); + size.add_length(2, this->data_len); +} +#endif } // namespace esphome::api diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index 5b6d694e3b..d9e68ece9b 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -276,6 +276,13 @@ enum UpdateCommand : uint32_t { UPDATE_COMMAND_CHECK = 2, }; #endif +#ifdef USE_ZWAVE_PROXY +enum ZWaveProxyRequestType : uint32_t { + ZWAVE_PROXY_REQUEST_TYPE_SUBSCRIBE = 0, + ZWAVE_PROXY_REQUEST_TYPE_UNSUBSCRIBE = 1, + ZWAVE_PROXY_REQUEST_TYPE_HOME_ID_CHANGE = 2, +}; +#endif } // namespace enums @@ -324,11 +331,12 @@ class CommandProtoMessage : public ProtoDecodableMessage { class HelloRequest final : public ProtoDecodableMessage { public: static constexpr uint8_t MESSAGE_TYPE = 1; - static constexpr uint8_t ESTIMATED_SIZE = 17; + static constexpr uint8_t ESTIMATED_SIZE = 27; #ifdef HAS_PROTO_MESSAGE_DUMP const char *message_name() const override { return "hello_request"; } #endif - std::string client_info{}; + const uint8_t *client_info{nullptr}; + uint16_t client_info_len{0}; uint32_t api_version_major{0}; uint32_t api_version_minor{0}; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -360,14 +368,16 @@ class HelloResponse final : public ProtoMessage { protected: }; -class ConnectRequest final : public ProtoDecodableMessage { +#ifdef USE_API_PASSWORD +class AuthenticationRequest final : public ProtoDecodableMessage { public: static constexpr uint8_t MESSAGE_TYPE = 3; - static constexpr uint8_t ESTIMATED_SIZE = 9; + static constexpr uint8_t ESTIMATED_SIZE = 19; #ifdef HAS_PROTO_MESSAGE_DUMP - const char *message_name() const override { return "connect_request"; } + const char *message_name() const override { return "authentication_request"; } #endif - std::string password{}; + const uint8_t *password{nullptr}; + uint16_t password_len{0}; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -375,12 +385,12 @@ class ConnectRequest final : public ProtoDecodableMessage { protected: bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; -class ConnectResponse final : public ProtoMessage { +class AuthenticationResponse final : public ProtoMessage { public: static constexpr uint8_t MESSAGE_TYPE = 4; static constexpr uint8_t ESTIMATED_SIZE = 2; #ifdef HAS_PROTO_MESSAGE_DUMP - const char *message_name() const override { return "connect_response"; } + const char *message_name() const override { return "authentication_response"; } #endif bool invalid_password{false}; void encode(ProtoWriteBuffer buffer) const override; @@ -391,6 +401,7 @@ class ConnectResponse final : public ProtoMessage { protected: }; +#endif class DisconnectRequest final : public ProtoMessage { public: static constexpr uint8_t MESSAGE_TYPE = 5; @@ -490,7 +501,7 @@ class DeviceInfo final : public ProtoMessage { class DeviceInfoResponse final : public ProtoMessage { public: static constexpr uint8_t MESSAGE_TYPE = 10; - static constexpr uint8_t ESTIMATED_SIZE = 247; + static constexpr uint16_t ESTIMATED_SIZE = 257; #ifdef HAS_PROTO_MESSAGE_DUMP const char *message_name() const override { return "device_info_response"; } #endif @@ -550,6 +561,12 @@ class DeviceInfoResponse final : public ProtoMessage { #endif #ifdef USE_AREAS AreaInfo area{}; +#endif +#ifdef USE_ZWAVE_PROXY + uint32_t zwave_proxy_feature_flags{0}; +#endif +#ifdef USE_ZWAVE_PROXY + uint32_t zwave_home_id{0}; #endif void encode(ProtoWriteBuffer buffer) const override; void calculate_size(ProtoSize &size) const override; @@ -1084,12 +1101,12 @@ class HomeassistantServiceMap final : public ProtoMessage { protected: }; -class HomeassistantServiceResponse final : public ProtoMessage { +class HomeassistantActionRequest final : public ProtoMessage { public: static constexpr uint8_t MESSAGE_TYPE = 35; - static constexpr uint8_t ESTIMATED_SIZE = 113; + static constexpr uint8_t ESTIMATED_SIZE = 128; #ifdef HAS_PROTO_MESSAGE_DUMP - const char *message_name() const override { return "homeassistant_service_response"; } + const char *message_name() const override { return "homeassistant_action_request"; } #endif StringRef service_ref_{}; void set_service(const StringRef &ref) { this->service_ref_ = ref; } @@ -1097,6 +1114,15 @@ class HomeassistantServiceResponse final : public ProtoMessage { std::vector data_template{}; std::vector variables{}; bool is_event{false}; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + uint32_t call_id{0}; +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + bool wants_response{false}; +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + std::string response_template{}; +#endif void encode(ProtoWriteBuffer buffer) const override; void calculate_size(ProtoSize &size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP @@ -1106,6 +1132,30 @@ class HomeassistantServiceResponse final : public ProtoMessage { protected: }; #endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +class HomeassistantActionResponse final : public ProtoDecodableMessage { + public: + static constexpr uint8_t MESSAGE_TYPE = 130; + static constexpr uint8_t ESTIMATED_SIZE = 34; +#ifdef HAS_PROTO_MESSAGE_DUMP + const char *message_name() const override { return "homeassistant_action_response"; } +#endif + uint32_t call_id{0}; + bool success{false}; + std::string error_message{}; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + const uint8_t *response_data{nullptr}; + uint16_t response_data_len{0}; +#endif +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +#endif #ifdef USE_API_HOMEASSISTANT_STATES class SubscribeHomeAssistantStatesRequest final : public ProtoMessage { public: @@ -1174,12 +1224,13 @@ class GetTimeRequest final : public ProtoMessage { class GetTimeResponse final : public ProtoDecodableMessage { public: static constexpr uint8_t MESSAGE_TYPE = 37; - static constexpr uint8_t ESTIMATED_SIZE = 14; + static constexpr uint8_t ESTIMATED_SIZE = 24; #ifdef HAS_PROTO_MESSAGE_DUMP const char *message_name() const override { return "get_time_response"; } #endif uint32_t epoch_seconds{0}; - std::string timezone{}; + const uint8_t *timezone{nullptr}; + uint16_t timezone_len{0}; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1971,14 +2022,15 @@ class BluetoothGATTReadResponse final : public ProtoMessage { class BluetoothGATTWriteRequest final : public ProtoDecodableMessage { public: static constexpr uint8_t MESSAGE_TYPE = 75; - static constexpr uint8_t ESTIMATED_SIZE = 19; + static constexpr uint8_t ESTIMATED_SIZE = 29; #ifdef HAS_PROTO_MESSAGE_DUMP const char *message_name() const override { return "bluetooth_gatt_write_request"; } #endif uint64_t address{0}; uint32_t handle{0}; bool response{false}; - std::string data{}; + const uint8_t *data{nullptr}; + uint16_t data_len{0}; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2006,13 +2058,14 @@ class BluetoothGATTReadDescriptorRequest final : public ProtoDecodableMessage { class BluetoothGATTWriteDescriptorRequest final : public ProtoDecodableMessage { public: static constexpr uint8_t MESSAGE_TYPE = 77; - static constexpr uint8_t ESTIMATED_SIZE = 17; + static constexpr uint8_t ESTIMATED_SIZE = 27; #ifdef HAS_PROTO_MESSAGE_DUMP const char *message_name() const override { return "bluetooth_gatt_write_descriptor_request"; } #endif uint64_t address{0}; uint32_t handle{0}; - std::string data{}; + const uint8_t *data{nullptr}; + uint16_t data_len{0}; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2437,18 +2490,37 @@ class VoiceAssistantWakeWord final : public ProtoMessage { protected: }; -class VoiceAssistantConfigurationRequest final : public ProtoMessage { +class VoiceAssistantExternalWakeWord final : public ProtoDecodableMessage { public: - static constexpr uint8_t MESSAGE_TYPE = 121; - static constexpr uint8_t ESTIMATED_SIZE = 0; -#ifdef HAS_PROTO_MESSAGE_DUMP - const char *message_name() const override { return "voice_assistant_configuration_request"; } -#endif + std::string id{}; + std::string wake_word{}; + std::vector trained_languages{}; + std::string model_type{}; + uint32_t model_size{0}; + std::string model_hash{}; + std::string url{}; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif protected: + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +class VoiceAssistantConfigurationRequest final : public ProtoDecodableMessage { + public: + static constexpr uint8_t MESSAGE_TYPE = 121; + static constexpr uint8_t ESTIMATED_SIZE = 34; +#ifdef HAS_PROTO_MESSAGE_DUMP + const char *message_name() const override { return "voice_assistant_configuration_request"; } +#endif + std::vector external_wake_words{}; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; class VoiceAssistantConfigurationResponse final : public ProtoMessage { public: @@ -2911,5 +2983,45 @@ class UpdateCommandRequest final : public CommandProtoMessage { bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; #endif +#ifdef USE_ZWAVE_PROXY +class ZWaveProxyFrame final : public ProtoDecodableMessage { + public: + static constexpr uint8_t MESSAGE_TYPE = 128; + static constexpr uint8_t ESTIMATED_SIZE = 19; +#ifdef HAS_PROTO_MESSAGE_DUMP + const char *message_name() const override { return "z_wave_proxy_frame"; } +#endif + const uint8_t *data{nullptr}; + uint16_t data_len{0}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(ProtoSize &size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; +}; +class ZWaveProxyRequest final : public ProtoDecodableMessage { + public: + static constexpr uint8_t MESSAGE_TYPE = 129; + static constexpr uint8_t ESTIMATED_SIZE = 21; +#ifdef HAS_PROTO_MESSAGE_DUMP + const char *message_name() const override { return "z_wave_proxy_request"; } +#endif + enums::ZWaveProxyRequestType type{}; + const uint8_t *data{nullptr}; + uint16_t data_len{0}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(ProtoSize &size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +#endif } // namespace esphome::api diff --git a/esphome/components/api/api_pb2_dump.cpp b/esphome/components/api/api_pb2_dump.cpp index b5e98a9f28..cf732e451b 100644 --- a/esphome/components/api/api_pb2_dump.cpp +++ b/esphome/components/api/api_pb2_dump.cpp @@ -655,10 +655,26 @@ template<> const char *proto_enum_to_string(enums::UpdateC } } #endif +#ifdef USE_ZWAVE_PROXY +template<> const char *proto_enum_to_string(enums::ZWaveProxyRequestType value) { + switch (value) { + case enums::ZWAVE_PROXY_REQUEST_TYPE_SUBSCRIBE: + return "ZWAVE_PROXY_REQUEST_TYPE_SUBSCRIBE"; + case enums::ZWAVE_PROXY_REQUEST_TYPE_UNSUBSCRIBE: + return "ZWAVE_PROXY_REQUEST_TYPE_UNSUBSCRIBE"; + case enums::ZWAVE_PROXY_REQUEST_TYPE_HOME_ID_CHANGE: + return "ZWAVE_PROXY_REQUEST_TYPE_HOME_ID_CHANGE"; + default: + return "UNKNOWN"; + } +} +#endif void HelloRequest::dump_to(std::string &out) const { MessageDumpHelper helper(out, "HelloRequest"); - dump_field(out, "client_info", this->client_info); + out.append(" client_info: "); + out.append(format_hex_pretty(this->client_info, this->client_info_len)); + out.append("\n"); dump_field(out, "api_version_major", this->api_version_major); dump_field(out, "api_version_minor", this->api_version_minor); } @@ -669,8 +685,18 @@ void HelloResponse::dump_to(std::string &out) const { dump_field(out, "server_info", this->server_info_ref_); dump_field(out, "name", this->name_ref_); } -void ConnectRequest::dump_to(std::string &out) const { dump_field(out, "password", this->password); } -void ConnectResponse::dump_to(std::string &out) const { dump_field(out, "invalid_password", this->invalid_password); } +#ifdef USE_API_PASSWORD +void AuthenticationRequest::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "AuthenticationRequest"); + out.append(" password: "); + out.append(format_hex_pretty(this->password, this->password_len)); + out.append("\n"); +} +void AuthenticationResponse::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "AuthenticationResponse"); + dump_field(out, "invalid_password", this->invalid_password); +} +#endif void DisconnectRequest::dump_to(std::string &out) const { out.append("DisconnectRequest {}"); } void DisconnectResponse::dump_to(std::string &out) const { out.append("DisconnectResponse {}"); } void PingRequest::dump_to(std::string &out) const { out.append("PingRequest {}"); } @@ -749,6 +775,12 @@ void DeviceInfoResponse::dump_to(std::string &out) const { this->area.dump_to(out); out.append("\n"); #endif +#ifdef USE_ZWAVE_PROXY + dump_field(out, "zwave_proxy_feature_flags", this->zwave_proxy_feature_flags); +#endif +#ifdef USE_ZWAVE_PROXY + dump_field(out, "zwave_home_id", this->zwave_home_id); +#endif } void ListEntitiesRequest::dump_to(std::string &out) const { out.append("ListEntitiesRequest {}"); } void ListEntitiesDoneResponse::dump_to(std::string &out) const { out.append("ListEntitiesDoneResponse {}"); } @@ -1071,8 +1103,8 @@ void HomeassistantServiceMap::dump_to(std::string &out) const { dump_field(out, "key", this->key_ref_); dump_field(out, "value", this->value); } -void HomeassistantServiceResponse::dump_to(std::string &out) const { - MessageDumpHelper helper(out, "HomeassistantServiceResponse"); +void HomeassistantActionRequest::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "HomeassistantActionRequest"); dump_field(out, "service", this->service_ref_); for (const auto &it : this->data) { out.append(" data: "); @@ -1090,6 +1122,28 @@ void HomeassistantServiceResponse::dump_to(std::string &out) const { out.append("\n"); } dump_field(out, "is_event", this->is_event); +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + dump_field(out, "call_id", this->call_id); +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + dump_field(out, "wants_response", this->wants_response); +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + dump_field(out, "response_template", this->response_template); +#endif +} +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +void HomeassistantActionResponse::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "HomeassistantActionResponse"); + dump_field(out, "call_id", this->call_id); + dump_field(out, "success", this->success); + dump_field(out, "error_message", this->error_message); +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + out.append(" response_data: "); + out.append(format_hex_pretty(this->response_data, this->response_data_len)); + out.append("\n"); +#endif } #endif #ifdef USE_API_HOMEASSISTANT_STATES @@ -1113,7 +1167,9 @@ void GetTimeRequest::dump_to(std::string &out) const { out.append("GetTimeReques void GetTimeResponse::dump_to(std::string &out) const { MessageDumpHelper helper(out, "GetTimeResponse"); dump_field(out, "epoch_seconds", this->epoch_seconds); - dump_field(out, "timezone", this->timezone); + out.append(" timezone: "); + out.append(format_hex_pretty(this->timezone, this->timezone_len)); + out.append("\n"); } #ifdef USE_API_SERVICES void ListEntitiesServicesArgument::dump_to(std::string &out) const { @@ -1626,7 +1682,7 @@ void BluetoothGATTWriteRequest::dump_to(std::string &out) const { dump_field(out, "handle", this->handle); dump_field(out, "response", this->response); out.append(" data: "); - out.append(format_hex_pretty(reinterpret_cast(this->data.data()), this->data.size())); + out.append(format_hex_pretty(this->data, this->data_len)); out.append("\n"); } void BluetoothGATTReadDescriptorRequest::dump_to(std::string &out) const { @@ -1639,7 +1695,7 @@ void BluetoothGATTWriteDescriptorRequest::dump_to(std::string &out) const { dump_field(out, "address", this->address); dump_field(out, "handle", this->handle); out.append(" data: "); - out.append(format_hex_pretty(reinterpret_cast(this->data.data()), this->data.size())); + out.append(format_hex_pretty(this->data, this->data_len)); out.append("\n"); } void BluetoothGATTNotifyRequest::dump_to(std::string &out) const { @@ -1792,8 +1848,25 @@ void VoiceAssistantWakeWord::dump_to(std::string &out) const { dump_field(out, "trained_languages", it, 4); } } +void VoiceAssistantExternalWakeWord::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "VoiceAssistantExternalWakeWord"); + dump_field(out, "id", this->id); + dump_field(out, "wake_word", this->wake_word); + for (const auto &it : this->trained_languages) { + dump_field(out, "trained_languages", it, 4); + } + dump_field(out, "model_type", this->model_type); + dump_field(out, "model_size", this->model_size); + dump_field(out, "model_hash", this->model_hash); + dump_field(out, "url", this->url); +} void VoiceAssistantConfigurationRequest::dump_to(std::string &out) const { - out.append("VoiceAssistantConfigurationRequest {}"); + MessageDumpHelper helper(out, "VoiceAssistantConfigurationRequest"); + for (const auto &it : this->external_wake_words) { + out.append(" external_wake_words: "); + it.dump_to(out); + out.append("\n"); + } } void VoiceAssistantConfigurationResponse::dump_to(std::string &out) const { MessageDumpHelper helper(out, "VoiceAssistantConfigurationResponse"); @@ -2102,6 +2175,21 @@ void UpdateCommandRequest::dump_to(std::string &out) const { #endif } #endif +#ifdef USE_ZWAVE_PROXY +void ZWaveProxyFrame::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "ZWaveProxyFrame"); + out.append(" data: "); + out.append(format_hex_pretty(this->data, this->data_len)); + out.append("\n"); +} +void ZWaveProxyRequest::dump_to(std::string &out) const { + MessageDumpHelper helper(out, "ZWaveProxyRequest"); + dump_field(out, "type", static_cast(this->type)); + out.append(" data: "); + out.append(format_hex_pretty(this->data, this->data_len)); + out.append("\n"); +} +#endif } // namespace esphome::api diff --git a/esphome/components/api/api_pb2_service.cpp b/esphome/components/api/api_pb2_service.cpp index 2598e9a0fb..9d227af0a3 100644 --- a/esphome/components/api/api_pb2_service.cpp +++ b/esphome/components/api/api_pb2_service.cpp @@ -24,15 +24,17 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, this->on_hello_request(msg); break; } - case ConnectRequest::MESSAGE_TYPE: { - ConnectRequest msg; +#ifdef USE_API_PASSWORD + case AuthenticationRequest::MESSAGE_TYPE: { + AuthenticationRequest msg; msg.decode(msg_data, msg_size); #ifdef HAS_PROTO_MESSAGE_DUMP - ESP_LOGVV(TAG, "on_connect_request: %s", msg.dump().c_str()); + ESP_LOGVV(TAG, "on_authentication_request: %s", msg.dump().c_str()); #endif - this->on_connect_request(msg); + this->on_authentication_request(msg); break; } +#endif case DisconnectRequest::MESSAGE_TYPE: { DisconnectRequest msg; // Empty message: no decode needed @@ -546,7 +548,7 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, #ifdef USE_VOICE_ASSISTANT case VoiceAssistantConfigurationRequest::MESSAGE_TYPE: { VoiceAssistantConfigurationRequest msg; - // Empty message: no decode needed + msg.decode(msg_data, msg_size); #ifdef HAS_PROTO_MESSAGE_DUMP ESP_LOGVV(TAG, "on_voice_assistant_configuration_request: %s", msg.dump().c_str()); #endif @@ -586,6 +588,39 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, this->on_bluetooth_scanner_set_mode_request(msg); break; } +#endif +#ifdef USE_ZWAVE_PROXY + case ZWaveProxyFrame::MESSAGE_TYPE: { + ZWaveProxyFrame msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_z_wave_proxy_frame: %s", msg.dump().c_str()); +#endif + this->on_z_wave_proxy_frame(msg); + break; + } +#endif +#ifdef USE_ZWAVE_PROXY + case ZWaveProxyRequest::MESSAGE_TYPE: { + ZWaveProxyRequest msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_z_wave_proxy_request: %s", msg.dump().c_str()); +#endif + this->on_z_wave_proxy_request(msg); + break; + } +#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + case HomeassistantActionResponse::MESSAGE_TYPE: { + HomeassistantActionResponse msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_homeassistant_action_response: %s", msg.dump().c_str()); +#endif + this->on_homeassistant_action_response(msg); + break; + } #endif default: break; @@ -597,11 +632,13 @@ void APIServerConnection::on_hello_request(const HelloRequest &msg) { this->on_fatal_error(); } } -void APIServerConnection::on_connect_request(const ConnectRequest &msg) { - if (!this->send_connect_response(msg)) { +#ifdef USE_API_PASSWORD +void APIServerConnection::on_authentication_request(const AuthenticationRequest &msg) { + if (!this->send_authenticate_response(msg)) { this->on_fatal_error(); } } +#endif void APIServerConnection::on_disconnect_request(const DisconnectRequest &msg) { if (!this->send_disconnect_response(msg)) { this->on_fatal_error(); @@ -613,241 +650,139 @@ void APIServerConnection::on_ping_request(const PingRequest &msg) { } } void APIServerConnection::on_device_info_request(const DeviceInfoRequest &msg) { - if (this->check_connection_setup_() && !this->send_device_info_response(msg)) { + if (!this->send_device_info_response(msg)) { this->on_fatal_error(); } } -void APIServerConnection::on_list_entities_request(const ListEntitiesRequest &msg) { - if (this->check_authenticated_()) { - this->list_entities(msg); - } -} +void APIServerConnection::on_list_entities_request(const ListEntitiesRequest &msg) { this->list_entities(msg); } void APIServerConnection::on_subscribe_states_request(const SubscribeStatesRequest &msg) { - if (this->check_authenticated_()) { - this->subscribe_states(msg); - } -} -void APIServerConnection::on_subscribe_logs_request(const SubscribeLogsRequest &msg) { - if (this->check_authenticated_()) { - this->subscribe_logs(msg); - } + this->subscribe_states(msg); } +void APIServerConnection::on_subscribe_logs_request(const SubscribeLogsRequest &msg) { this->subscribe_logs(msg); } #ifdef USE_API_HOMEASSISTANT_SERVICES void APIServerConnection::on_subscribe_homeassistant_services_request( const SubscribeHomeassistantServicesRequest &msg) { - if (this->check_authenticated_()) { - this->subscribe_homeassistant_services(msg); - } + this->subscribe_homeassistant_services(msg); } #endif #ifdef USE_API_HOMEASSISTANT_STATES void APIServerConnection::on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &msg) { - if (this->check_authenticated_()) { - this->subscribe_home_assistant_states(msg); - } + this->subscribe_home_assistant_states(msg); } #endif #ifdef USE_API_SERVICES -void APIServerConnection::on_execute_service_request(const ExecuteServiceRequest &msg) { - if (this->check_authenticated_()) { - this->execute_service(msg); - } -} +void APIServerConnection::on_execute_service_request(const ExecuteServiceRequest &msg) { this->execute_service(msg); } #endif #ifdef USE_API_NOISE void APIServerConnection::on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &msg) { - if (this->check_authenticated_() && !this->send_noise_encryption_set_key_response(msg)) { + if (!this->send_noise_encryption_set_key_response(msg)) { this->on_fatal_error(); } } #endif #ifdef USE_BUTTON -void APIServerConnection::on_button_command_request(const ButtonCommandRequest &msg) { - if (this->check_authenticated_()) { - this->button_command(msg); - } -} +void APIServerConnection::on_button_command_request(const ButtonCommandRequest &msg) { this->button_command(msg); } #endif #ifdef USE_CAMERA -void APIServerConnection::on_camera_image_request(const CameraImageRequest &msg) { - if (this->check_authenticated_()) { - this->camera_image(msg); - } -} +void APIServerConnection::on_camera_image_request(const CameraImageRequest &msg) { this->camera_image(msg); } #endif #ifdef USE_CLIMATE -void APIServerConnection::on_climate_command_request(const ClimateCommandRequest &msg) { - if (this->check_authenticated_()) { - this->climate_command(msg); - } -} +void APIServerConnection::on_climate_command_request(const ClimateCommandRequest &msg) { this->climate_command(msg); } #endif #ifdef USE_COVER -void APIServerConnection::on_cover_command_request(const CoverCommandRequest &msg) { - if (this->check_authenticated_()) { - this->cover_command(msg); - } -} +void APIServerConnection::on_cover_command_request(const CoverCommandRequest &msg) { this->cover_command(msg); } #endif #ifdef USE_DATETIME_DATE -void APIServerConnection::on_date_command_request(const DateCommandRequest &msg) { - if (this->check_authenticated_()) { - this->date_command(msg); - } -} +void APIServerConnection::on_date_command_request(const DateCommandRequest &msg) { this->date_command(msg); } #endif #ifdef USE_DATETIME_DATETIME void APIServerConnection::on_date_time_command_request(const DateTimeCommandRequest &msg) { - if (this->check_authenticated_()) { - this->datetime_command(msg); - } + this->datetime_command(msg); } #endif #ifdef USE_FAN -void APIServerConnection::on_fan_command_request(const FanCommandRequest &msg) { - if (this->check_authenticated_()) { - this->fan_command(msg); - } -} +void APIServerConnection::on_fan_command_request(const FanCommandRequest &msg) { this->fan_command(msg); } #endif #ifdef USE_LIGHT -void APIServerConnection::on_light_command_request(const LightCommandRequest &msg) { - if (this->check_authenticated_()) { - this->light_command(msg); - } -} +void APIServerConnection::on_light_command_request(const LightCommandRequest &msg) { this->light_command(msg); } #endif #ifdef USE_LOCK -void APIServerConnection::on_lock_command_request(const LockCommandRequest &msg) { - if (this->check_authenticated_()) { - this->lock_command(msg); - } -} +void APIServerConnection::on_lock_command_request(const LockCommandRequest &msg) { this->lock_command(msg); } #endif #ifdef USE_MEDIA_PLAYER void APIServerConnection::on_media_player_command_request(const MediaPlayerCommandRequest &msg) { - if (this->check_authenticated_()) { - this->media_player_command(msg); - } + this->media_player_command(msg); } #endif #ifdef USE_NUMBER -void APIServerConnection::on_number_command_request(const NumberCommandRequest &msg) { - if (this->check_authenticated_()) { - this->number_command(msg); - } -} +void APIServerConnection::on_number_command_request(const NumberCommandRequest &msg) { this->number_command(msg); } #endif #ifdef USE_SELECT -void APIServerConnection::on_select_command_request(const SelectCommandRequest &msg) { - if (this->check_authenticated_()) { - this->select_command(msg); - } -} +void APIServerConnection::on_select_command_request(const SelectCommandRequest &msg) { this->select_command(msg); } #endif #ifdef USE_SIREN -void APIServerConnection::on_siren_command_request(const SirenCommandRequest &msg) { - if (this->check_authenticated_()) { - this->siren_command(msg); - } -} +void APIServerConnection::on_siren_command_request(const SirenCommandRequest &msg) { this->siren_command(msg); } #endif #ifdef USE_SWITCH -void APIServerConnection::on_switch_command_request(const SwitchCommandRequest &msg) { - if (this->check_authenticated_()) { - this->switch_command(msg); - } -} +void APIServerConnection::on_switch_command_request(const SwitchCommandRequest &msg) { this->switch_command(msg); } #endif #ifdef USE_TEXT -void APIServerConnection::on_text_command_request(const TextCommandRequest &msg) { - if (this->check_authenticated_()) { - this->text_command(msg); - } -} +void APIServerConnection::on_text_command_request(const TextCommandRequest &msg) { this->text_command(msg); } #endif #ifdef USE_DATETIME_TIME -void APIServerConnection::on_time_command_request(const TimeCommandRequest &msg) { - if (this->check_authenticated_()) { - this->time_command(msg); - } -} +void APIServerConnection::on_time_command_request(const TimeCommandRequest &msg) { this->time_command(msg); } #endif #ifdef USE_UPDATE -void APIServerConnection::on_update_command_request(const UpdateCommandRequest &msg) { - if (this->check_authenticated_()) { - this->update_command(msg); - } -} +void APIServerConnection::on_update_command_request(const UpdateCommandRequest &msg) { this->update_command(msg); } #endif #ifdef USE_VALVE -void APIServerConnection::on_valve_command_request(const ValveCommandRequest &msg) { - if (this->check_authenticated_()) { - this->valve_command(msg); - } -} +void APIServerConnection::on_valve_command_request(const ValveCommandRequest &msg) { this->valve_command(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_subscribe_bluetooth_le_advertisements_request( const SubscribeBluetoothLEAdvertisementsRequest &msg) { - if (this->check_authenticated_()) { - this->subscribe_bluetooth_le_advertisements(msg); - } + this->subscribe_bluetooth_le_advertisements(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_device_request(const BluetoothDeviceRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_device_request(msg); - } + this->bluetooth_device_request(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_gatt_get_services_request(const BluetoothGATTGetServicesRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_gatt_get_services(msg); - } + this->bluetooth_gatt_get_services(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_gatt_read_request(const BluetoothGATTReadRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_gatt_read(msg); - } + this->bluetooth_gatt_read(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_gatt_write_request(const BluetoothGATTWriteRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_gatt_write(msg); - } + this->bluetooth_gatt_write(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_gatt_read_descriptor_request(const BluetoothGATTReadDescriptorRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_gatt_read_descriptor(msg); - } + this->bluetooth_gatt_read_descriptor(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_gatt_write_descriptor_request(const BluetoothGATTWriteDescriptorRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_gatt_write_descriptor(msg); - } + this->bluetooth_gatt_write_descriptor(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_gatt_notify_request(const BluetoothGATTNotifyRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_gatt_notify(msg); - } + this->bluetooth_gatt_notify(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_subscribe_bluetooth_connections_free_request( const SubscribeBluetoothConnectionsFreeRequest &msg) { - if (this->check_authenticated_() && !this->send_subscribe_bluetooth_connections_free_response(msg)) { + if (!this->send_subscribe_bluetooth_connections_free_response(msg)) { this->on_fatal_error(); } } @@ -855,45 +790,68 @@ void APIServerConnection::on_subscribe_bluetooth_connections_free_request( #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_unsubscribe_bluetooth_le_advertisements_request( const UnsubscribeBluetoothLEAdvertisementsRequest &msg) { - if (this->check_authenticated_()) { - this->unsubscribe_bluetooth_le_advertisements(msg); - } + this->unsubscribe_bluetooth_le_advertisements(msg); } #endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &msg) { - if (this->check_authenticated_()) { - this->bluetooth_scanner_set_mode(msg); - } + this->bluetooth_scanner_set_mode(msg); } #endif #ifdef USE_VOICE_ASSISTANT void APIServerConnection::on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &msg) { - if (this->check_authenticated_()) { - this->subscribe_voice_assistant(msg); - } + this->subscribe_voice_assistant(msg); } #endif #ifdef USE_VOICE_ASSISTANT void APIServerConnection::on_voice_assistant_configuration_request(const VoiceAssistantConfigurationRequest &msg) { - if (this->check_authenticated_() && !this->send_voice_assistant_get_configuration_response(msg)) { + if (!this->send_voice_assistant_get_configuration_response(msg)) { this->on_fatal_error(); } } #endif #ifdef USE_VOICE_ASSISTANT void APIServerConnection::on_voice_assistant_set_configuration(const VoiceAssistantSetConfiguration &msg) { - if (this->check_authenticated_()) { - this->voice_assistant_set_configuration(msg); - } + this->voice_assistant_set_configuration(msg); } #endif #ifdef USE_ALARM_CONTROL_PANEL void APIServerConnection::on_alarm_control_panel_command_request(const AlarmControlPanelCommandRequest &msg) { - if (this->check_authenticated_()) { - this->alarm_control_panel_command(msg); - } + this->alarm_control_panel_command(msg); } #endif +#ifdef USE_ZWAVE_PROXY +void APIServerConnection::on_z_wave_proxy_frame(const ZWaveProxyFrame &msg) { this->zwave_proxy_frame(msg); } +#endif +#ifdef USE_ZWAVE_PROXY +void APIServerConnection::on_z_wave_proxy_request(const ZWaveProxyRequest &msg) { this->zwave_proxy_request(msg); } +#endif + +void APIServerConnection::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) { + // Check authentication/connection requirements for messages + switch (msg_type) { + case HelloRequest::MESSAGE_TYPE: // No setup required +#ifdef USE_API_PASSWORD + case AuthenticationRequest::MESSAGE_TYPE: // No setup required +#endif + case DisconnectRequest::MESSAGE_TYPE: // No setup required + case PingRequest::MESSAGE_TYPE: // No setup required + break; // Skip all checks for these messages + case DeviceInfoRequest::MESSAGE_TYPE: // Connection setup only + if (!this->check_connection_setup_()) { + return; // Connection not setup + } + break; + default: + // All other messages require authentication (which includes connection check) + if (!this->check_authenticated_()) { + return; // Authentication failed + } + break; + } + + // Call base implementation to process the message + APIServerConnectionBase::read_message(msg_size, msg_type, msg_data); +} } // namespace esphome::api diff --git a/esphome/components/api/api_pb2_service.h b/esphome/components/api/api_pb2_service.h index 5b7508e786..549b00ee6a 100644 --- a/esphome/components/api/api_pb2_service.h +++ b/esphome/components/api/api_pb2_service.h @@ -26,7 +26,9 @@ class APIServerConnectionBase : public ProtoService { virtual void on_hello_request(const HelloRequest &value){}; - virtual void on_connect_request(const ConnectRequest &value){}; +#ifdef USE_API_PASSWORD + virtual void on_authentication_request(const AuthenticationRequest &value){}; +#endif virtual void on_disconnect_request(const DisconnectRequest &value){}; virtual void on_disconnect_response(const DisconnectResponse &value){}; @@ -64,6 +66,9 @@ class APIServerConnectionBase : public ProtoService { virtual void on_subscribe_homeassistant_services_request(const SubscribeHomeassistantServicesRequest &value){}; #endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + virtual void on_homeassistant_action_response(const HomeassistantActionResponse &value){}; +#endif #ifdef USE_API_HOMEASSISTANT_STATES virtual void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &value){}; #endif @@ -205,6 +210,12 @@ class APIServerConnectionBase : public ProtoService { #ifdef USE_UPDATE virtual void on_update_command_request(const UpdateCommandRequest &value){}; +#endif +#ifdef USE_ZWAVE_PROXY + virtual void on_z_wave_proxy_frame(const ZWaveProxyFrame &value){}; +#endif +#ifdef USE_ZWAVE_PROXY + virtual void on_z_wave_proxy_request(const ZWaveProxyRequest &value){}; #endif protected: void read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override; @@ -213,7 +224,9 @@ class APIServerConnectionBase : public ProtoService { class APIServerConnection : public APIServerConnectionBase { public: virtual bool send_hello_response(const HelloRequest &msg) = 0; - virtual bool send_connect_response(const ConnectRequest &msg) = 0; +#ifdef USE_API_PASSWORD + virtual bool send_authenticate_response(const AuthenticationRequest &msg) = 0; +#endif virtual bool send_disconnect_response(const DisconnectRequest &msg) = 0; virtual bool send_ping_response(const PingRequest &msg) = 0; virtual bool send_device_info_response(const DeviceInfoRequest &msg) = 0; @@ -331,10 +344,18 @@ class APIServerConnection : public APIServerConnectionBase { #endif #ifdef USE_ALARM_CONTROL_PANEL virtual void alarm_control_panel_command(const AlarmControlPanelCommandRequest &msg) = 0; +#endif +#ifdef USE_ZWAVE_PROXY + virtual void zwave_proxy_frame(const ZWaveProxyFrame &msg) = 0; +#endif +#ifdef USE_ZWAVE_PROXY + virtual void zwave_proxy_request(const ZWaveProxyRequest &msg) = 0; #endif protected: void on_hello_request(const HelloRequest &msg) override; - void on_connect_request(const ConnectRequest &msg) override; +#ifdef USE_API_PASSWORD + void on_authentication_request(const AuthenticationRequest &msg) override; +#endif void on_disconnect_request(const DisconnectRequest &msg) override; void on_ping_request(const PingRequest &msg) override; void on_device_info_request(const DeviceInfoRequest &msg) override; @@ -453,6 +474,13 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_ALARM_CONTROL_PANEL void on_alarm_control_panel_command_request(const AlarmControlPanelCommandRequest &msg) override; #endif +#ifdef USE_ZWAVE_PROXY + void on_z_wave_proxy_frame(const ZWaveProxyFrame &msg) override; +#endif +#ifdef USE_ZWAVE_PROXY + void on_z_wave_proxy_request(const ZWaveProxyRequest &msg) override; +#endif + void read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override; }; } // namespace esphome::api diff --git a/esphome/components/api/api_server.cpp b/esphome/components/api/api_server.cpp index a12cf13ce2..778d9389ef 100644 --- a/esphome/components/api/api_server.cpp +++ b/esphome/components/api/api_server.cpp @@ -9,12 +9,16 @@ #include "esphome/core/log.h" #include "esphome/core/util.h" #include "esphome/core/version.h" +#ifdef USE_API_HOMEASSISTANT_SERVICES +#include "homeassistant_service.h" +#endif #ifdef USE_LOGGER #include "esphome/components/logger/logger.h" #endif #include +#include namespace esphome::api { @@ -87,7 +91,7 @@ void APIServer::setup() { return; } - err = this->socket_->listen(4); + err = this->socket_->listen(this->listen_backlog_); if (err != 0) { ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno); this->mark_failed(); @@ -140,9 +144,19 @@ void APIServer::loop() { while (true) { struct sockaddr_storage source_addr; socklen_t addr_len = sizeof(source_addr); + auto sock = this->socket_->accept_loop_monitored((struct sockaddr *) &source_addr, &addr_len); if (!sock) break; + + // Check if we're at the connection limit + if (this->clients_.size() >= this->max_connections_) { + ESP_LOGW(TAG, "Max connections (%d), rejecting %s", this->max_connections_, sock->getpeername().c_str()); + // Immediately close - socket destructor will handle cleanup + sock.reset(); + continue; + } + ESP_LOGD(TAG, "Accept %s", sock->getpeername().c_str()); auto *conn = new APIConnection(std::move(sock), this); @@ -167,7 +181,8 @@ void APIServer::loop() { // Network is down - disconnect all clients for (auto &client : this->clients_) { client->on_fatal_error(); - ESP_LOGW(TAG, "%s: Network down; disconnect", client->get_client_combined_info().c_str()); + ESP_LOGW(TAG, "%s (%s): Network down; disconnect", client->client_info_.name.c_str(), + client->client_info_.peername.c_str()); } // Continue to process and clean up the clients below } @@ -206,8 +221,10 @@ void APIServer::loop() { void APIServer::dump_config() { ESP_LOGCONFIG(TAG, "Server:\n" - " Address: %s:%u", - network::get_use_address().c_str(), this->port_); + " Address: %s:%u\n" + " Listen backlog: %u\n" + " Max connections: %u", + network::get_use_address().c_str(), this->port_, this->listen_backlog_, this->max_connections_); #ifdef USE_API_NOISE ESP_LOGCONFIG(TAG, " Noise encryption: %s", YESNO(this->noise_ctx_->has_psk())); if (!this->noise_ctx_->has_psk()) { @@ -219,12 +236,12 @@ void APIServer::dump_config() { } #ifdef USE_API_PASSWORD -bool APIServer::check_password(const std::string &password) const { +bool APIServer::check_password(const uint8_t *password_data, size_t password_len) const { // depend only on input password length const char *a = this->password_.c_str(); uint32_t len_a = this->password_.length(); - const char *b = password.c_str(); - uint32_t len_b = password.length(); + const char *b = reinterpret_cast(password_data); + uint32_t len_b = password_len; // disable optimization with volatile volatile uint32_t length = len_b; @@ -247,6 +264,7 @@ bool APIServer::check_password(const std::string &password) const { return result == 0; } + #endif void APIServer::handle_disconnect(APIConnection *conn) {} @@ -357,6 +375,15 @@ void APIServer::on_update(update::UpdateEntity *obj) { } #endif +#ifdef USE_ZWAVE_PROXY +void APIServer::on_zwave_proxy_request(const esphome::api::ProtoMessage &msg) { + // We could add code to manage a second subscription type, but, since this message type is + // very infrequent and small, we simply send it to all clients + for (auto &c : this->clients_) + c->send_message(msg, api::ZWaveProxyRequest::MESSAGE_TYPE); +} +#endif + #ifdef USE_ALARM_CONTROL_PANEL API_DISPATCH_UPDATE(alarm_control_panel::AlarmControlPanel, alarm_control_panel) #endif @@ -372,12 +399,43 @@ void APIServer::set_password(const std::string &password) { this->password_ = pa void APIServer::set_batch_delay(uint16_t batch_delay) { this->batch_delay_ = batch_delay; } #ifdef USE_API_HOMEASSISTANT_SERVICES -void APIServer::send_homeassistant_service_call(const HomeassistantServiceResponse &call) { +void APIServer::send_homeassistant_action(const HomeassistantActionRequest &call) { for (auto &client : this->clients_) { - client->send_homeassistant_service_call(call); + client->send_homeassistant_action(call); } } -#endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +void APIServer::register_action_response_callback(uint32_t call_id, ActionResponseCallback callback) { + this->action_response_callbacks_.push_back({call_id, std::move(callback)}); +} + +void APIServer::handle_action_response(uint32_t call_id, bool success, const std::string &error_message) { + for (auto it = this->action_response_callbacks_.begin(); it != this->action_response_callbacks_.end(); ++it) { + if (it->call_id == call_id) { + auto callback = std::move(it->callback); + this->action_response_callbacks_.erase(it); + ActionResponse response(success, error_message); + callback(response); + return; + } + } +} +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON +void APIServer::handle_action_response(uint32_t call_id, bool success, const std::string &error_message, + const uint8_t *response_data, size_t response_data_len) { + for (auto it = this->action_response_callbacks_.begin(); it != this->action_response_callbacks_.end(); ++it) { + if (it->call_id == call_id) { + auto callback = std::move(it->callback); + this->action_response_callbacks_.erase(it); + ActionResponse response(success, error_message, response_data, response_data_len); + callback(response); + return; + } + } +} +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES +#endif // USE_API_HOMEASSISTANT_SERVICES #ifdef USE_API_HOMEASSISTANT_STATES void APIServer::subscribe_home_assistant_state(std::string entity_id, optional attribute, diff --git a/esphome/components/api/api_server.h b/esphome/components/api/api_server.h index 8b5e624df2..5d038e5ddd 100644 --- a/esphome/components/api/api_server.h +++ b/esphome/components/api/api_server.h @@ -16,6 +16,7 @@ #include "user_services.h" #endif +#include #include namespace esphome::api { @@ -37,13 +38,15 @@ class APIServer : public Component, public Controller { void on_shutdown() override; bool teardown() override; #ifdef USE_API_PASSWORD - bool check_password(const std::string &password) const; + bool check_password(const uint8_t *password_data, size_t password_len) const; void set_password(const std::string &password); #endif void set_port(uint16_t port); void set_reboot_timeout(uint32_t reboot_timeout); void set_batch_delay(uint16_t batch_delay); uint16_t get_batch_delay() const { return batch_delay_; } + void set_listen_backlog(uint8_t listen_backlog) { this->listen_backlog_ = listen_backlog; } + void set_max_connections(uint8_t max_connections) { this->max_connections_ = max_connections; } // Get reference to shared buffer for API connections std::vector &get_shared_buffer_ref() { return shared_write_buffer_; } @@ -107,8 +110,19 @@ class APIServer : public Component, public Controller { void on_media_player_update(media_player::MediaPlayer *obj) override; #endif #ifdef USE_API_HOMEASSISTANT_SERVICES - void send_homeassistant_service_call(const HomeassistantServiceResponse &call); -#endif + void send_homeassistant_action(const HomeassistantActionRequest &call); + +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + // Action response handling + using ActionResponseCallback = std::function; + void register_action_response_callback(uint32_t call_id, ActionResponseCallback callback); + void handle_action_response(uint32_t call_id, bool success, const std::string &error_message); +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + void handle_action_response(uint32_t call_id, bool success, const std::string &error_message, + const uint8_t *response_data, size_t response_data_len); +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES +#endif // USE_API_HOMEASSISTANT_SERVICES #ifdef USE_API_SERVICES void register_user_service(UserServiceDescriptor *descriptor) { this->user_services_.push_back(descriptor); } #endif @@ -125,6 +139,9 @@ class APIServer : public Component, public Controller { #ifdef USE_UPDATE void on_update(update::UpdateEntity *obj) override; #endif +#ifdef USE_ZWAVE_PROXY + void on_zwave_proxy_request(const esphome::api::ProtoMessage &msg); +#endif bool is_connected() const; @@ -181,12 +198,23 @@ class APIServer : public Component, public Controller { #ifdef USE_API_SERVICES std::vector user_services_; #endif +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + struct PendingActionResponse { + uint32_t call_id; + ActionResponseCallback callback; + }; + std::vector action_response_callbacks_; +#endif // Group smaller types together uint16_t port_{6053}; uint16_t batch_delay_{100}; + // Connection limits - these defaults will be overridden by config values + // from cv.SplitDefault in __init__.py which sets platform-specific defaults + uint8_t listen_backlog_{4}; + uint8_t max_connections_{8}; bool shutting_down_ = false; - // 5 bytes used, 3 bytes padding + // 7 bytes used, 1 byte padding #ifdef USE_API_NOISE std::shared_ptr noise_ctx_ = std::make_shared(); diff --git a/esphome/components/api/custom_api_device.h b/esphome/components/api/custom_api_device.h index 44f9eee571..0c6e49d6ca 100644 --- a/esphome/components/api/custom_api_device.h +++ b/esphome/components/api/custom_api_device.h @@ -179,9 +179,9 @@ class CustomAPIDevice { * @param service_name The service to call. */ void call_homeassistant_service(const std::string &service_name) { - HomeassistantServiceResponse resp; + HomeassistantActionRequest resp; resp.set_service(StringRef(service_name)); - global_api_server->send_homeassistant_service_call(resp); + global_api_server->send_homeassistant_action(resp); } /** Call a Home Assistant service from ESPHome. @@ -199,7 +199,7 @@ class CustomAPIDevice { * @param data The data for the service call, mapping from string to string. */ void call_homeassistant_service(const std::string &service_name, const std::map &data) { - HomeassistantServiceResponse resp; + HomeassistantActionRequest resp; resp.set_service(StringRef(service_name)); for (auto &it : data) { resp.data.emplace_back(); @@ -207,7 +207,7 @@ class CustomAPIDevice { kv.set_key(StringRef(it.first)); kv.value = it.second; } - global_api_server->send_homeassistant_service_call(resp); + global_api_server->send_homeassistant_action(resp); } /** Fire an ESPHome event in Home Assistant. @@ -221,10 +221,10 @@ class CustomAPIDevice { * @param event_name The event to fire. */ void fire_homeassistant_event(const std::string &event_name) { - HomeassistantServiceResponse resp; + HomeassistantActionRequest resp; resp.set_service(StringRef(event_name)); resp.is_event = true; - global_api_server->send_homeassistant_service_call(resp); + global_api_server->send_homeassistant_action(resp); } /** Fire an ESPHome event in Home Assistant. @@ -241,7 +241,7 @@ class CustomAPIDevice { * @param data The data for the event, mapping from string to string. */ void fire_homeassistant_event(const std::string &service_name, const std::map &data) { - HomeassistantServiceResponse resp; + HomeassistantActionRequest resp; resp.set_service(StringRef(service_name)); resp.is_event = true; for (auto &it : data) { @@ -250,7 +250,7 @@ class CustomAPIDevice { kv.set_key(StringRef(it.first)); kv.value = it.second; } - global_api_server->send_homeassistant_service_call(resp); + global_api_server->send_homeassistant_action(resp); } #else template void call_homeassistant_service(const std::string &service_name) { diff --git a/esphome/components/api/homeassistant_service.h b/esphome/components/api/homeassistant_service.h index 5df9c7c792..730024f7b7 100644 --- a/esphome/components/api/homeassistant_service.h +++ b/esphome/components/api/homeassistant_service.h @@ -3,10 +3,15 @@ #include "api_server.h" #ifdef USE_API #ifdef USE_API_HOMEASSISTANT_SERVICES +#include +#include +#include #include "api_pb2.h" +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON +#include "esphome/components/json/json_util.h" +#endif #include "esphome/core/automation.h" #include "esphome/core/helpers.h" -#include namespace esphome::api { @@ -44,9 +49,47 @@ template class TemplatableKeyValuePair { TemplatableStringValue value; }; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +// Represents the response data from a Home Assistant action +class ActionResponse { + public: + ActionResponse(bool success, std::string error_message = "") + : success_(success), error_message_(std::move(error_message)) {} + +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + ActionResponse(bool success, std::string error_message, const uint8_t *data, size_t data_len) + : success_(success), error_message_(std::move(error_message)) { + if (data == nullptr || data_len == 0) + return; + this->json_document_ = json::parse_json(data, data_len); + } +#endif + + bool is_success() const { return this->success_; } + const std::string &get_error_message() const { return this->error_message_; } + +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + // Get data as parsed JSON object (const version returns read-only view) + JsonObjectConst get_json() const { return this->json_document_.as(); } +#endif + + protected: + bool success_; + std::string error_message_; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + JsonDocument json_document_; +#endif +}; + +// Callback type for action responses +template using ActionResponseCallback = std::function; +#endif + template class HomeAssistantServiceCallAction : public Action { public: - explicit HomeAssistantServiceCallAction(APIServer *parent, bool is_event) : parent_(parent), is_event_(is_event) {} + explicit HomeAssistantServiceCallAction(APIServer *parent, bool is_event) : parent_(parent) { + this->flags_.is_event = is_event; + } template void set_service(T service) { this->service_ = service; } @@ -61,11 +104,29 @@ template class HomeAssistantServiceCallAction : public Actionvariables_.emplace_back(std::move(key), value); } +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + template void set_response_template(T response_template) { + this->response_template_ = response_template; + this->flags_.has_response_template = true; + } + + void set_wants_status() { this->flags_.wants_status = true; } + void set_wants_response() { this->flags_.wants_response = true; } + +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + Trigger *get_success_trigger_with_response() const { + return this->success_trigger_with_response_; + } +#endif + Trigger *get_success_trigger() const { return this->success_trigger_; } + Trigger *get_error_trigger() const { return this->error_trigger_; } +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES + void play(Ts... x) override { - HomeassistantServiceResponse resp; + HomeassistantActionRequest resp; std::string service_value = this->service_.value(x...); resp.set_service(StringRef(service_value)); - resp.is_event = this->is_event_; + resp.is_event = this->flags_.is_event; for (auto &it : this->data_) { resp.data.emplace_back(); auto &kv = resp.data.back(); @@ -84,18 +145,74 @@ template class HomeAssistantServiceCallAction : public Actionparent_->send_homeassistant_service_call(resp); + +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES + if (this->flags_.wants_status) { + // Generate a unique call ID for this service call + static uint32_t call_id_counter = 1; + uint32_t call_id = call_id_counter++; + resp.call_id = call_id; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + if (this->flags_.wants_response) { + resp.wants_response = true; + // Set response template if provided + if (this->flags_.has_response_template) { + std::string response_template_value = this->response_template_.value(x...); + resp.response_template = response_template_value; + } + } +#endif + + auto captured_args = std::make_tuple(x...); + this->parent_->register_action_response_callback(call_id, [this, captured_args](const ActionResponse &response) { + std::apply( + [this, &response](auto &&...args) { + if (response.is_success()) { +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + if (this->flags_.wants_response) { + this->success_trigger_with_response_->trigger(response.get_json(), args...); + } else +#endif + { + this->success_trigger_->trigger(args...); + } + } else { + this->error_trigger_->trigger(response.get_error_message(), args...); + } + }, + captured_args); + }); + } +#endif + + this->parent_->send_homeassistant_action(resp); } protected: APIServer *parent_; - bool is_event_; TemplatableStringValue service_{}; std::vector> data_; std::vector> data_template_; std::vector> variables_; +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES +#ifdef USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + TemplatableStringValue response_template_{""}; + Trigger *success_trigger_with_response_ = new Trigger(); +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON + Trigger *success_trigger_ = new Trigger(); + Trigger *error_trigger_ = new Trigger(); +#endif // USE_API_HOMEASSISTANT_ACTION_RESPONSES + + struct Flags { + uint8_t is_event : 1; + uint8_t wants_status : 1; + uint8_t wants_response : 1; + uint8_t has_response_template : 1; + uint8_t reserved : 5; + } flags_{0}; }; } // namespace esphome::api + #endif #endif diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index 0e5ec61050..9d780692ec 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -182,6 +182,10 @@ class ProtoLengthDelimited { explicit ProtoLengthDelimited(const uint8_t *value, size_t length) : value_(value), length_(length) {} std::string as_string() const { return std::string(reinterpret_cast(this->value_), this->length_); } + // Direct access to raw data without string allocation + const uint8_t *data() const { return this->value_; } + size_t size() const { return this->length_; } + /** * Decode the length-delimited data into an existing ProtoDecodableMessage instance. * @@ -827,7 +831,7 @@ class ProtoService { } // Authentication helper methods - bool check_connection_setup_() { + inline bool check_connection_setup_() { if (!this->is_connection_setup()) { this->on_no_setup_connection(); return false; @@ -835,7 +839,7 @@ class ProtoService { return true; } - bool check_authenticated_() { + inline bool check_authenticated_() { #ifdef USE_API_PASSWORD if (!this->check_connection_setup_()) { return false; diff --git a/esphome/components/api/user_services.h b/esphome/components/api/user_services.h index 5f040e8433..3996c921a9 100644 --- a/esphome/components/api/user_services.h +++ b/esphome/components/api/user_services.h @@ -35,7 +35,7 @@ template class UserServiceBase : public UserServiceDescriptor { msg.set_name(StringRef(this->name_)); msg.key = this->key_; std::array arg_types = {to_service_arg_type()...}; - for (int i = 0; i < sizeof...(Ts); i++) { + for (size_t i = 0; i < sizeof...(Ts); i++) { msg.args.emplace_back(); auto &arg = msg.args.back(); arg.type = arg_types[i]; @@ -55,7 +55,7 @@ template class UserServiceBase : public UserServiceDescriptor { protected: virtual void execute(Ts... x) = 0; - template void execute_(std::vector args, seq type) { + template void execute_(const std::vector &args, seq type) { this->execute((get_execute_arg_value(args[S]))...); } diff --git a/esphome/components/as7341/sensor.py b/esphome/components/as7341/sensor.py index 2832b7c3df..fa51a1cdfa 100644 --- a/esphome/components/as7341/sensor.py +++ b/esphome/components/as7341/sensor.py @@ -2,6 +2,7 @@ import esphome.codegen as cg from esphome.components import i2c, sensor import esphome.config_validation as cv from esphome.const import ( + CONF_CLEAR, CONF_GAIN, CONF_ID, DEVICE_CLASS_ILLUMINANCE, @@ -29,7 +30,6 @@ CONF_F5 = "f5" CONF_F6 = "f6" CONF_F7 = "f7" CONF_F8 = "f8" -CONF_CLEAR = "clear" CONF_NIR = "nir" UNIT_COUNTS = "#" diff --git a/esphome/components/audio/__init__.py b/esphome/components/audio/__init__.py index f657cb5da3..7b03e4b6a7 100644 --- a/esphome/components/audio/__init__.py +++ b/esphome/components/audio/__init__.py @@ -165,4 +165,4 @@ def final_validate_audio_schema( async def to_code(config): - cg.add_library("esphome/esp-audio-libs", "1.1.4") + cg.add_library("esphome/esp-audio-libs", "2.0.1") diff --git a/esphome/components/audio/audio.cpp b/esphome/components/audio/audio.cpp index 2a58c38ac7..9cc9b7d0da 100644 --- a/esphome/components/audio/audio.cpp +++ b/esphome/components/audio/audio.cpp @@ -57,7 +57,7 @@ const char *audio_file_type_to_string(AudioFileType file_type) { void scale_audio_samples(const int16_t *audio_samples, int16_t *output_buffer, int16_t scale_factor, size_t samples_to_scale) { // Note the assembly dsps_mulc function has audio glitches if the input and output buffers are the same. - for (int i = 0; i < samples_to_scale; i++) { + for (size_t i = 0; i < samples_to_scale; i++) { int32_t acc = (int32_t) audio_samples[i] * (int32_t) scale_factor; output_buffer[i] = (int16_t) (acc >> 15); } diff --git a/esphome/components/audio/audio_decoder.cpp b/esphome/components/audio/audio_decoder.cpp index 90ba1aec1e..d1ad571a52 100644 --- a/esphome/components/audio/audio_decoder.cpp +++ b/esphome/components/audio/audio_decoder.cpp @@ -229,18 +229,18 @@ FileDecoderState AudioDecoder::decode_flac_() { auto result = this->flac_decoder_->read_header(this->input_transfer_buffer_->get_buffer_start(), this->input_transfer_buffer_->available()); - if (result == esp_audio_libs::flac::FLAC_DECODER_HEADER_OUT_OF_DATA) { - return FileDecoderState::POTENTIALLY_FAILED; - } - - if (result != esp_audio_libs::flac::FLAC_DECODER_SUCCESS) { - // Couldn't read FLAC header + if (result > esp_audio_libs::flac::FLAC_DECODER_HEADER_OUT_OF_DATA) { + // Serrious error reading FLAC header, there is no recovery return FileDecoderState::FAILED; } size_t bytes_consumed = this->flac_decoder_->get_bytes_index(); this->input_transfer_buffer_->decrease_buffer_length(bytes_consumed); + if (result == esp_audio_libs::flac::FLAC_DECODER_HEADER_OUT_OF_DATA) { + return FileDecoderState::MORE_TO_PROCESS; + } + // Reallocate the output transfer buffer to the smallest necessary size this->free_buffer_required_ = flac_decoder_->get_output_buffer_size_bytes(); if (!this->output_transfer_buffer_->reallocate(this->free_buffer_required_)) { @@ -256,9 +256,9 @@ FileDecoderState AudioDecoder::decode_flac_() { } uint32_t output_samples = 0; - auto result = this->flac_decoder_->decode_frame( - this->input_transfer_buffer_->get_buffer_start(), this->input_transfer_buffer_->available(), - reinterpret_cast(this->output_transfer_buffer_->get_buffer_end()), &output_samples); + auto result = this->flac_decoder_->decode_frame(this->input_transfer_buffer_->get_buffer_start(), + this->input_transfer_buffer_->available(), + this->output_transfer_buffer_->get_buffer_end(), &output_samples); if (result == esp_audio_libs::flac::FLAC_DECODER_ERROR_OUT_OF_DATA) { // Not an issue, just needs more data that we'll get next time. diff --git a/esphome/components/bl0906/bl0906.cpp b/esphome/components/bl0906/bl0906.cpp index e48715010c..c1cd48a1ac 100644 --- a/esphome/components/bl0906/bl0906.cpp +++ b/esphome/components/bl0906/bl0906.cpp @@ -97,10 +97,10 @@ void BL0906::handle_actions_() { return; } ActionCallbackFuncPtr ptr_func = nullptr; - for (int i = 0; i < this->action_queue_.size(); i++) { + for (size_t i = 0; i < this->action_queue_.size(); i++) { ptr_func = this->action_queue_[i]; if (ptr_func) { - ESP_LOGI(TAG, "HandleActionCallback[%d]", i); + ESP_LOGI(TAG, "HandleActionCallback[%zu]", i); (this->*ptr_func)(); } } diff --git a/esphome/components/bl0942/bl0942.cpp b/esphome/components/bl0942/bl0942.cpp index 894fcbfbb7..95dd689b07 100644 --- a/esphome/components/bl0942/bl0942.cpp +++ b/esphome/components/bl0942/bl0942.cpp @@ -51,7 +51,7 @@ void BL0942::loop() { if (!avail) { return; } - if (avail < sizeof(buffer)) { + if (static_cast(avail) < sizeof(buffer)) { if (!this->rx_start_) { this->rx_start_ = millis(); } else if (millis() > this->rx_start_ + PKT_TIMEOUT_MS) { @@ -148,7 +148,7 @@ void BL0942::setup() { this->write_reg_(BL0942_REG_USR_WRPROT, 0); - if (this->read_reg_(BL0942_REG_MODE) != mode) + if (static_cast(this->read_reg_(BL0942_REG_MODE)) != mode) this->status_set_warning(LOG_STR("BL0942 setup failed!")); this->flush(); diff --git a/esphome/components/ble_client/__init__.py b/esphome/components/ble_client/__init__.py index 5f4ea8afd1..768a345213 100644 --- a/esphome/components/ble_client/__init__.py +++ b/esphome/components/ble_client/__init__.py @@ -116,7 +116,7 @@ CONFIG_SCHEMA = cv.All( ) .extend(cv.COMPONENT_SCHEMA) .extend(esp32_ble_tracker.ESP_BLE_DEVICE_SCHEMA), - esp32_ble_tracker.consume_connection_slots(1, "ble_client"), + esp32_ble.consume_connection_slots(1, "ble_client"), ) CONF_BLE_CLIENT_ID = "ble_client_id" diff --git a/esphome/components/bluetooth_proxy/__init__.py b/esphome/components/bluetooth_proxy/__init__.py index f21b5028c7..ad7528c156 100644 --- a/esphome/components/bluetooth_proxy/__init__.py +++ b/esphome/components/bluetooth_proxy/__init__.py @@ -6,8 +6,6 @@ from esphome.components.esp32 import add_idf_sdkconfig_option from esphome.components.esp32_ble import BTLoggers import esphome.config_validation as cv from esphome.const import CONF_ACTIVE, CONF_ID -from esphome.core import CORE -from esphome.log import AnsiFore, color AUTO_LOAD = ["esp32_ble_client", "esp32_ble_tracker"] DEPENDENCIES = ["api", "esp32"] @@ -44,29 +42,7 @@ def validate_connections(config): ) elif config[CONF_ACTIVE]: connection_slots: int = config[CONF_CONNECTION_SLOTS] - esp32_ble_tracker.consume_connection_slots(connection_slots, "bluetooth_proxy")( - config - ) - - # Warn about connection slot waste when using Arduino framework - if CORE.using_arduino and connection_slots: - _LOGGER.warning( - "Bluetooth Proxy with active connections on Arduino framework has suboptimal performance.\n" - "If BLE connections fail, they can waste connection slots for 10 seconds because\n" - "Arduino doesn't allow configuring the BLE connection timeout (fixed at 30s).\n" - "ESP-IDF framework allows setting it to 20s to match client timeouts.\n" - "\n" - "To switch to ESP-IDF, add this to your YAML:\n" - " esp32:\n" - " framework:\n" - " type: esp-idf\n" - "\n" - "For detailed migration instructions, see:\n" - "%s", - color( - AnsiFore.BLUE, "https://esphome.io/guides/esp32_arduino_to_idf.html" - ), - ) + esp32_ble.consume_connection_slots(connection_slots, "bluetooth_proxy")(config) return { **config, @@ -81,19 +57,17 @@ CONFIG_SCHEMA = cv.All( { cv.GenerateID(): cv.declare_id(BluetoothProxy), cv.Optional(CONF_ACTIVE, default=True): cv.boolean, - cv.SplitDefault(CONF_CACHE_SERVICES, esp32_idf=True): cv.All( - cv.only_with_esp_idf, cv.boolean - ), + cv.Optional(CONF_CACHE_SERVICES, default=True): cv.boolean, cv.Optional( CONF_CONNECTION_SLOTS, default=DEFAULT_CONNECTION_SLOTS, ): cv.All( cv.positive_int, - cv.Range(min=1, max=esp32_ble_tracker.max_connections()), + cv.Range(min=1, max=esp32_ble.IDF_MAX_CONNECTIONS), ), cv.Optional(CONF_CONNECTIONS): cv.All( cv.ensure_list(CONNECTION_SCHEMA), - cv.Length(min=1, max=esp32_ble_tracker.max_connections()), + cv.Length(min=1, max=esp32_ble.IDF_MAX_CONNECTIONS), ), } ) diff --git a/esphome/components/bluetooth_proxy/bluetooth_connection.cpp b/esphome/components/bluetooth_proxy/bluetooth_connection.cpp index 540492f8c5..cde82fbfb0 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_connection.cpp +++ b/esphome/components/bluetooth_proxy/bluetooth_connection.cpp @@ -514,7 +514,8 @@ esp_err_t BluetoothConnection::read_characteristic(uint16_t handle) { return this->check_and_log_error_("esp_ble_gattc_read_char", err); } -esp_err_t BluetoothConnection::write_characteristic(uint16_t handle, const std::string &data, bool response) { +esp_err_t BluetoothConnection::write_characteristic(uint16_t handle, const uint8_t *data, size_t length, + bool response) { if (!this->connected()) { this->log_gatt_not_connected_("write", "characteristic"); return ESP_GATT_NOT_CONNECTED; @@ -522,8 +523,11 @@ esp_err_t BluetoothConnection::write_characteristic(uint16_t handle, const std:: ESP_LOGV(TAG, "[%d] [%s] Writing GATT characteristic handle %d", this->connection_index_, this->address_str_.c_str(), handle); + // ESP-IDF's API requires a non-const uint8_t* but it doesn't modify the data + // The BTC layer immediately copies the data to its own buffer (see btc_gattc.c) + // const_cast is safe here and was previously hidden by a C-style cast esp_err_t err = - esp_ble_gattc_write_char(this->gattc_if_, this->conn_id_, handle, data.size(), (uint8_t *) data.data(), + esp_ble_gattc_write_char(this->gattc_if_, this->conn_id_, handle, length, const_cast(data), response ? ESP_GATT_WRITE_TYPE_RSP : ESP_GATT_WRITE_TYPE_NO_RSP, ESP_GATT_AUTH_REQ_NONE); return this->check_and_log_error_("esp_ble_gattc_write_char", err); } @@ -540,7 +544,7 @@ esp_err_t BluetoothConnection::read_descriptor(uint16_t handle) { return this->check_and_log_error_("esp_ble_gattc_read_char_descr", err); } -esp_err_t BluetoothConnection::write_descriptor(uint16_t handle, const std::string &data, bool response) { +esp_err_t BluetoothConnection::write_descriptor(uint16_t handle, const uint8_t *data, size_t length, bool response) { if (!this->connected()) { this->log_gatt_not_connected_("write", "descriptor"); return ESP_GATT_NOT_CONNECTED; @@ -548,8 +552,11 @@ esp_err_t BluetoothConnection::write_descriptor(uint16_t handle, const std::stri ESP_LOGV(TAG, "[%d] [%s] Writing GATT descriptor handle %d", this->connection_index_, this->address_str_.c_str(), handle); + // ESP-IDF's API requires a non-const uint8_t* but it doesn't modify the data + // The BTC layer immediately copies the data to its own buffer (see btc_gattc.c) + // const_cast is safe here and was previously hidden by a C-style cast esp_err_t err = esp_ble_gattc_write_char_descr( - this->gattc_if_, this->conn_id_, handle, data.size(), (uint8_t *) data.data(), + this->gattc_if_, this->conn_id_, handle, length, const_cast(data), response ? ESP_GATT_WRITE_TYPE_RSP : ESP_GATT_WRITE_TYPE_NO_RSP, ESP_GATT_AUTH_REQ_NONE); return this->check_and_log_error_("esp_ble_gattc_write_char_descr", err); } diff --git a/esphome/components/bluetooth_proxy/bluetooth_connection.h b/esphome/components/bluetooth_proxy/bluetooth_connection.h index e5d5ff2dd6..60bbc93e8b 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_connection.h +++ b/esphome/components/bluetooth_proxy/bluetooth_connection.h @@ -18,9 +18,9 @@ class BluetoothConnection final : public esp32_ble_client::BLEClientBase { esp32_ble_tracker::AdvertisementParserType get_advertisement_parser_type() override; esp_err_t read_characteristic(uint16_t handle); - esp_err_t write_characteristic(uint16_t handle, const std::string &data, bool response); + esp_err_t write_characteristic(uint16_t handle, const uint8_t *data, size_t length, bool response); esp_err_t read_descriptor(uint16_t handle); - esp_err_t write_descriptor(uint16_t handle, const std::string &data, bool response); + esp_err_t write_descriptor(uint16_t handle, const uint8_t *data, size_t length, bool response); esp_err_t notify_characteristic(uint16_t handle, bool enable); diff --git a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp index 532aff550e..cd7261d5e5 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp +++ b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp @@ -305,7 +305,7 @@ void BluetoothProxy::bluetooth_gatt_write(const api::BluetoothGATTWriteRequest & return; } - auto err = connection->write_characteristic(msg.handle, msg.data, msg.response); + auto err = connection->write_characteristic(msg.handle, msg.data, msg.data_len, msg.response); if (err != ESP_OK) { this->send_gatt_error(msg.address, msg.handle, err); } @@ -331,7 +331,7 @@ void BluetoothProxy::bluetooth_gatt_write_descriptor(const api::BluetoothGATTWri return; } - auto err = connection->write_descriptor(msg.handle, msg.data, true); + auto err = connection->write_descriptor(msg.handle, msg.data, msg.data_len, true); if (err != ESP_OK) { this->send_gatt_error(msg.address, msg.handle, err); } diff --git a/esphome/components/camera_encoder/__init__.py b/esphome/components/camera_encoder/__init__.py index c0f0ca2fe0..89181d27b4 100644 --- a/esphome/components/camera_encoder/__init__.py +++ b/esphome/components/camera_encoder/__init__.py @@ -2,7 +2,6 @@ import esphome.codegen as cg from esphome.components.esp32 import add_idf_component import esphome.config_validation as cv from esphome.const import CONF_BUFFER_SIZE, CONF_ID, CONF_TYPE -from esphome.core import CORE from esphome.types import ConfigType CODEOWNERS = ["@DT-art1"] @@ -51,9 +50,8 @@ async def to_code(config: ConfigType) -> None: buffer = cg.new_Pvariable(config[CONF_ENCODER_BUFFER_ID]) cg.add(buffer.set_buffer_size(config[CONF_BUFFER_SIZE])) if config[CONF_TYPE] == ESP32_CAMERA_ENCODER: - if CORE.using_esp_idf: - add_idf_component(name="espressif/esp32-camera", ref="2.1.0") - cg.add_build_flag("-DUSE_ESP32_CAMERA_JPEG_ENCODER") + add_idf_component(name="espressif/esp32-camera", ref="2.1.1") + cg.add_define("USE_ESP32_CAMERA_JPEG_ENCODER") var = cg.new_Pvariable( config[CONF_ID], config[CONF_QUALITY], diff --git a/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.cpp b/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.cpp index 7e21122087..55a3f0b96c 100644 --- a/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.cpp +++ b/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.cpp @@ -1,3 +1,5 @@ +#include "esphome/core/defines.h" + #ifdef USE_ESP32_CAMERA_JPEG_ENCODER #include "esp32_camera_jpeg_encoder.h" @@ -15,7 +17,7 @@ camera::EncoderError ESP32CameraJPEGEncoder::encode_pixels(camera::CameraImageSp this->bytes_written_ = 0; this->out_of_output_memory_ = false; bool success = fmt2jpg_cb(pixels->get_data_buffer(), pixels->get_data_length(), spec->width, spec->height, - to_internal_(spec->format), this->quality_, callback_, this); + to_internal_(spec->format), this->quality_, callback, this); if (!success) return camera::ENCODER_ERROR_CONFIGURATION; @@ -49,7 +51,7 @@ void ESP32CameraJPEGEncoder::dump_config() { this->output_->get_max_size(), this->quality_, this->buffer_expand_size_); } -size_t ESP32CameraJPEGEncoder::callback_(void *arg, size_t index, const void *data, size_t len) { +size_t ESP32CameraJPEGEncoder::callback(void *arg, size_t index, const void *data, size_t len) { ESP32CameraJPEGEncoder *that = reinterpret_cast(arg); uint8_t *buffer = that->output_->get_data(); size_t buffer_length = that->output_->get_max_size(); diff --git a/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.h b/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.h index b585252584..0ede366e73 100644 --- a/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.h +++ b/esphome/components/camera_encoder/esp32_camera_jpeg_encoder.h @@ -1,5 +1,7 @@ #pragma once +#include "esphome/core/defines.h" + #ifdef USE_ESP32_CAMERA_JPEG_ENCODER #include @@ -24,7 +26,7 @@ class ESP32CameraJPEGEncoder : public camera::Encoder { void dump_config() override; // ------------------------- protected: - static size_t callback_(void *arg, size_t index, const void *data, size_t len); + static size_t callback(void *arg, size_t index, const void *data, size_t len); pixformat_t to_internal_(camera::PixelFormat format); camera::EncoderBuffer *output_{}; diff --git a/esphome/components/canbus/canbus.cpp b/esphome/components/canbus/canbus.cpp index 6e61f05be7..e208b0fd66 100644 --- a/esphome/components/canbus/canbus.cpp +++ b/esphome/components/canbus/canbus.cpp @@ -21,8 +21,8 @@ void Canbus::dump_config() { } } -void Canbus::send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request, - const std::vector &data) { +canbus::Error Canbus::send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request, + const std::vector &data) { struct CanFrame can_message; uint8_t size = static_cast(data.size()); @@ -45,13 +45,15 @@ void Canbus::send_data(uint32_t can_id, bool use_extended_id, bool remote_transm ESP_LOGVV(TAG, " data[%d]=%02x", i, can_message.data[i]); } - if (this->send_message(&can_message) != canbus::ERROR_OK) { + canbus::Error error = this->send_message(&can_message); + if (error != canbus::ERROR_OK) { if (use_extended_id) { - ESP_LOGW(TAG, "send to extended id=0x%08" PRIx32 " failed!", can_id); + ESP_LOGW(TAG, "send to extended id=0x%08" PRIx32 " failed with error %d!", can_id, error); } else { - ESP_LOGW(TAG, "send to standard id=0x%03" PRIx32 " failed!", can_id); + ESP_LOGW(TAG, "send to standard id=0x%03" PRIx32 " failed with error %d!", can_id, error); } } + return error; } void Canbus::add_trigger(CanbusTrigger *trigger) { diff --git a/esphome/components/canbus/canbus.h b/esphome/components/canbus/canbus.h index 7319bfb4ad..56e2f2719b 100644 --- a/esphome/components/canbus/canbus.h +++ b/esphome/components/canbus/canbus.h @@ -70,11 +70,11 @@ class Canbus : public Component { float get_setup_priority() const override { return setup_priority::HARDWARE; } void loop() override; - void send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request, - const std::vector &data); - void send_data(uint32_t can_id, bool use_extended_id, const std::vector &data) { + canbus::Error send_data(uint32_t can_id, bool use_extended_id, bool remote_transmission_request, + const std::vector &data); + canbus::Error send_data(uint32_t can_id, bool use_extended_id, const std::vector &data) { // for backwards compatibility only - this->send_data(can_id, use_extended_id, false, data); + return this->send_data(can_id, use_extended_id, false, data); } void set_can_id(uint32_t can_id) { this->can_id_ = can_id; } void set_use_extended_id(bool use_extended_id) { this->use_extended_id_ = use_extended_id; } diff --git a/esphome/components/captive_portal/__init__.py b/esphome/components/captive_portal/__init__.py index 9f2af0a230..99acb76bcf 100644 --- a/esphome/components/captive_portal/__init__.py +++ b/esphome/components/captive_portal/__init__.py @@ -1,6 +1,7 @@ import esphome.codegen as cg from esphome.components import web_server_base from esphome.components.web_server_base import CONF_WEB_SERVER_BASE_ID +from esphome.config_helpers import filter_source_files_from_platform import esphome.config_validation as cv from esphome.const import ( CONF_ID, @@ -9,11 +10,19 @@ from esphome.const import ( PLATFORM_ESP8266, PLATFORM_LN882X, PLATFORM_RTL87XX, + PlatformFramework, ) from esphome.core import CORE, coroutine_with_priority from esphome.coroutine import CoroPriority -AUTO_LOAD = ["web_server_base", "ota.web_server"] + +def AUTO_LOAD() -> list[str]: + auto_load = ["web_server_base", "ota.web_server"] + if CORE.using_esp_idf: + auto_load.append("socket") + return auto_load + + DEPENDENCIES = ["wifi"] CODEOWNERS = ["@esphome/core"] @@ -58,3 +67,11 @@ async def to_code(config): cg.add_library("DNSServer", None) if CORE.is_libretiny: cg.add_library("DNSServer", None) + + +# Only compile the ESP-IDF DNS server when using ESP-IDF framework +FILTER_SOURCE_FILES = filter_source_files_from_platform( + { + "dns_server_esp32_idf.cpp": {PlatformFramework.ESP32_IDF}, + } +) diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 7eb0ffa99e..30438747f2 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -11,14 +11,14 @@ namespace captive_portal { static const char *const TAG = "captive_portal"; void CaptivePortal::handle_config(AsyncWebServerRequest *request) { - AsyncResponseStream *stream = request->beginResponseStream(F("application/json")); - stream->addHeader(F("cache-control"), F("public, max-age=0, must-revalidate")); + AsyncResponseStream *stream = request->beginResponseStream(ESPHOME_F("application/json")); + stream->addHeader(ESPHOME_F("cache-control"), ESPHOME_F("public, max-age=0, must-revalidate")); #ifdef USE_ESP8266 - stream->print(F("{\"mac\":\"")); + stream->print(ESPHOME_F("{\"mac\":\"")); stream->print(get_mac_address_pretty().c_str()); - stream->print(F("\",\"name\":\"")); + stream->print(ESPHOME_F("\",\"name\":\"")); stream->print(App.get_name().c_str()); - stream->print(F("\",\"aps\":[{}")); + stream->print(ESPHOME_F("\",\"aps\":[{}")); #else stream->printf(R"({"mac":"%s","name":"%s","aps":[{})", get_mac_address_pretty().c_str(), App.get_name().c_str()); #endif @@ -29,37 +29,35 @@ void CaptivePortal::handle_config(AsyncWebServerRequest *request) { // Assumes no " in ssid, possible unicode isses? #ifdef USE_ESP8266 - stream->print(F(",{\"ssid\":\"")); + stream->print(ESPHOME_F(",{\"ssid\":\"")); stream->print(scan.get_ssid().c_str()); - stream->print(F("\",\"rssi\":")); + stream->print(ESPHOME_F("\",\"rssi\":")); stream->print(scan.get_rssi()); - stream->print(F(",\"lock\":")); + stream->print(ESPHOME_F(",\"lock\":")); stream->print(scan.get_with_auth()); - stream->print(F("}")); + stream->print(ESPHOME_F("}")); #else stream->printf(R"(,{"ssid":"%s","rssi":%d,"lock":%d})", scan.get_ssid().c_str(), scan.get_rssi(), scan.get_with_auth()); #endif } - stream->print(F("]}")); + stream->print(ESPHOME_F("]}")); request->send(stream); } void CaptivePortal::handle_wifisave(AsyncWebServerRequest *request) { - std::string ssid = request->arg("ssid").c_str(); - std::string psk = request->arg("psk").c_str(); + std::string ssid = request->arg("ssid").c_str(); // NOLINT(readability-redundant-string-cstr) + std::string psk = request->arg("psk").c_str(); // NOLINT(readability-redundant-string-cstr) ESP_LOGI(TAG, "Requested WiFi Settings Change:"); ESP_LOGI(TAG, " SSID='%s'", ssid.c_str()); ESP_LOGI(TAG, " Password=" LOG_SECRET("'%s'"), psk.c_str()); wifi::global_wifi_component->save_wifi_sta(ssid, psk); wifi::global_wifi_component->start_scanning(); - request->redirect(F("/?save")); + request->redirect(ESPHOME_F("/?save")); } void CaptivePortal::setup() { -#ifndef USE_ARDUINO - // No DNS server needed for non-Arduino frameworks + // Disable loop by default - will be enabled when captive portal starts this->disable_loop(); -#endif } void CaptivePortal::start() { this->base_->init(); @@ -67,51 +65,47 @@ void CaptivePortal::start() { this->base_->add_handler(this); } + network::IPAddress ip = wifi::global_wifi_component->wifi_soft_ap_ip(); + +#ifdef USE_ESP_IDF + // Create DNS server instance for ESP-IDF + this->dns_server_ = make_unique(); + this->dns_server_->start(ip); +#endif #ifdef USE_ARDUINO this->dns_server_ = make_unique(); this->dns_server_->setErrorReplyCode(DNSReplyCode::NoError); - network::IPAddress ip = wifi::global_wifi_component->wifi_soft_ap_ip(); - this->dns_server_->start(53, F("*"), ip); - // Re-enable loop() when DNS server is started - this->enable_loop(); + this->dns_server_->start(53, ESPHOME_F("*"), ip); #endif - this->base_->get_server()->onNotFound([this](AsyncWebServerRequest *req) { - if (!this->active_ || req->host().c_str() == wifi::global_wifi_component->wifi_soft_ap_ip().str()) { - req->send(404, F("text/html"), F("File not found")); - return; - } - -#ifdef USE_ESP8266 - String url = F("http://"); - url += wifi::global_wifi_component->wifi_soft_ap_ip().str().c_str(); -#else - auto url = "http://" + wifi::global_wifi_component->wifi_soft_ap_ip().str(); -#endif - req->redirect(url.c_str()); - }); - this->initialized_ = true; this->active_ = true; + + // Enable loop() now that captive portal is active + this->enable_loop(); + + ESP_LOGV(TAG, "Captive portal started"); } void CaptivePortal::handleRequest(AsyncWebServerRequest *req) { - if (req->url() == F("/")) { -#ifndef USE_ESP8266 - auto *response = req->beginResponse(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); -#else - auto *response = req->beginResponse_P(200, F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); -#endif - response->addHeader(F("Content-Encoding"), F("gzip")); - req->send(response); - return; - } else if (req->url() == F("/config.json")) { + if (req->url() == ESPHOME_F("/config.json")) { this->handle_config(req); return; - } else if (req->url() == F("/wifisave")) { + } else if (req->url() == ESPHOME_F("/wifisave")) { this->handle_wifisave(req); return; } + + // All other requests get the captive portal page + // This includes OS captive portal detection endpoints which will trigger + // the captive portal when they don't receive their expected responses +#ifndef USE_ESP8266 + auto *response = req->beginResponse(200, ESPHOME_F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); +#else + auto *response = req->beginResponse_P(200, ESPHOME_F("text/html"), INDEX_GZ, sizeof(INDEX_GZ)); +#endif + response->addHeader(ESPHOME_F("Content-Encoding"), ESPHOME_F("gzip")); + req->send(response); } CaptivePortal::CaptivePortal(web_server_base::WebServerBase *base) : base_(base) { global_captive_portal = this; } diff --git a/esphome/components/captive_portal/captive_portal.h b/esphome/components/captive_portal/captive_portal.h index 382afe92f0..f48c286f0c 100644 --- a/esphome/components/captive_portal/captive_portal.h +++ b/esphome/components/captive_portal/captive_portal.h @@ -5,6 +5,9 @@ #ifdef USE_ARDUINO #include #endif +#ifdef USE_ESP_IDF +#include "dns_server_esp32_idf.h" +#endif #include "esphome/core/component.h" #include "esphome/core/helpers.h" #include "esphome/core/preferences.h" @@ -19,41 +22,36 @@ class CaptivePortal : public AsyncWebHandler, public Component { CaptivePortal(web_server_base::WebServerBase *base); void setup() override; void dump_config() override; -#ifdef USE_ARDUINO void loop() override { +#ifdef USE_ARDUINO if (this->dns_server_ != nullptr) { this->dns_server_->processNextRequest(); - } else { - this->disable_loop(); } - } #endif +#ifdef USE_ESP_IDF + if (this->dns_server_ != nullptr) { + this->dns_server_->process_next_request(); + } +#endif + } float get_setup_priority() const override; void start(); bool is_active() const { return this->active_; } void end() { this->active_ = false; + this->disable_loop(); // Stop processing DNS requests this->base_->deinit(); -#ifdef USE_ARDUINO - this->dns_server_->stop(); - this->dns_server_ = nullptr; -#endif + if (this->dns_server_ != nullptr) { + this->dns_server_->stop(); + this->dns_server_ = nullptr; + } } bool canHandle(AsyncWebServerRequest *request) const override { - if (!this->active_) - return false; - - if (request->method() == HTTP_GET) { - if (request->url() == F("/")) - return true; - if (request->url() == F("/config.json")) - return true; - if (request->url() == F("/wifisave")) - return true; - } - - return false; + // Handle all GET requests when captive portal is active + // This allows us to respond with the portal page for any URL, + // triggering OS captive portal detection + return this->active_ && request->method() == HTTP_GET; } void handle_config(AsyncWebServerRequest *request); @@ -66,7 +64,7 @@ class CaptivePortal : public AsyncWebHandler, public Component { web_server_base::WebServerBase *base_; bool initialized_{false}; bool active_{false}; -#ifdef USE_ARDUINO +#if defined(USE_ARDUINO) || defined(USE_ESP_IDF) std::unique_ptr dns_server_{nullptr}; #endif }; diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.cpp b/esphome/components/captive_portal/dns_server_esp32_idf.cpp new file mode 100644 index 0000000000..740107400a --- /dev/null +++ b/esphome/components/captive_portal/dns_server_esp32_idf.cpp @@ -0,0 +1,205 @@ +#include "dns_server_esp32_idf.h" +#ifdef USE_ESP_IDF + +#include "esphome/core/log.h" +#include "esphome/core/hal.h" +#include "esphome/components/socket/socket.h" +#include +#include + +namespace esphome::captive_portal { + +static const char *const TAG = "captive_portal.dns"; + +// DNS constants +static constexpr uint16_t DNS_PORT = 53; +static constexpr uint16_t DNS_QR_FLAG = 1 << 15; +static constexpr uint16_t DNS_OPCODE_MASK = 0x7800; +static constexpr uint16_t DNS_QTYPE_A = 0x0001; +static constexpr uint16_t DNS_QCLASS_IN = 0x0001; +static constexpr uint16_t DNS_ANSWER_TTL = 300; + +// DNS Header structure +struct DNSHeader { + uint16_t id; + uint16_t flags; + uint16_t qd_count; + uint16_t an_count; + uint16_t ns_count; + uint16_t ar_count; +} __attribute__((packed)); + +// DNS Question structure +struct DNSQuestion { + uint16_t type; + uint16_t dns_class; +} __attribute__((packed)); + +// DNS Answer structure +struct DNSAnswer { + uint16_t ptr_offset; + uint16_t type; + uint16_t dns_class; + uint32_t ttl; + uint16_t addr_len; + uint32_t ip_addr; +} __attribute__((packed)); + +void DNSServer::start(const network::IPAddress &ip) { + this->server_ip_ = ip; + ESP_LOGV(TAG, "Starting DNS server on %s", ip.str().c_str()); + + // Create loop-monitored UDP socket + this->socket_ = socket::socket_ip_loop_monitored(SOCK_DGRAM, IPPROTO_UDP); + if (this->socket_ == nullptr) { + ESP_LOGE(TAG, "Socket create failed"); + return; + } + + // Set socket options + int enable = 1; + this->socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); + + // Bind to port 53 + struct sockaddr_storage server_addr = {}; + socklen_t addr_len = socket::set_sockaddr_any((struct sockaddr *) &server_addr, sizeof(server_addr), DNS_PORT); + + int err = this->socket_->bind((struct sockaddr *) &server_addr, addr_len); + if (err != 0) { + ESP_LOGE(TAG, "Bind failed: %d", errno); + this->socket_ = nullptr; + return; + } + ESP_LOGV(TAG, "Bound to port %d", DNS_PORT); +} + +void DNSServer::stop() { + if (this->socket_ != nullptr) { + this->socket_->close(); + this->socket_ = nullptr; + } + ESP_LOGV(TAG, "Stopped"); +} + +void DNSServer::process_next_request() { + // Process one request if socket is valid and data is available + if (this->socket_ == nullptr || !this->socket_->ready()) { + return; + } + struct sockaddr_in client_addr; + socklen_t client_addr_len = sizeof(client_addr); + + // Receive DNS request using raw fd for recvfrom + int fd = this->socket_->get_fd(); + if (fd < 0) { + return; + } + + ssize_t len = recvfrom(fd, this->buffer_, sizeof(this->buffer_), MSG_DONTWAIT, (struct sockaddr *) &client_addr, + &client_addr_len); + + if (len < 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + ESP_LOGE(TAG, "recvfrom failed: %d", errno); + } + return; + } + + ESP_LOGVV(TAG, "Received %d bytes from %s:%d", len, inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port)); + + if (len < static_cast(sizeof(DNSHeader) + 1)) { + ESP_LOGV(TAG, "Request too short: %d", len); + return; + } + + // Parse DNS header + DNSHeader *header = (DNSHeader *) this->buffer_; + uint16_t flags = ntohs(header->flags); + uint16_t qd_count = ntohs(header->qd_count); + + // Check if it's a standard query + if ((flags & DNS_QR_FLAG) || (flags & DNS_OPCODE_MASK) || qd_count != 1) { + ESP_LOGV(TAG, "Not a standard query: flags=0x%04X, qd_count=%d", flags, qd_count); + return; // Not a standard query + } + + // Parse domain name (we don't actually care about it - redirect everything) + uint8_t *ptr = this->buffer_ + sizeof(DNSHeader); + uint8_t *end = this->buffer_ + len; + + while (ptr < end && *ptr != 0) { + uint8_t label_len = *ptr; + if (label_len > 63) { // Check for invalid label length + return; + } + // Check if we have room for this label plus the length byte + if (ptr + label_len + 1 > end) { + return; // Would overflow + } + ptr += label_len + 1; + } + + // Check if we reached a proper null terminator + if (ptr >= end || *ptr != 0) { + return; // Name not terminated or truncated + } + ptr++; // Skip the null terminator + + // Check we have room for the question + if (ptr + sizeof(DNSQuestion) > end) { + return; // Request truncated + } + + // Parse DNS question + DNSQuestion *question = (DNSQuestion *) ptr; + uint16_t qtype = ntohs(question->type); + uint16_t qclass = ntohs(question->dns_class); + + // We only handle A queries + if (qtype != DNS_QTYPE_A || qclass != DNS_QCLASS_IN) { + ESP_LOGV(TAG, "Not an A query: type=0x%04X, class=0x%04X", qtype, qclass); + return; // Not an A query + } + + // Build DNS response by modifying the request in-place + header->flags = htons(DNS_QR_FLAG | 0x8000); // Response + Authoritative + header->an_count = htons(1); // One answer + + // Add answer section after the question + size_t question_len = (ptr + sizeof(DNSQuestion)) - this->buffer_ - sizeof(DNSHeader); + size_t answer_offset = sizeof(DNSHeader) + question_len; + + // Check if we have room for the answer + if (answer_offset + sizeof(DNSAnswer) > sizeof(this->buffer_)) { + ESP_LOGW(TAG, "Response too large"); + return; + } + + DNSAnswer *answer = (DNSAnswer *) (this->buffer_ + answer_offset); + + // Pointer to name in question (offset from start of packet) + answer->ptr_offset = htons(0xC000 | sizeof(DNSHeader)); + answer->type = htons(DNS_QTYPE_A); + answer->dns_class = htons(DNS_QCLASS_IN); + answer->ttl = htonl(DNS_ANSWER_TTL); + answer->addr_len = htons(4); + + // Get the raw IP address + ip4_addr_t addr = this->server_ip_; + answer->ip_addr = addr.addr; + + size_t response_len = answer_offset + sizeof(DNSAnswer); + + // Send response + ssize_t sent = + this->socket_->sendto(this->buffer_, response_len, 0, (struct sockaddr *) &client_addr, client_addr_len); + if (sent < 0) { + ESP_LOGV(TAG, "Send failed: %d", errno); + } else { + ESP_LOGV(TAG, "Sent %d bytes", sent); + } +} + +} // namespace esphome::captive_portal + +#endif // USE_ESP_IDF diff --git a/esphome/components/captive_portal/dns_server_esp32_idf.h b/esphome/components/captive_portal/dns_server_esp32_idf.h new file mode 100644 index 0000000000..13d9def8e3 --- /dev/null +++ b/esphome/components/captive_portal/dns_server_esp32_idf.h @@ -0,0 +1,27 @@ +#pragma once +#ifdef USE_ESP_IDF + +#include +#include "esphome/core/helpers.h" +#include "esphome/components/network/ip_address.h" +#include "esphome/components/socket/socket.h" + +namespace esphome::captive_portal { + +class DNSServer { + public: + void start(const network::IPAddress &ip); + void stop(); + void process_next_request(); + + protected: + static constexpr size_t DNS_BUFFER_SIZE = 192; + + std::unique_ptr socket_{nullptr}; + network::IPAddress server_ip_; + uint8_t buffer_[DNS_BUFFER_SIZE]; +}; + +} // namespace esphome::captive_portal + +#endif // USE_ESP_IDF diff --git a/esphome/components/ccs811/ccs811.cpp b/esphome/components/ccs811/ccs811.cpp index 40c5318339..84355f2793 100644 --- a/esphome/components/ccs811/ccs811.cpp +++ b/esphome/components/ccs811/ccs811.cpp @@ -155,7 +155,7 @@ void CCS811Component::dump_config() { LOG_UPDATE_INTERVAL(this); LOG_SENSOR(" ", "CO2 Sensor", this->co2_); LOG_SENSOR(" ", "TVOC Sensor", this->tvoc_); - LOG_TEXT_SENSOR(" ", "Firmware Version Sensor", this->version_) + LOG_TEXT_SENSOR(" ", "Firmware Version Sensor", this->version_); if (this->baseline_) { ESP_LOGCONFIG(TAG, " Baseline: %04X", *this->baseline_); } else { diff --git a/esphome/components/climate/climate.cpp b/esphome/components/climate/climate.cpp index be56310b35..e7a454d459 100644 --- a/esphome/components/climate/climate.cpp +++ b/esphome/components/climate/climate.cpp @@ -367,9 +367,11 @@ void Climate::save_state_() { state.uses_custom_fan_mode = true; const auto &supported = traits.get_supported_custom_fan_modes(); std::vector vec{supported.begin(), supported.end()}; - auto it = std::find(vec.begin(), vec.end(), custom_fan_mode); - if (it != vec.end()) { - state.custom_fan_mode = std::distance(vec.begin(), it); + for (size_t i = 0; i < vec.size(); i++) { + if (vec[i] == custom_fan_mode) { + state.custom_fan_mode = i; + break; + } } } if (traits.get_supports_presets() && preset.has_value()) { @@ -380,10 +382,11 @@ void Climate::save_state_() { state.uses_custom_preset = true; const auto &supported = traits.get_supported_custom_presets(); std::vector vec{supported.begin(), supported.end()}; - auto it = std::find(vec.begin(), vec.end(), custom_preset); - // only set custom preset if value exists, otherwise leave it as is - if (it != vec.cend()) { - state.custom_preset = std::distance(vec.begin(), it); + for (size_t i = 0; i < vec.size(); i++) { + if (vec[i] == custom_preset) { + state.custom_preset = i; + break; + } } } if (traits.get_supports_swing_modes()) { diff --git a/esphome/components/cm1106/cm1106.cpp b/esphome/components/cm1106/cm1106.cpp index 339a1659ac..d88ea2e1da 100644 --- a/esphome/components/cm1106/cm1106.cpp +++ b/esphome/components/cm1106/cm1106.cpp @@ -13,7 +13,7 @@ static const uint8_t C_M1106_CMD_SET_CO2_CALIB_RESPONSE[4] = {0x16, 0x01, 0x03, uint8_t cm1106_checksum(const uint8_t *response, size_t len) { uint8_t crc = 0; - for (int i = 0; i < len - 1; i++) { + for (size_t i = 0; i < len - 1; i++) { crc -= response[i]; } return crc; diff --git a/esphome/components/copy/lock/copy_lock.cpp b/esphome/components/copy/lock/copy_lock.cpp index 67a8acffec..25bd8c33ef 100644 --- a/esphome/components/copy/lock/copy_lock.cpp +++ b/esphome/components/copy/lock/copy_lock.cpp @@ -11,7 +11,7 @@ void CopyLock::setup() { traits.set_assumed_state(source_->traits.get_assumed_state()); traits.set_requires_code(source_->traits.get_requires_code()); - traits.set_supported_states(source_->traits.get_supported_states()); + traits.set_supported_states_mask(source_->traits.get_supported_states_mask()); traits.set_supports_open(source_->traits.get_supports_open()); this->publish_state(source_->state); diff --git a/esphome/components/cover/cover.cpp b/esphome/components/cover/cover.cpp index 700bceec01..3378279371 100644 --- a/esphome/components/cover/cover.cpp +++ b/esphome/components/cover/cover.cpp @@ -1,5 +1,6 @@ #include "cover.h" #include "esphome/core/log.h" +#include namespace esphome { namespace cover { diff --git a/esphome/components/daikin_arc/daikin_arc.cpp b/esphome/components/daikin_arc/daikin_arc.cpp index f806463d00..068819ecd1 100644 --- a/esphome/components/daikin_arc/daikin_arc.cpp +++ b/esphome/components/daikin_arc/daikin_arc.cpp @@ -26,7 +26,7 @@ void DaikinArcClimate::transmit_query_() { uint8_t remote_header[8] = {0x11, 0xDA, 0x27, 0x00, 0x84, 0x87, 0x20, 0x00}; // Calculate checksum - for (int i = 0; i < sizeof(remote_header) - 1; i++) { + for (size_t i = 0; i < sizeof(remote_header) - 1; i++) { remote_header[sizeof(remote_header) - 1] += remote_header[i]; } @@ -102,7 +102,7 @@ void DaikinArcClimate::transmit_state() { remote_state[9] = fan_speed & 0xff; // Calculate checksum - for (int i = 0; i < sizeof(remote_header) - 1; i++) { + for (size_t i = 0; i < sizeof(remote_header) - 1; i++) { remote_header[sizeof(remote_header) - 1] += remote_header[i]; } @@ -350,7 +350,7 @@ bool DaikinArcClimate::on_receive(remote_base::RemoteReceiveData data) { bool valid_daikin_frame = false; if (data.expect_item(DAIKIN_HEADER_MARK, DAIKIN_HEADER_SPACE)) { valid_daikin_frame = true; - int bytes_count = data.size() / 2 / 8; + size_t bytes_count = data.size() / 2 / 8; std::unique_ptr buf(new char[bytes_count * 3 + 1]); buf[0] = '\0'; for (size_t i = 0; i < bytes_count; i++) { @@ -370,7 +370,7 @@ bool DaikinArcClimate::on_receive(remote_base::RemoteReceiveData data) { if (!valid_daikin_frame) { char sbuf[16 * 10 + 1]; sbuf[0] = '\0'; - for (size_t j = 0; j < data.size(); j++) { + for (size_t j = 0; j < static_cast(data.size()); j++) { if ((j - 2) % 16 == 0) { if (j > 0) { ESP_LOGD(TAG, "DATA %04x: %s", (j - 16 > 0xffff ? 0 : j - 16), sbuf); @@ -380,19 +380,26 @@ bool DaikinArcClimate::on_receive(remote_base::RemoteReceiveData data) { char type_ch = ' '; // debug_tolerance = 25% - if (DAIKIN_DBG_LOWER(DAIKIN_ARC_PRE_MARK) <= data[j] && data[j] <= DAIKIN_DBG_UPPER(DAIKIN_ARC_PRE_MARK)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_ARC_PRE_MARK)) <= data[j] && + data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_ARC_PRE_MARK))) type_ch = 'P'; - if (DAIKIN_DBG_LOWER(DAIKIN_ARC_PRE_SPACE) <= -data[j] && -data[j] <= DAIKIN_DBG_UPPER(DAIKIN_ARC_PRE_SPACE)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_ARC_PRE_SPACE)) <= -data[j] && + -data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_ARC_PRE_SPACE))) type_ch = 'a'; - if (DAIKIN_DBG_LOWER(DAIKIN_HEADER_MARK) <= data[j] && data[j] <= DAIKIN_DBG_UPPER(DAIKIN_HEADER_MARK)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_HEADER_MARK)) <= data[j] && + data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_HEADER_MARK))) type_ch = 'H'; - if (DAIKIN_DBG_LOWER(DAIKIN_HEADER_SPACE) <= -data[j] && -data[j] <= DAIKIN_DBG_UPPER(DAIKIN_HEADER_SPACE)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_HEADER_SPACE)) <= -data[j] && + -data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_HEADER_SPACE))) type_ch = 'h'; - if (DAIKIN_DBG_LOWER(DAIKIN_BIT_MARK) <= data[j] && data[j] <= DAIKIN_DBG_UPPER(DAIKIN_BIT_MARK)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_BIT_MARK)) <= data[j] && + data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_BIT_MARK))) type_ch = 'B'; - if (DAIKIN_DBG_LOWER(DAIKIN_ONE_SPACE) <= -data[j] && -data[j] <= DAIKIN_DBG_UPPER(DAIKIN_ONE_SPACE)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_ONE_SPACE)) <= -data[j] && + -data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_ONE_SPACE))) type_ch = '1'; - if (DAIKIN_DBG_LOWER(DAIKIN_ZERO_SPACE) <= -data[j] && -data[j] <= DAIKIN_DBG_UPPER(DAIKIN_ZERO_SPACE)) + if (static_cast(DAIKIN_DBG_LOWER(DAIKIN_ZERO_SPACE)) <= -data[j] && + -data[j] <= static_cast(DAIKIN_DBG_UPPER(DAIKIN_ZERO_SPACE))) type_ch = '0'; if (abs(data[j]) > 100000) { @@ -400,7 +407,7 @@ bool DaikinArcClimate::on_receive(remote_base::RemoteReceiveData data) { } else { sprintf(sbuf, "%s%-5d[%c] ", sbuf, (int) (round(data[j] / 10.) * 10), type_ch); } - if (j == data.size() - 1) { + if (j + 1 == static_cast(data.size())) { ESP_LOGD(TAG, "DATA %04x: %s", (j - 8 > 0xffff ? 0 : j - 8), sbuf); } } diff --git a/esphome/components/deep_sleep/__init__.py b/esphome/components/deep_sleep/__init__.py index 05ae60239d..19fb726016 100644 --- a/esphome/components/deep_sleep/__init__.py +++ b/esphome/components/deep_sleep/__init__.py @@ -197,7 +197,8 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_ESP32_EXT1_WAKEUP): cv.All( cv.only_on_esp32, esp32.only_on_variant( - unsupported=[VARIANT_ESP32C3], msg_prefix="Wakeup from ext1" + unsupported=[VARIANT_ESP32C2, VARIANT_ESP32C3], + msg_prefix="Wakeup from ext1", ), cv.Schema( { @@ -214,7 +215,13 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_TOUCH_WAKEUP): cv.All( cv.only_on_esp32, esp32.only_on_variant( - unsupported=[VARIANT_ESP32C3], msg_prefix="Wakeup from touch" + unsupported=[ + VARIANT_ESP32C2, + VARIANT_ESP32C3, + VARIANT_ESP32C6, + VARIANT_ESP32H2, + ], + msg_prefix="Wakeup from touch", ), cv.boolean, ), diff --git a/esphome/components/deep_sleep/deep_sleep_component.h b/esphome/components/deep_sleep/deep_sleep_component.h index 7a640b9ea5..38744163c7 100644 --- a/esphome/components/deep_sleep/deep_sleep_component.h +++ b/esphome/components/deep_sleep/deep_sleep_component.h @@ -34,7 +34,7 @@ enum WakeupPinMode { WAKEUP_PIN_MODE_INVERT_WAKEUP, }; -#if defined(USE_ESP32) && !defined(USE_ESP32_VARIANT_ESP32C3) +#if defined(USE_ESP32) && !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) struct Ext1Wakeup { uint64_t mask; esp_sleep_ext1_wakeup_mode_t wakeup_mode; @@ -50,7 +50,7 @@ struct WakeupCauseToRunDuration { uint32_t gpio_cause; }; -#endif +#endif // USE_ESP32 template class EnterDeepSleepAction; @@ -73,20 +73,22 @@ class DeepSleepComponent : public Component { void set_wakeup_pin(InternalGPIOPin *pin) { this->wakeup_pin_ = pin; } void set_wakeup_pin_mode(WakeupPinMode wakeup_pin_mode); -#endif +#endif // USE_ESP32 #if defined(USE_ESP32) -#if !defined(USE_ESP32_VARIANT_ESP32C3) - +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) void set_ext1_wakeup(Ext1Wakeup ext1_wakeup); - - void set_touch_wakeup(bool touch_wakeup); - #endif + +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \ + !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) + void set_touch_wakeup(bool touch_wakeup); +#endif + // Set the duration in ms for how long the code should run before entering // deep sleep mode, according to the cause the ESP32 has woken. void set_run_duration(WakeupCauseToRunDuration wakeup_cause_to_run_duration); -#endif +#endif // USE_ESP32 /// Set a duration in ms for how long the code should run before entering deep sleep mode. void set_run_duration(uint32_t time_ms); @@ -117,13 +119,13 @@ class DeepSleepComponent : public Component { InternalGPIOPin *wakeup_pin_; WakeupPinMode wakeup_pin_mode_{WAKEUP_PIN_MODE_IGNORE}; -#if !defined(USE_ESP32_VARIANT_ESP32C3) +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) optional ext1_wakeup_; #endif optional touch_wakeup_; optional wakeup_cause_to_run_duration_; -#endif +#endif // USE_ESP32 optional run_duration_; bool next_enter_deep_sleep_{false}; bool prevent_{false}; diff --git a/esphome/components/deep_sleep/deep_sleep_esp32.cpp b/esphome/components/deep_sleep/deep_sleep_esp32.cpp index e9d0a4981f..b93d9ce601 100644 --- a/esphome/components/deep_sleep/deep_sleep_esp32.cpp +++ b/esphome/components/deep_sleep/deep_sleep_esp32.cpp @@ -7,6 +7,26 @@ namespace esphome { namespace deep_sleep { +// Deep Sleep feature support matrix for ESP32 variants: +// +// | Variant | ext0 | ext1 | Touch | GPIO wakeup | +// |-----------|------|------|-------|-------------| +// | ESP32 | ✓ | ✓ | ✓ | | +// | ESP32-S2 | ✓ | ✓ | ✓ | | +// | ESP32-S3 | ✓ | ✓ | ✓ | | +// | ESP32-C2 | | | | ✓ | +// | ESP32-C3 | | | | ✓ | +// | ESP32-C5 | | (✓) | | (✓) | +// | ESP32-C6 | | ✓ | | ✓ | +// | ESP32-H2 | | ✓ | | | +// +// Notes: +// - (✓) = Supported by hardware but not yet implemented in ESPHome +// - ext0: Single pin wakeup using RTC GPIO (esp_sleep_enable_ext0_wakeup) +// - ext1: Multiple pin wakeup (esp_sleep_enable_ext1_wakeup) +// - Touch: Touch pad wakeup (esp_sleep_enable_touchpad_wakeup) +// - GPIO wakeup: GPIO wakeup for non-RTC pins (esp_deep_sleep_enable_gpio_wakeup) + static const char *const TAG = "deep_sleep"; optional DeepSleepComponent::get_run_duration_() const { @@ -30,13 +50,13 @@ void DeepSleepComponent::set_wakeup_pin_mode(WakeupPinMode wakeup_pin_mode) { this->wakeup_pin_mode_ = wakeup_pin_mode; } -#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) void DeepSleepComponent::set_ext1_wakeup(Ext1Wakeup ext1_wakeup) { this->ext1_wakeup_ = ext1_wakeup; } - -#if !defined(USE_ESP32_VARIANT_ESP32H2) -void DeepSleepComponent::set_touch_wakeup(bool touch_wakeup) { this->touch_wakeup_ = touch_wakeup; } #endif +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \ + !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) +void DeepSleepComponent::set_touch_wakeup(bool touch_wakeup) { this->touch_wakeup_ = touch_wakeup; } #endif void DeepSleepComponent::set_run_duration(WakeupCauseToRunDuration wakeup_cause_to_run_duration) { @@ -72,9 +92,13 @@ bool DeepSleepComponent::prepare_to_sleep_() { } void DeepSleepComponent::deep_sleep_() { -#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) + // Timer wakeup - all variants support this if (this->sleep_duration_.has_value()) esp_sleep_enable_timer_wakeup(*this->sleep_duration_); + + // Single pin wakeup (ext0) - ESP32, S2, S3 only +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \ + !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) if (this->wakeup_pin_ != nullptr) { const auto gpio_pin = gpio_num_t(this->wakeup_pin_->get_pin()); if (this->wakeup_pin_->get_flags() & gpio::FLAG_PULLUP) { @@ -95,32 +119,15 @@ void DeepSleepComponent::deep_sleep_() { } esp_sleep_enable_ext0_wakeup(gpio_pin, level); } - if (this->ext1_wakeup_.has_value()) { - esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode); - } - - if (this->touch_wakeup_.has_value() && *(this->touch_wakeup_)) { - esp_sleep_enable_touchpad_wakeup(); - esp_sleep_pd_config(ESP_PD_DOMAIN_RTC_PERIPH, ESP_PD_OPTION_ON); - } #endif -#if defined(USE_ESP32_VARIANT_ESP32H2) - if (this->sleep_duration_.has_value()) - esp_sleep_enable_timer_wakeup(*this->sleep_duration_); - if (this->ext1_wakeup_.has_value()) { - esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode); - } -#endif - -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) - if (this->sleep_duration_.has_value()) - esp_sleep_enable_timer_wakeup(*this->sleep_duration_); + // GPIO wakeup - C2, C3, C6 only +#if defined(USE_ESP32_VARIANT_ESP32C2) || defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) if (this->wakeup_pin_ != nullptr) { const auto gpio_pin = gpio_num_t(this->wakeup_pin_->get_pin()); - if (this->wakeup_pin_->get_flags() && gpio::FLAG_PULLUP) { + if (this->wakeup_pin_->get_flags() & gpio::FLAG_PULLUP) { gpio_sleep_set_pull_mode(gpio_pin, GPIO_PULLUP_ONLY); - } else if (this->wakeup_pin_->get_flags() && gpio::FLAG_PULLDOWN) { + } else if (this->wakeup_pin_->get_flags() & gpio::FLAG_PULLDOWN) { gpio_sleep_set_pull_mode(gpio_pin, GPIO_PULLDOWN_ONLY); } gpio_sleep_set_direction(gpio_pin, GPIO_MODE_INPUT); @@ -138,9 +145,26 @@ void DeepSleepComponent::deep_sleep_() { static_cast(level)); } #endif + + // Multiple pin wakeup (ext1) - All except C2, C3 +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) + if (this->ext1_wakeup_.has_value()) { + esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode); + } +#endif + + // Touch wakeup - ESP32, S2, S3 only +#if !defined(USE_ESP32_VARIANT_ESP32C2) && !defined(USE_ESP32_VARIANT_ESP32C3) && \ + !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) + if (this->touch_wakeup_.has_value() && *(this->touch_wakeup_)) { + esp_sleep_enable_touchpad_wakeup(); + esp_sleep_pd_config(ESP_PD_DOMAIN_RTC_PERIPH, ESP_PD_OPTION_ON); + } +#endif + esp_deep_sleep_start(); } } // namespace deep_sleep } // namespace esphome -#endif +#endif // USE_ESP32 diff --git a/esphome/components/ektf2232/touchscreen/__init__.py b/esphome/components/ektf2232/touchscreen/__init__.py index 7d946fdcb9..123f03ca08 100644 --- a/esphome/components/ektf2232/touchscreen/__init__.py +++ b/esphome/components/ektf2232/touchscreen/__init__.py @@ -2,7 +2,7 @@ from esphome import pins import esphome.codegen as cg from esphome.components import i2c, touchscreen import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_INTERRUPT_PIN +from esphome.const import CONF_ID, CONF_INTERRUPT_PIN, CONF_RESET_PIN CODEOWNERS = ["@jesserockz"] DEPENDENCIES = ["i2c"] @@ -15,7 +15,7 @@ EKTF2232Touchscreen = ektf2232_ns.class_( ) CONF_EKTF2232_ID = "ektf2232_id" -CONF_RTS_PIN = "rts_pin" +CONF_RTS_PIN = "rts_pin" # To be removed before 2026.4.0 CONFIG_SCHEMA = touchscreen.TOUCHSCREEN_SCHEMA.extend( cv.Schema( @@ -24,7 +24,10 @@ CONFIG_SCHEMA = touchscreen.TOUCHSCREEN_SCHEMA.extend( cv.Required(CONF_INTERRUPT_PIN): cv.All( pins.internal_gpio_input_pin_schema ), - cv.Required(CONF_RTS_PIN): pins.gpio_output_pin_schema, + cv.Required(CONF_RESET_PIN): pins.gpio_output_pin_schema, + cv.Optional(CONF_RTS_PIN): cv.invalid( + f"{CONF_RTS_PIN} has been renamed to {CONF_RESET_PIN}" + ), } ).extend(i2c.i2c_device_schema(0x15)) ) @@ -37,5 +40,5 @@ async def to_code(config): interrupt_pin = await cg.gpio_pin_expression(config[CONF_INTERRUPT_PIN]) cg.add(var.set_interrupt_pin(interrupt_pin)) - rts_pin = await cg.gpio_pin_expression(config[CONF_RTS_PIN]) - cg.add(var.set_rts_pin(rts_pin)) + reset_pin = await cg.gpio_pin_expression(config[CONF_RESET_PIN]) + cg.add(var.set_reset_pin(reset_pin)) diff --git a/esphome/components/ektf2232/touchscreen/ektf2232.cpp b/esphome/components/ektf2232/touchscreen/ektf2232.cpp index 1dacee6a57..63ebb2166b 100644 --- a/esphome/components/ektf2232/touchscreen/ektf2232.cpp +++ b/esphome/components/ektf2232/touchscreen/ektf2232.cpp @@ -21,7 +21,7 @@ void EKTF2232Touchscreen::setup() { this->attach_interrupt_(this->interrupt_pin_, gpio::INTERRUPT_FALLING_EDGE); - this->rts_pin_->setup(); + this->reset_pin_->setup(); this->hard_reset_(); if (!this->soft_reset_()) { @@ -98,9 +98,9 @@ bool EKTF2232Touchscreen::get_power_state() { } void EKTF2232Touchscreen::hard_reset_() { - this->rts_pin_->digital_write(false); + this->reset_pin_->digital_write(false); delay(15); - this->rts_pin_->digital_write(true); + this->reset_pin_->digital_write(true); delay(15); } @@ -127,7 +127,7 @@ void EKTF2232Touchscreen::dump_config() { ESP_LOGCONFIG(TAG, "EKT2232 Touchscreen:"); LOG_I2C_DEVICE(this); LOG_PIN(" Interrupt Pin: ", this->interrupt_pin_); - LOG_PIN(" RTS Pin: ", this->rts_pin_); + LOG_PIN(" Reset Pin: ", this->reset_pin_); } } // namespace ektf2232 diff --git a/esphome/components/ektf2232/touchscreen/ektf2232.h b/esphome/components/ektf2232/touchscreen/ektf2232.h index e9288d0a27..2ddc60851f 100644 --- a/esphome/components/ektf2232/touchscreen/ektf2232.h +++ b/esphome/components/ektf2232/touchscreen/ektf2232.h @@ -17,7 +17,7 @@ class EKTF2232Touchscreen : public Touchscreen, public i2c::I2CDevice { void dump_config() override; void set_interrupt_pin(InternalGPIOPin *pin) { this->interrupt_pin_ = pin; } - void set_rts_pin(GPIOPin *pin) { this->rts_pin_ = pin; } + void set_reset_pin(GPIOPin *pin) { this->reset_pin_ = pin; } void set_power_state(bool enable); bool get_power_state(); @@ -28,7 +28,7 @@ class EKTF2232Touchscreen : public Touchscreen, public i2c::I2CDevice { void update_touches() override; InternalGPIOPin *interrupt_pin_; - GPIOPin *rts_pin_; + GPIOPin *reset_pin_; }; } // namespace ektf2232 diff --git a/esphome/components/epaper_spi/__init__.py b/esphome/components/epaper_spi/__init__.py new file mode 100644 index 0000000000..f70ffa9520 --- /dev/null +++ b/esphome/components/epaper_spi/__init__.py @@ -0,0 +1 @@ +CODEOWNERS = ["@esphome/core"] diff --git a/esphome/components/epaper_spi/display.py b/esphome/components/epaper_spi/display.py new file mode 100644 index 0000000000..20549f049d --- /dev/null +++ b/esphome/components/epaper_spi/display.py @@ -0,0 +1,80 @@ +from esphome import core, pins +import esphome.codegen as cg +from esphome.components import display, spi +import esphome.config_validation as cv +from esphome.const import ( + CONF_BUSY_PIN, + CONF_DC_PIN, + CONF_ID, + CONF_LAMBDA, + CONF_MODEL, + CONF_PAGES, + CONF_RESET_DURATION, + CONF_RESET_PIN, +) + +AUTO_LOAD = ["split_buffer"] +DEPENDENCIES = ["spi"] + +epaper_spi_ns = cg.esphome_ns.namespace("epaper_spi") +EPaperBase = epaper_spi_ns.class_( + "EPaperBase", cg.PollingComponent, spi.SPIDevice, display.DisplayBuffer +) + +EPaperSpectraE6 = epaper_spi_ns.class_("EPaperSpectraE6", EPaperBase) +EPaper7p3InSpectraE6 = epaper_spi_ns.class_("EPaper7p3InSpectraE6", EPaperSpectraE6) + +MODELS = { + "7.3in-spectra-e6": EPaper7p3InSpectraE6, +} + + +CONFIG_SCHEMA = cv.All( + display.FULL_DISPLAY_SCHEMA.extend( + { + cv.GenerateID(): cv.declare_id(EPaperBase), + cv.Required(CONF_DC_PIN): pins.gpio_output_pin_schema, + cv.Required(CONF_MODEL): cv.one_of(*MODELS, lower=True, space="-"), + cv.Optional(CONF_RESET_PIN): pins.gpio_output_pin_schema, + cv.Optional(CONF_BUSY_PIN): pins.gpio_input_pin_schema, + cv.Optional(CONF_RESET_DURATION): cv.All( + cv.positive_time_period_milliseconds, + cv.Range(max=core.TimePeriod(milliseconds=500)), + ), + } + ) + .extend(cv.polling_component_schema("60s")) + .extend(spi.spi_device_schema()), + cv.has_at_most_one_key(CONF_PAGES, CONF_LAMBDA), +) + +FINAL_VALIDATE_SCHEMA = spi.final_validate_device_schema( + "epaper_spi", require_miso=False, require_mosi=True +) + + +async def to_code(config): + model = MODELS[config[CONF_MODEL]] + + rhs = model.new() + var = cg.Pvariable(config[CONF_ID], rhs, model) + + await display.register_display(var, config) + await spi.register_spi_device(var, config) + + dc = await cg.gpio_pin_expression(config[CONF_DC_PIN]) + cg.add(var.set_dc_pin(dc)) + + if CONF_LAMBDA in config: + lambda_ = await cg.process_lambda( + config[CONF_LAMBDA], [(display.DisplayRef, "it")], return_type=cg.void + ) + cg.add(var.set_writer(lambda_)) + if CONF_RESET_PIN in config: + reset = await cg.gpio_pin_expression(config[CONF_RESET_PIN]) + cg.add(var.set_reset_pin(reset)) + if CONF_BUSY_PIN in config: + busy = await cg.gpio_pin_expression(config[CONF_BUSY_PIN]) + cg.add(var.set_busy_pin(busy)) + if CONF_RESET_DURATION in config: + cg.add(var.set_reset_duration(config[CONF_RESET_DURATION])) diff --git a/esphome/components/epaper_spi/epaper_spi.cpp b/esphome/components/epaper_spi/epaper_spi.cpp new file mode 100644 index 0000000000..21be4a2c05 --- /dev/null +++ b/esphome/components/epaper_spi/epaper_spi.cpp @@ -0,0 +1,227 @@ +#include "epaper_spi.h" +#include +#include "esphome/core/application.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" + +namespace esphome::epaper_spi { + +static const char *const TAG = "epaper_spi"; + +static const LogString *epaper_state_to_string(EPaperState state) { + switch (state) { + case EPaperState::IDLE: + return LOG_STR("IDLE"); + case EPaperState::UPDATE: + return LOG_STR("UPDATE"); + case EPaperState::RESET: + return LOG_STR("RESET"); + case EPaperState::INITIALISE: + return LOG_STR("INITIALISE"); + case EPaperState::TRANSFER_DATA: + return LOG_STR("TRANSFER_DATA"); + case EPaperState::POWER_ON: + return LOG_STR("POWER_ON"); + case EPaperState::REFRESH_SCREEN: + return LOG_STR("REFRESH_SCREEN"); + case EPaperState::POWER_OFF: + return LOG_STR("POWER_OFF"); + case EPaperState::DEEP_SLEEP: + return LOG_STR("DEEP_SLEEP"); + default: + return LOG_STR("UNKNOWN"); + } +} + +void EPaperBase::setup() { + if (!this->init_buffer_(this->get_buffer_length())) { + this->mark_failed("Failed to initialise buffer"); + return; + } + this->setup_pins_(); + this->spi_setup(); +} + +bool EPaperBase::init_buffer_(size_t buffer_length) { + if (!this->buffer_.init(buffer_length)) { + return false; + } + this->clear(); + return true; +} + +void EPaperBase::setup_pins_() { + this->dc_pin_->setup(); // OUTPUT + this->dc_pin_->digital_write(false); + + if (this->reset_pin_ != nullptr) { + this->reset_pin_->setup(); // OUTPUT + this->reset_pin_->digital_write(true); + } + + if (this->busy_pin_ != nullptr) { + this->busy_pin_->setup(); // INPUT + } +} + +float EPaperBase::get_setup_priority() const { return setup_priority::PROCESSOR; } + +void EPaperBase::command(uint8_t value) { + this->start_command_(); + this->write_byte(value); + this->end_command_(); +} + +void EPaperBase::data(uint8_t value) { + this->start_data_(); + this->write_byte(value); + this->end_data_(); +} + +// write a command followed by zero or more bytes of data. +// The command is the first byte, length is the length of data only in the second byte, followed by the data. +// [COMMAND, LENGTH, DATA...] +void EPaperBase::cmd_data(const uint8_t *data) { + const uint8_t command = data[0]; + const uint8_t length = data[1]; + const uint8_t *ptr = data + 2; + + ESP_LOGVV(TAG, "Command: 0x%02X, Length: %d, Data: %s", command, length, + format_hex_pretty(ptr, length, '.', false).c_str()); + + this->dc_pin_->digital_write(false); + this->enable(); + this->write_byte(command); + if (length > 0) { + this->dc_pin_->digital_write(true); + this->write_array(ptr, length); + } + this->disable(); +} + +bool EPaperBase::is_idle_() { + if (this->busy_pin_ == nullptr) { + return true; + } + return !this->busy_pin_->digital_read(); +} + +void EPaperBase::reset() { + if (this->reset_pin_ != nullptr) { + this->reset_pin_->digital_write(false); + this->disable_loop(); + this->set_timeout(this->reset_duration_, [this] { + this->reset_pin_->digital_write(true); + this->set_timeout(20, [this] { this->enable_loop(); }); + }); + } +} + +void EPaperBase::update() { + if (!this->state_queue_.empty()) { + ESP_LOGE(TAG, "Display update already in progress - %s", + LOG_STR_ARG(epaper_state_to_string(this->state_queue_.front()))); + return; + } + + this->state_queue_.push(EPaperState::UPDATE); + this->state_queue_.push(EPaperState::RESET); + this->state_queue_.push(EPaperState::INITIALISE); + this->state_queue_.push(EPaperState::TRANSFER_DATA); + this->state_queue_.push(EPaperState::POWER_ON); + this->state_queue_.push(EPaperState::REFRESH_SCREEN); + this->state_queue_.push(EPaperState::POWER_OFF); + this->state_queue_.push(EPaperState::DEEP_SLEEP); + this->state_queue_.push(EPaperState::IDLE); + + this->enable_loop(); +} + +void EPaperBase::loop() { + if (this->waiting_for_idle_) { + if (this->is_idle_()) { + this->waiting_for_idle_ = false; + } else { + if (App.get_loop_component_start_time() - this->waiting_for_idle_last_print_ >= 1000) { + ESP_LOGV(TAG, "Waiting for idle"); + this->waiting_for_idle_last_print_ = App.get_loop_component_start_time(); + } + return; + } + } + + auto state = this->state_queue_.front(); + + switch (state) { + case EPaperState::IDLE: + this->disable_loop(); + break; + case EPaperState::UPDATE: + this->do_update_(); // Calls ESPHome (current page) lambda + break; + case EPaperState::RESET: + this->reset(); + break; + case EPaperState::INITIALISE: + this->initialise_(); + break; + case EPaperState::TRANSFER_DATA: + if (!this->transfer_data()) { + return; // Not done yet, come back next loop + } + break; + case EPaperState::POWER_ON: + this->power_on(); + break; + case EPaperState::REFRESH_SCREEN: + this->refresh_screen(); + break; + case EPaperState::POWER_OFF: + this->power_off(); + break; + case EPaperState::DEEP_SLEEP: + this->deep_sleep(); + break; + } + this->state_queue_.pop(); +} + +void EPaperBase::start_command_() { + this->dc_pin_->digital_write(false); + this->enable(); +} + +void EPaperBase::end_command_() { this->disable(); } + +void EPaperBase::start_data_() { + this->dc_pin_->digital_write(true); + this->enable(); +} +void EPaperBase::end_data_() { this->disable(); } + +void EPaperBase::on_safe_shutdown() { this->deep_sleep(); } + +void EPaperBase::initialise_() { + size_t index = 0; + const auto &sequence = this->init_sequence_; + const size_t sequence_size = this->init_sequence_length_; + while (index != sequence_size) { + if (sequence_size - index < 2) { + this->mark_failed("Malformed init sequence"); + return; + } + const auto *ptr = sequence + index; + const uint8_t length = ptr[1]; + if (sequence_size - index < length + 2) { + this->mark_failed("Malformed init sequence"); + return; + } + + this->cmd_data(ptr); + index += length + 2; + } + + this->power_on(); +} + +} // namespace esphome::epaper_spi diff --git a/esphome/components/epaper_spi/epaper_spi.h b/esphome/components/epaper_spi/epaper_spi.h new file mode 100644 index 0000000000..f6b2d41c65 --- /dev/null +++ b/esphome/components/epaper_spi/epaper_spi.h @@ -0,0 +1,93 @@ +#pragma once + +#include "esphome/components/display/display_buffer.h" +#include "esphome/components/spi/spi.h" +#include "esphome/components/split_buffer/split_buffer.h" +#include "esphome/core/component.h" + +#include + +namespace esphome::epaper_spi { + +enum class EPaperState : uint8_t { + IDLE, + UPDATE, + RESET, + INITIALISE, + TRANSFER_DATA, + POWER_ON, + REFRESH_SCREEN, + POWER_OFF, + DEEP_SLEEP, +}; + +static const uint8_t MAX_TRANSFER_TIME = 10; // Transfer in 10ms blocks to allow the loop to run + +class EPaperBase : public display::DisplayBuffer, + public spi::SPIDevice { + public: + EPaperBase(const uint8_t *init_sequence, const size_t init_sequence_length) + : init_sequence_length_(init_sequence_length), init_sequence_(init_sequence) {} + void set_dc_pin(GPIOPin *dc_pin) { dc_pin_ = dc_pin; } + float get_setup_priority() const override; + void set_reset_pin(GPIOPin *reset) { this->reset_pin_ = reset; } + void set_busy_pin(GPIOPin *busy) { this->busy_pin_ = busy; } + void set_reset_duration(uint32_t reset_duration) { this->reset_duration_ = reset_duration; } + + void command(uint8_t value); + void data(uint8_t value); + void cmd_data(const uint8_t *data); + + void update() override; + void loop() override; + + void setup() override; + + void on_safe_shutdown() override; + + protected: + bool is_idle_(); + void setup_pins_(); + virtual void reset(); + void initialise_(); + bool init_buffer_(size_t buffer_length); + + virtual int get_width_controller() { return this->get_width_internal(); }; + virtual void deep_sleep() = 0; + /** + * Send data to the device via SPI + * @return true if done, false if should be called next loop + */ + virtual bool transfer_data() = 0; + virtual void refresh_screen() = 0; + + virtual void power_on() = 0; + virtual void power_off() = 0; + virtual uint32_t get_buffer_length() = 0; + + void start_command_(); + void end_command_(); + void start_data_(); + void end_data_(); + + const size_t init_sequence_length_{0}; + + size_t current_data_index_{0}; + uint32_t reset_duration_{200}; + uint32_t waiting_for_idle_last_print_{0}; + + GPIOPin *dc_pin_; + GPIOPin *busy_pin_{nullptr}; + GPIOPin *reset_pin_{nullptr}; + + const uint8_t *init_sequence_{nullptr}; + + bool waiting_for_idle_{false}; + + split_buffer::SplitBuffer buffer_; + + std::queue state_queue_{{EPaperState::IDLE}}; +}; + +} // namespace esphome::epaper_spi diff --git a/esphome/components/epaper_spi/epaper_spi_model_7p3in_spectra_e6.cpp b/esphome/components/epaper_spi/epaper_spi_model_7p3in_spectra_e6.cpp new file mode 100644 index 0000000000..f6273b392f --- /dev/null +++ b/esphome/components/epaper_spi/epaper_spi_model_7p3in_spectra_e6.cpp @@ -0,0 +1,42 @@ +#include "epaper_spi_model_7p3in_spectra_e6.h" + +namespace esphome::epaper_spi { + +static constexpr const char *const TAG = "epaper_spi.7.3in-spectra-e6"; + +void EPaper7p3InSpectraE6::power_on() { + ESP_LOGI(TAG, "Power on"); + this->command(0x04); + this->waiting_for_idle_ = true; +} + +void EPaper7p3InSpectraE6::power_off() { + ESP_LOGI(TAG, "Power off"); + this->command(0x02); + this->data(0x00); + this->waiting_for_idle_ = true; +} + +void EPaper7p3InSpectraE6::refresh_screen() { + ESP_LOGI(TAG, "Refresh"); + this->command(0x12); + this->data(0x00); + this->waiting_for_idle_ = true; +} + +void EPaper7p3InSpectraE6::deep_sleep() { + ESP_LOGI(TAG, "Deep sleep"); + this->command(0x07); + this->data(0xA5); +} + +void EPaper7p3InSpectraE6::dump_config() { + LOG_DISPLAY("", "E-Paper SPI", this); + ESP_LOGCONFIG(TAG, " Model: 7.3in Spectra E6"); + LOG_PIN(" Reset Pin: ", this->reset_pin_); + LOG_PIN(" DC Pin: ", this->dc_pin_); + LOG_PIN(" Busy Pin: ", this->busy_pin_); + LOG_UPDATE_INTERVAL(this); +} + +} // namespace esphome::epaper_spi diff --git a/esphome/components/epaper_spi/epaper_spi_model_7p3in_spectra_e6.h b/esphome/components/epaper_spi/epaper_spi_model_7p3in_spectra_e6.h new file mode 100644 index 0000000000..6e850085ac --- /dev/null +++ b/esphome/components/epaper_spi/epaper_spi_model_7p3in_spectra_e6.h @@ -0,0 +1,45 @@ +#pragma once + +#include "epaper_spi_spectra_e6.h" + +namespace esphome::epaper_spi { + +class EPaper7p3InSpectraE6 : public EPaperSpectraE6 { + static constexpr const uint16_t WIDTH = 800; + static constexpr const uint16_t HEIGHT = 480; + // clang-format off + + // Command, data length, data + static constexpr uint8_t INIT_SEQUENCE[] = { + 0xAA, 6, 0x49, 0x55, 0x20, 0x08, 0x09, 0x18, + 0x01, 1, 0x3F, + 0x00, 2, 0x5F, 0x69, + 0x03, 4, 0x00, 0x54, 0x00, 0x44, + 0x05, 4, 0x40, 0x1F, 0x1F, 0x2C, + 0x06, 4, 0x6F, 0x1F, 0x17, 0x49, + 0x08, 4, 0x6F, 0x1F, 0x1F, 0x22, + 0x30, 1, 0x03, + 0x50, 1, 0x3F, + 0x60, 2, 0x02, 0x00, + 0x61, 4, WIDTH / 256, WIDTH % 256, HEIGHT / 256, HEIGHT % 256, + 0x84, 1, 0x01, + 0xE3, 1, 0x2F, + }; + // clang-format on + + public: + EPaper7p3InSpectraE6() : EPaperSpectraE6(INIT_SEQUENCE, sizeof(INIT_SEQUENCE)) {} + + void dump_config() override; + + protected: + int get_width_internal() override { return WIDTH; }; + int get_height_internal() override { return HEIGHT; }; + + void refresh_screen() override; + void power_on() override; + void power_off() override; + void deep_sleep() override; +}; + +} // namespace esphome::epaper_spi diff --git a/esphome/components/epaper_spi/epaper_spi_spectra_e6.cpp b/esphome/components/epaper_spi/epaper_spi_spectra_e6.cpp new file mode 100644 index 0000000000..dccc691252 --- /dev/null +++ b/esphome/components/epaper_spi/epaper_spi_spectra_e6.cpp @@ -0,0 +1,135 @@ +#include "epaper_spi_spectra_e6.h" + +#include "esphome/core/log.h" + +namespace esphome::epaper_spi { + +static constexpr const char *const TAG = "epaper_spi.6c"; + +static inline uint8_t color_to_hex(Color color) { + if (color.red > 127) { + if (color.green > 170) { + if (color.blue > 127) { + return 0x1; // White + } else { + return 0x2; // Yellow + } + } else { + return 0x3; // Red (or Magenta) + } + } else { + if (color.green > 127) { + if (color.blue > 127) { + return 0x5; // Cyan -> Blue + } else { + return 0x6; // Green + } + } else { + if (color.blue > 127) { + return 0x5; // Blue + } else { + return 0x0; // Black + } + } + } +} + +void EPaperSpectraE6::fill(Color color) { + uint8_t pixel_color; + if (color.is_on()) { + pixel_color = color_to_hex(color); + } else { + pixel_color = 0x1; + } + + // We store 8 bitset<3> in 3 bytes + // | byte 1 | byte 2 | byte 3 | + // |aaabbbaa|abbbaaab|bbaaabbb| + uint8_t byte_1 = pixel_color << 5 | pixel_color << 2 | pixel_color >> 1; + uint8_t byte_2 = pixel_color << 7 | pixel_color << 4 | pixel_color << 1 | pixel_color >> 2; + uint8_t byte_3 = pixel_color << 6 | pixel_color << 3 | pixel_color << 0; + + const size_t buffer_length = this->get_buffer_length(); + for (size_t i = 0; i < buffer_length; i += 3) { + this->buffer_[i + 0] = byte_1; + this->buffer_[i + 1] = byte_2; + this->buffer_[i + 2] = byte_3; + } +} + +uint32_t EPaperSpectraE6::get_buffer_length() { + // 6 colors buffer, 1 pixel = 3 bits, we will store 8 pixels in 24 bits = 3 bytes + return this->get_width_controller() * this->get_height_internal() / 8u * 3u; +} + +void HOT EPaperSpectraE6::draw_absolute_pixel_internal(int x, int y, Color color) { + if (x >= this->get_width_internal() || y >= this->get_height_internal() || x < 0 || y < 0) + return; + + uint8_t pixel_bits = color_to_hex(color); + uint32_t pixel_position = x + y * this->get_width_controller(); + uint32_t first_bit_position = pixel_position * 3; + uint32_t byte_position = first_bit_position / 8u; + uint32_t byte_subposition = first_bit_position % 8u; + + if (byte_subposition <= 5) { + this->buffer_[byte_position] = (this->buffer_[byte_position] & (0xFF ^ (0b111 << (5 - byte_subposition)))) | + (pixel_bits << (5 - byte_subposition)); + } else { + this->buffer_[byte_position] = (this->buffer_[byte_position] & (0xFF ^ (0b111 >> (byte_subposition - 5)))) | + (pixel_bits >> (byte_subposition - 5)); + + this->buffer_[byte_position + 1] = + (this->buffer_[byte_position + 1] & (0xFF ^ (0xFF & (0b111 << (13 - byte_subposition))))) | + (pixel_bits << (13 - byte_subposition)); + } +} + +bool HOT EPaperSpectraE6::transfer_data() { + const uint32_t start_time = App.get_loop_component_start_time(); + if (this->current_data_index_ == 0) { + ESP_LOGV(TAG, "Sending data"); + this->command(0x10); + } + + uint8_t bytes_to_send[4]{0}; + const size_t buffer_length = this->get_buffer_length(); + for (size_t i = this->current_data_index_; i < buffer_length; i += 3) { + const uint32_t triplet = encode_uint24(this->buffer_[i + 0], this->buffer_[i + 1], this->buffer_[i + 2]); + // 8 pixels are stored in 3 bytes + // |aaabbbaa|abbbaaab|bbaaabbb| + // | byte 1 | byte 2 | byte 3 | + bytes_to_send[0] = ((triplet >> 17) & 0b01110000) | ((triplet >> 18) & 0b00000111); + bytes_to_send[1] = ((triplet >> 11) & 0b01110000) | ((triplet >> 12) & 0b00000111); + bytes_to_send[2] = ((triplet >> 5) & 0b01110000) | ((triplet >> 6) & 0b00000111); + bytes_to_send[3] = ((triplet << 1) & 0b01110000) | ((triplet << 0) & 0b00000111); + + this->start_data_(); + this->write_array(bytes_to_send, sizeof(bytes_to_send)); + this->end_data_(); + + if (millis() - start_time > MAX_TRANSFER_TIME) { + // Let the main loop run and come back next loop + this->current_data_index_ = i + 3; + return false; + } + } + // Finished the entire dataset + this->current_data_index_ = 0; + return true; +} + +void EPaperSpectraE6::reset() { + if (this->reset_pin_ != nullptr) { + this->disable_loop(); + this->reset_pin_->digital_write(true); + this->set_timeout(20, [this] { + this->reset_pin_->digital_write(false); + delay(2); + this->reset_pin_->digital_write(true); + this->set_timeout(20, [this] { this->enable_loop(); }); + }); + } +} + +} // namespace esphome::epaper_spi diff --git a/esphome/components/epaper_spi/epaper_spi_spectra_e6.h b/esphome/components/epaper_spi/epaper_spi_spectra_e6.h new file mode 100644 index 0000000000..9f0652f79d --- /dev/null +++ b/esphome/components/epaper_spi/epaper_spi_spectra_e6.h @@ -0,0 +1,23 @@ +#pragma once + +#include "epaper_spi.h" + +namespace esphome::epaper_spi { + +class EPaperSpectraE6 : public EPaperBase { + public: + EPaperSpectraE6(const uint8_t *init_sequence, const size_t init_sequence_length) + : EPaperBase(init_sequence, init_sequence_length) {} + + display::DisplayType get_display_type() override { return display::DisplayType::DISPLAY_TYPE_COLOR; } + void fill(Color color) override; + + protected: + void draw_absolute_pixel_internal(int x, int y, Color color) override; + uint32_t get_buffer_length() override; + + bool transfer_data() override; + void reset() override; +}; + +} // namespace esphome::epaper_spi diff --git a/esphome/components/es7210/es7210.cpp b/esphome/components/es7210/es7210.cpp index e5729703ed..1358121c1b 100644 --- a/esphome/components/es7210/es7210.cpp +++ b/esphome/components/es7210/es7210.cpp @@ -97,12 +97,12 @@ bool ES7210::set_mic_gain(float mic_gain) { } bool ES7210::configure_sample_rate_() { - int mclk_fre = this->sample_rate_ * MCLK_DIV_FRE; + uint32_t mclk_fre = this->sample_rate_ * MCLK_DIV_FRE; int coeff = -1; - for (int i = 0; i < (sizeof(ES7210_COEFFICIENTS) / sizeof(ES7210_COEFFICIENTS[0])); ++i) { + for (size_t i = 0; i < (sizeof(ES7210_COEFFICIENTS) / sizeof(ES7210_COEFFICIENTS[0])); ++i) { if (ES7210_COEFFICIENTS[i].lrclk == this->sample_rate_ && ES7210_COEFFICIENTS[i].mclk == mclk_fre) - coeff = i; + coeff = static_cast(i); } if (coeff >= 0) { diff --git a/esphome/components/esp32/__init__.py b/esphome/components/esp32/__init__.py index 12d84dd4b3..860f2450e6 100644 --- a/esphome/components/esp32/__init__.py +++ b/esphome/components/esp32/__init__.py @@ -36,9 +36,8 @@ from esphome.const import ( __version__, ) from esphome.core import CORE, HexInt, TimePeriod -from esphome.cpp_generator import RawExpression import esphome.final_validate as fv -from esphome.helpers import copy_file_if_changed, mkdir_p, write_file_if_changed +from esphome.helpers import copy_file_if_changed, write_file_if_changed from esphome.types import ConfigType from esphome.writer import clean_cmake_cache @@ -157,8 +156,6 @@ def set_core_data(config): conf = config[CONF_FRAMEWORK] if conf[CONF_TYPE] == FRAMEWORK_ESP_IDF: CORE.data[KEY_CORE][KEY_TARGET_FRAMEWORK] = "esp-idf" - CORE.data[KEY_ESP32][KEY_SDKCONFIG_OPTIONS] = {} - CORE.data[KEY_ESP32][KEY_COMPONENTS] = {} elif conf[CONF_TYPE] == FRAMEWORK_ARDUINO: CORE.data[KEY_CORE][KEY_TARGET_FRAMEWORK] = "arduino" if variant not in ARDUINO_ALLOWED_VARIANTS: @@ -166,6 +163,8 @@ def set_core_data(config): f"ESPHome does not support using the Arduino framework for the {variant}. Please use the ESP-IDF framework instead.", path=[CONF_FRAMEWORK, CONF_TYPE], ) + CORE.data[KEY_ESP32][KEY_SDKCONFIG_OPTIONS] = {} + CORE.data[KEY_ESP32][KEY_COMPONENTS] = {} CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] = cv.Version.parse( config[CONF_FRAMEWORK][CONF_VERSION] ) @@ -236,8 +235,6 @@ SdkconfigValueType = bool | int | HexInt | str | RawSdkconfigValue def add_idf_sdkconfig_option(name: str, value: SdkconfigValueType): """Set an esp-idf sdkconfig value.""" - if not CORE.using_esp_idf: - raise ValueError("Not an esp-idf project") CORE.data[KEY_ESP32][KEY_SDKCONFIG_OPTIONS][name] = value @@ -252,8 +249,6 @@ def add_idf_component( submodules: list[str] | None = None, ): """Add an esp-idf component to the project.""" - if not CORE.using_esp_idf: - raise ValueError("Not an esp-idf project") if not repo and not ref and not path: raise ValueError("Requires at least one of repo, ref or path") if refresh or submodules or components: @@ -277,14 +272,14 @@ def add_idf_component( } -def add_extra_script(stage: str, filename: str, path: str): +def add_extra_script(stage: str, filename: str, path: Path): """Add an extra script to the project.""" key = f"{stage}:{filename}" if add_extra_build_file(filename, path): cg.add_platformio_option("extra_scripts", [key]) -def add_extra_build_file(filename: str, path: str) -> bool: +def add_extra_build_file(filename: str, path: Path) -> bool: """Add an extra build file to the project.""" if filename not in CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES]: CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES][filename] = { @@ -301,14 +296,9 @@ def _format_framework_arduino_version(ver: cv.Version) -> str: return f"pioarduino/framework-arduinoespressif32@https://github.com/espressif/arduino-esp32/releases/download/{str(ver)}/esp32-{str(ver)}.zip" -def _format_framework_espidf_version( - ver: cv.Version, release: str, for_platformio: bool -) -> str: - # format the given arduino (https://github.com/espressif/esp-idf/releases) version to +def _format_framework_espidf_version(ver: cv.Version, release: str) -> str: + # format the given espidf (https://github.com/pioarduino/esp-idf/releases) version to # a PIO platformio/framework-espidf value - # List of package versions: https://api.registry.platformio.org/v3/packages/platformio/tool/framework-espidf - if for_platformio: - return f"platformio/framework-espidf@~3.{ver.major}{ver.minor:02d}{ver.patch:02d}.0" if release: return f"pioarduino/framework-espidf@https://github.com/pioarduino/esp-idf/releases/download/v{str(ver)}.{release}/esp-idf-v{str(ver)}.zip" return f"pioarduino/framework-espidf@https://github.com/pioarduino/esp-idf/releases/download/v{str(ver)}/esp-idf-v{str(ver)}.zip" @@ -322,154 +312,114 @@ def _format_framework_espidf_version( # The default/recommended arduino framework version # - https://github.com/espressif/arduino-esp32/releases -RECOMMENDED_ARDUINO_FRAMEWORK_VERSION = cv.Version(3, 2, 1) -# The platform-espressif32 version to use for arduino frameworks -# - https://github.com/pioarduino/platform-espressif32/releases -ARDUINO_PLATFORM_VERSION = cv.Version(54, 3, 21, "2") +ARDUINO_FRAMEWORK_VERSION_LOOKUP = { + "recommended": cv.Version(3, 2, 1), + "latest": cv.Version(3, 3, 1), + "dev": cv.Version(3, 3, 1), +} +ARDUINO_PLATFORM_VERSION_LOOKUP = { + cv.Version(3, 3, 1): cv.Version(55, 3, 31), + cv.Version(3, 3, 0): cv.Version(55, 3, 30, "2"), + cv.Version(3, 2, 1): cv.Version(54, 3, 21, "2"), + cv.Version(3, 2, 0): cv.Version(54, 3, 20), + cv.Version(3, 1, 3): cv.Version(53, 3, 13), + cv.Version(3, 1, 2): cv.Version(53, 3, 12), + cv.Version(3, 1, 1): cv.Version(53, 3, 11), + cv.Version(3, 1, 0): cv.Version(53, 3, 10), +} # The default/recommended esp-idf framework version # - https://github.com/espressif/esp-idf/releases -# - https://api.registry.platformio.org/v3/packages/platformio/tool/framework-espidf -RECOMMENDED_ESP_IDF_FRAMEWORK_VERSION = cv.Version(5, 4, 2) -# The platformio/espressif32 version to use for esp-idf frameworks -# - https://github.com/platformio/platform-espressif32/releases -# - https://api.registry.platformio.org/v3/packages/platformio/platform/espressif32 -ESP_IDF_PLATFORM_VERSION = cv.Version(54, 3, 21, "2") +ESP_IDF_FRAMEWORK_VERSION_LOOKUP = { + "recommended": cv.Version(5, 4, 2), + "latest": cv.Version(5, 5, 1), + "dev": cv.Version(5, 5, 1), +} +ESP_IDF_PLATFORM_VERSION_LOOKUP = { + cv.Version(5, 5, 1): cv.Version(55, 3, 31), + cv.Version(5, 5, 0): cv.Version(55, 3, 31), + cv.Version(5, 4, 2): cv.Version(54, 3, 21, "2"), + cv.Version(5, 4, 1): cv.Version(54, 3, 21, "2"), + cv.Version(5, 4, 0): cv.Version(54, 3, 21, "2"), + cv.Version(5, 3, 2): cv.Version(53, 3, 13), + cv.Version(5, 3, 1): cv.Version(53, 3, 13), + cv.Version(5, 3, 0): cv.Version(53, 3, 13), + cv.Version(5, 1, 6): cv.Version(51, 3, 7), + cv.Version(5, 1, 5): cv.Version(51, 3, 7), +} -# List based on https://registry.platformio.org/tools/platformio/framework-espidf/versions -SUPPORTED_PLATFORMIO_ESP_IDF_5X = [ - cv.Version(5, 3, 1), - cv.Version(5, 3, 0), - cv.Version(5, 2, 2), - cv.Version(5, 2, 1), - cv.Version(5, 1, 2), - cv.Version(5, 1, 1), - cv.Version(5, 1, 0), - cv.Version(5, 0, 2), - cv.Version(5, 0, 1), - cv.Version(5, 0, 0), -] - -# pioarduino versions that don't require a release number -# List based on https://github.com/pioarduino/esp-idf/releases -SUPPORTED_PIOARDUINO_ESP_IDF_5X = [ - cv.Version(5, 5, 0), - cv.Version(5, 4, 2), - cv.Version(5, 4, 1), - cv.Version(5, 4, 0), - cv.Version(5, 3, 3), - cv.Version(5, 3, 2), - cv.Version(5, 3, 1), - cv.Version(5, 3, 0), - cv.Version(5, 1, 5), - cv.Version(5, 1, 6), -] +# The platform-espressif32 version +# - https://github.com/pioarduino/platform-espressif32/releases +PLATFORM_VERSION_LOOKUP = { + "recommended": cv.Version(54, 3, 21, "2"), + "latest": cv.Version(55, 3, 31), + "dev": "https://github.com/pioarduino/platform-espressif32.git#develop", +} -def _arduino_check_versions(value): +def _check_versions(value): value = value.copy() - lookups = { - "dev": (cv.Version(3, 2, 1), "https://github.com/espressif/arduino-esp32.git"), - "latest": (cv.Version(3, 2, 1), None), - "recommended": (RECOMMENDED_ARDUINO_FRAMEWORK_VERSION, None), - } - if value[CONF_VERSION] in lookups: - if CONF_SOURCE in value: + if value[CONF_VERSION] in PLATFORM_VERSION_LOOKUP: + if CONF_SOURCE in value or CONF_PLATFORM_VERSION in value: raise cv.Invalid( - "Framework version needs to be explicitly specified when custom source is used." + "Version needs to be explicitly set when a custom source or platform_version is used." ) - version, source = lookups[value[CONF_VERSION]] + platform_lookup = PLATFORM_VERSION_LOOKUP[value[CONF_VERSION]] + value[CONF_PLATFORM_VERSION] = _parse_platform_version(str(platform_lookup)) + + if value[CONF_TYPE] == FRAMEWORK_ARDUINO: + version = ARDUINO_FRAMEWORK_VERSION_LOOKUP[value[CONF_VERSION]] + else: + version = ESP_IDF_FRAMEWORK_VERSION_LOOKUP[value[CONF_VERSION]] else: version = cv.Version.parse(cv.version_number(value[CONF_VERSION])) - source = value.get(CONF_SOURCE, None) value[CONF_VERSION] = str(version) - value[CONF_SOURCE] = source or _format_framework_arduino_version(version) - value[CONF_PLATFORM_VERSION] = value.get( - CONF_PLATFORM_VERSION, _parse_platform_version(str(ARDUINO_PLATFORM_VERSION)) - ) + if value[CONF_TYPE] == FRAMEWORK_ARDUINO: + if version < cv.Version(3, 0, 0): + raise cv.Invalid("Only Arduino 3.0+ is supported.") + recommended_version = ARDUINO_FRAMEWORK_VERSION_LOOKUP["recommended"] + platform_lookup = ARDUINO_PLATFORM_VERSION_LOOKUP.get(version) + value[CONF_SOURCE] = value.get( + CONF_SOURCE, _format_framework_arduino_version(version) + ) + if value[CONF_SOURCE].startswith("http"): + value[CONF_SOURCE] = ( + f"pioarduino/framework-arduinoespressif32@{value[CONF_SOURCE]}" + ) + else: + if version < cv.Version(5, 0, 0): + raise cv.Invalid("Only ESP-IDF 5.0+ is supported.") + recommended_version = ESP_IDF_FRAMEWORK_VERSION_LOOKUP["recommended"] + platform_lookup = ESP_IDF_PLATFORM_VERSION_LOOKUP.get(version) + value[CONF_SOURCE] = value.get( + CONF_SOURCE, + _format_framework_espidf_version(version, value.get(CONF_RELEASE, None)), + ) + if value[CONF_SOURCE].startswith("http"): + value[CONF_SOURCE] = f"pioarduino/framework-espidf@{value[CONF_SOURCE]}" - if value[CONF_SOURCE].startswith("http"): - # prefix is necessary or platformio will complain with a cryptic error - value[CONF_SOURCE] = f"framework-arduinoespressif32@{value[CONF_SOURCE]}" + if CONF_PLATFORM_VERSION not in value: + if platform_lookup is None: + raise cv.Invalid( + "Framework version not recognized; please specify platform_version" + ) + value[CONF_PLATFORM_VERSION] = _parse_platform_version(str(platform_lookup)) - if version != RECOMMENDED_ARDUINO_FRAMEWORK_VERSION: + if version != recommended_version: _LOGGER.warning( - "The selected Arduino framework version is not the recommended one. " + "The selected framework version is not the recommended one. " "If there are connectivity or build issues please remove the manual version." ) - return value - - -def _esp_idf_check_versions(value): - value = value.copy() - lookups = { - "dev": (cv.Version(5, 4, 2), "https://github.com/espressif/esp-idf.git"), - "latest": (cv.Version(5, 2, 2), None), - "recommended": (RECOMMENDED_ESP_IDF_FRAMEWORK_VERSION, None), - } - - if value[CONF_VERSION] in lookups: - if CONF_SOURCE in value: - raise cv.Invalid( - "Framework version needs to be explicitly specified when custom source is used." - ) - - version, source = lookups[value[CONF_VERSION]] - else: - version = cv.Version.parse(cv.version_number(value[CONF_VERSION])) - source = value.get(CONF_SOURCE, None) - - if version < cv.Version(5, 0, 0): - raise cv.Invalid("Only ESP-IDF 5.0+ is supported.") - - # flag this for later *before* we set value[CONF_PLATFORM_VERSION] below - has_platform_ver = CONF_PLATFORM_VERSION in value - - value[CONF_PLATFORM_VERSION] = value.get( - CONF_PLATFORM_VERSION, _parse_platform_version(str(ESP_IDF_PLATFORM_VERSION)) - ) - - if ( - is_platformio := _platform_is_platformio(value[CONF_PLATFORM_VERSION]) - ) and version not in SUPPORTED_PLATFORMIO_ESP_IDF_5X: - raise cv.Invalid( - f"ESP-IDF {str(version)} not supported by platformio/espressif32" - ) - - if ( - version in SUPPORTED_PLATFORMIO_ESP_IDF_5X - and version not in SUPPORTED_PIOARDUINO_ESP_IDF_5X - ) and not has_platform_ver: - raise cv.Invalid( - f"ESP-IDF {value[CONF_VERSION]} may be supported by platformio/espressif32; please specify '{CONF_PLATFORM_VERSION}'" - ) - - if ( - not is_platformio - and CONF_RELEASE not in value - and version not in SUPPORTED_PIOARDUINO_ESP_IDF_5X + if value[CONF_PLATFORM_VERSION] != _parse_platform_version( + str(PLATFORM_VERSION_LOOKUP["recommended"]) ): - raise cv.Invalid( - f"ESP-IDF {value[CONF_VERSION]} is not available with pioarduino; you may need to specify '{CONF_RELEASE}'" - ) - - value[CONF_VERSION] = str(version) - value[CONF_SOURCE] = source or _format_framework_espidf_version( - version, value.get(CONF_RELEASE, None), is_platformio - ) - - if value[CONF_SOURCE].startswith("http"): - # prefix is necessary or platformio will complain with a cryptic error - value[CONF_SOURCE] = f"framework-espidf@{value[CONF_SOURCE]}" - - if version != RECOMMENDED_ESP_IDF_FRAMEWORK_VERSION: _LOGGER.warning( - "The selected ESP-IDF framework version is not the recommended one. " + "The selected platform version is not the recommended one. " "If there are connectivity or build issues please remove the manual version." ) @@ -479,26 +429,14 @@ def _esp_idf_check_versions(value): def _parse_platform_version(value): try: ver = cv.Version.parse(cv.version_number(value)) - if ver.major >= 50: # a pioarduino version - release = f"{ver.major}.{ver.minor:02d}.{ver.patch:02d}" - if ver.extra: - release += f"-{ver.extra}" - return f"https://github.com/pioarduino/platform-espressif32/releases/download/{release}/platform-espressif32.zip" - # if platform version is a valid version constraint, prefix the default package - cv.platformio_version_constraint(value) - return f"platformio/espressif32@{value}" + release = f"{ver.major}.{ver.minor:02d}.{ver.patch:02d}" + if ver.extra: + release += f"-{ver.extra}" + return f"https://github.com/pioarduino/platform-espressif32/releases/download/{release}/platform-espressif32.zip" except cv.Invalid: return value -def _platform_is_platformio(value): - try: - ver = cv.Version.parse(cv.version_number(value)) - return ver.major < 50 - except cv.Invalid: - return "platformio" in value - - def _detect_variant(value): board = value.get(CONF_BOARD) variant = value.get(CONF_VARIANT) @@ -588,24 +526,6 @@ def final_validate(config): return config -ARDUINO_FRAMEWORK_SCHEMA = cv.All( - cv.Schema( - { - cv.Optional(CONF_VERSION, default="recommended"): cv.string_strict, - cv.Optional(CONF_SOURCE): cv.string_strict, - cv.Optional(CONF_PLATFORM_VERSION): _parse_platform_version, - cv.Optional(CONF_ADVANCED, default={}): cv.Schema( - { - cv.Optional( - CONF_IGNORE_EFUSE_CUSTOM_MAC, default=False - ): cv.boolean, - } - ), - } - ), - _arduino_check_versions, -) - CONF_SDKCONFIG_OPTIONS = "sdkconfig_options" CONF_ENABLE_LWIP_DHCP_SERVER = "enable_lwip_dhcp_server" CONF_ENABLE_LWIP_MDNS_QUERIES = "enable_lwip_mdns_queries" @@ -624,9 +544,14 @@ def _validate_idf_component(config: ConfigType) -> ConfigType: return config -ESP_IDF_FRAMEWORK_SCHEMA = cv.All( +FRAMEWORK_ESP_IDF = "esp-idf" +FRAMEWORK_ARDUINO = "arduino" +FRAMEWORK_SCHEMA = cv.All( cv.Schema( { + cv.Optional(CONF_TYPE, default=FRAMEWORK_ARDUINO): cv.one_of( + FRAMEWORK_ESP_IDF, FRAMEWORK_ARDUINO + ), cv.Optional(CONF_VERSION, default="recommended"): cv.string_strict, cv.Optional(CONF_RELEASE): cv.string_strict, cv.Optional(CONF_SOURCE): cv.string_strict, @@ -690,7 +615,7 @@ ESP_IDF_FRAMEWORK_SCHEMA = cv.All( ), } ), - _esp_idf_check_versions, + _check_versions, ) @@ -757,32 +682,18 @@ def _set_default_framework(config): config = config.copy() variant = config[CONF_VARIANT] + config[CONF_FRAMEWORK] = FRAMEWORK_SCHEMA({}) if variant in ARDUINO_ALLOWED_VARIANTS: - config[CONF_FRAMEWORK] = ARDUINO_FRAMEWORK_SCHEMA({}) config[CONF_FRAMEWORK][CONF_TYPE] = FRAMEWORK_ARDUINO - # Show the migration message _show_framework_migration_message( config.get(CONF_NAME, "This device"), variant ) else: - config[CONF_FRAMEWORK] = ESP_IDF_FRAMEWORK_SCHEMA({}) config[CONF_FRAMEWORK][CONF_TYPE] = FRAMEWORK_ESP_IDF return config -FRAMEWORK_ESP_IDF = "esp-idf" -FRAMEWORK_ARDUINO = "arduino" -FRAMEWORK_SCHEMA = cv.typed_schema( - { - FRAMEWORK_ESP_IDF: ESP_IDF_FRAMEWORK_SCHEMA, - FRAMEWORK_ARDUINO: ARDUINO_FRAMEWORK_SCHEMA, - }, - lower=True, - space="-", -) - - FLASH_SIZES = [ "2MB", "4MB", @@ -837,6 +748,8 @@ async def to_code(config): conf = config[CONF_FRAMEWORK] cg.add_platformio_option("platform", conf[CONF_PLATFORM_VERSION]) + if CONF_SOURCE in conf: + cg.add_platformio_option("platform_packages", [conf[CONF_SOURCE]]) if conf[CONF_ADVANCED][CONF_IGNORE_EFUSE_CUSTOM_MAC]: cg.add_define("USE_ESP32_IGNORE_EFUSE_CUSTOM_MAC") @@ -847,142 +760,146 @@ async def to_code(config): add_extra_script( "post", "post_build.py", - os.path.join(os.path.dirname(__file__), "post_build.py.script"), + Path(__file__).parent / "post_build.py.script", ) - freq = config[CONF_CPU_FREQUENCY][:-3] if conf[CONF_TYPE] == FRAMEWORK_ESP_IDF: cg.add_platformio_option("framework", "espidf") cg.add_build_flag("-DUSE_ESP_IDF") cg.add_build_flag("-DUSE_ESP32_FRAMEWORK_ESP_IDF") - cg.add_build_flag("-Wno-nonnull-compare") - - cg.add_platformio_option("platform_packages", [conf[CONF_SOURCE]]) - - add_idf_sdkconfig_option(f"CONFIG_IDF_TARGET_{variant}", True) - add_idf_sdkconfig_option( - f"CONFIG_ESPTOOLPY_FLASHSIZE_{config[CONF_FLASH_SIZE]}", True - ) - add_idf_sdkconfig_option("CONFIG_PARTITION_TABLE_SINGLE_APP", False) - add_idf_sdkconfig_option("CONFIG_PARTITION_TABLE_CUSTOM", True) - add_idf_sdkconfig_option( - "CONFIG_PARTITION_TABLE_CUSTOM_FILENAME", "partitions.csv" - ) - - # Increase freertos tick speed from 100Hz to 1kHz so that delay() resolution is 1ms - add_idf_sdkconfig_option("CONFIG_FREERTOS_HZ", 1000) - - # Setup watchdog - add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT", True) - add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_PANIC", True) - add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0", False) - add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1", False) - - # Disable dynamic log level control to save memory - add_idf_sdkconfig_option("CONFIG_LOG_DYNAMIC_LEVEL_CONTROL", False) - - # Set default CPU frequency - add_idf_sdkconfig_option(f"CONFIG_ESP_DEFAULT_CPU_FREQ_MHZ_{freq}", True) - - # Apply LWIP optimization settings - advanced = conf[CONF_ADVANCED] - # DHCP server: only disable if explicitly set to false - # WiFi component handles its own optimization when AP mode is not used - if ( - CONF_ENABLE_LWIP_DHCP_SERVER in advanced - and not advanced[CONF_ENABLE_LWIP_DHCP_SERVER] - ): - add_idf_sdkconfig_option("CONFIG_LWIP_DHCPS", False) - if not advanced.get(CONF_ENABLE_LWIP_MDNS_QUERIES, True): - add_idf_sdkconfig_option("CONFIG_LWIP_DNS_SUPPORT_MDNS_QUERIES", False) - if not advanced.get(CONF_ENABLE_LWIP_BRIDGE_INTERFACE, False): - add_idf_sdkconfig_option("CONFIG_LWIP_BRIDGEIF_MAX_PORTS", 0) - if advanced.get(CONF_EXECUTE_FROM_PSRAM, False): - add_idf_sdkconfig_option("CONFIG_SPIRAM_FETCH_INSTRUCTIONS", True) - add_idf_sdkconfig_option("CONFIG_SPIRAM_RODATA", True) - - # Apply LWIP core locking for better socket performance - # This is already enabled by default in Arduino framework, where it provides - # significant performance benefits. Our benchmarks show socket operations are - # 24-200% faster with core locking enabled: - # - select() on 4 sockets: ~190μs (Arduino/core locking) vs ~235μs (ESP-IDF default) - # - Up to 200% slower under load when all operations queue through tcpip_thread - # Enabling this makes ESP-IDF socket performance match Arduino framework. - if advanced.get(CONF_ENABLE_LWIP_TCPIP_CORE_LOCKING, True): - add_idf_sdkconfig_option("CONFIG_LWIP_TCPIP_CORE_LOCKING", True) - if advanced.get(CONF_ENABLE_LWIP_CHECK_THREAD_SAFETY, True): - add_idf_sdkconfig_option("CONFIG_LWIP_CHECK_THREAD_SAFETY", True) - - cg.add_platformio_option("board_build.partitions", "partitions.csv") - if CONF_PARTITIONS in config: - add_extra_build_file( - "partitions.csv", CORE.relative_config_path(config[CONF_PARTITIONS]) - ) - - if assertion_level := advanced.get(CONF_ASSERTION_LEVEL): - for key, flag in ASSERTION_LEVELS.items(): - add_idf_sdkconfig_option(flag, assertion_level == key) - - add_idf_sdkconfig_option("CONFIG_COMPILER_OPTIMIZATION_DEFAULT", False) - compiler_optimization = advanced.get(CONF_COMPILER_OPTIMIZATION) - for key, flag in COMPILER_OPTIMIZATIONS.items(): - add_idf_sdkconfig_option(flag, compiler_optimization == key) - - add_idf_sdkconfig_option( - "CONFIG_LWIP_ESP_LWIP_ASSERT", - conf[CONF_ADVANCED][CONF_ENABLE_LWIP_ASSERT], - ) - - if advanced.get(CONF_IGNORE_EFUSE_MAC_CRC): - add_idf_sdkconfig_option("CONFIG_ESP_MAC_IGNORE_MAC_CRC_ERROR", True) - add_idf_sdkconfig_option( - "CONFIG_ESP_PHY_CALIBRATION_AND_DATA_STORAGE", False - ) - if advanced.get(CONF_ENABLE_IDF_EXPERIMENTAL_FEATURES): - _LOGGER.warning( - "Using experimental features in ESP-IDF may result in unexpected failures." - ) - add_idf_sdkconfig_option("CONFIG_IDF_EXPERIMENTAL_FEATURES", True) - - cg.add_define( - "USE_ESP_IDF_VERSION_CODE", - cg.RawExpression( - f"VERSION_CODE({framework_ver.major}, {framework_ver.minor}, {framework_ver.patch})" - ), - ) - - add_idf_sdkconfig_option( - f"CONFIG_LOG_DEFAULT_LEVEL_{conf[CONF_LOG_LEVEL]}", True - ) - - for name, value in conf[CONF_SDKCONFIG_OPTIONS].items(): - add_idf_sdkconfig_option(name, RawSdkconfigValue(value)) - - for component in conf[CONF_COMPONENTS]: - add_idf_component( - name=component[CONF_NAME], - repo=component.get(CONF_SOURCE), - ref=component.get(CONF_REF), - path=component.get(CONF_PATH), - ) - elif conf[CONF_TYPE] == FRAMEWORK_ARDUINO: - cg.add_platformio_option("framework", "arduino") + else: + cg.add_platformio_option("framework", "arduino, espidf") cg.add_build_flag("-DUSE_ARDUINO") cg.add_build_flag("-DUSE_ESP32_FRAMEWORK_ARDUINO") - cg.add_platformio_option("platform_packages", [conf[CONF_SOURCE]]) - - if CONF_PARTITIONS in config: - cg.add_platformio_option("board_build.partitions", config[CONF_PARTITIONS]) - else: - cg.add_platformio_option("board_build.partitions", "partitions.csv") - + cg.add_platformio_option( + "board_build.embed_txtfiles", + [ + "managed_components/espressif__esp_insights/server_certs/https_server.crt", + "managed_components/espressif__esp_rainmaker/server_certs/rmaker_mqtt_server.crt", + "managed_components/espressif__esp_rainmaker/server_certs/rmaker_claim_service_server.crt", + "managed_components/espressif__esp_rainmaker/server_certs/rmaker_ota_server.crt", + ], + ) cg.add_define( "USE_ARDUINO_VERSION_CODE", cg.RawExpression( f"VERSION_CODE({framework_ver.major}, {framework_ver.minor}, {framework_ver.patch})" ), ) - cg.add(RawExpression(f"setCpuFrequencyMhz({freq})")) + add_idf_sdkconfig_option("CONFIG_AUTOSTART_ARDUINO", True) + add_idf_sdkconfig_option("CONFIG_MBEDTLS_PSK_MODES", True) + add_idf_sdkconfig_option("CONFIG_MBEDTLS_CERTIFICATE_BUNDLE", True) + + cg.add_build_flag("-Wno-nonnull-compare") + + add_idf_sdkconfig_option(f"CONFIG_IDF_TARGET_{variant}", True) + add_idf_sdkconfig_option( + f"CONFIG_ESPTOOLPY_FLASHSIZE_{config[CONF_FLASH_SIZE]}", True + ) + add_idf_sdkconfig_option("CONFIG_PARTITION_TABLE_SINGLE_APP", False) + add_idf_sdkconfig_option("CONFIG_PARTITION_TABLE_CUSTOM", True) + add_idf_sdkconfig_option("CONFIG_PARTITION_TABLE_CUSTOM_FILENAME", "partitions.csv") + + # Increase freertos tick speed from 100Hz to 1kHz so that delay() resolution is 1ms + add_idf_sdkconfig_option("CONFIG_FREERTOS_HZ", 1000) + + # Setup watchdog + add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT", True) + add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_PANIC", True) + add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0", False) + add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1", False) + + # Disable dynamic log level control to save memory + add_idf_sdkconfig_option("CONFIG_LOG_DYNAMIC_LEVEL_CONTROL", False) + + # Set default CPU frequency + add_idf_sdkconfig_option( + f"CONFIG_ESP_DEFAULT_CPU_FREQ_MHZ_{config[CONF_CPU_FREQUENCY][:-3]}", True + ) + + # Apply LWIP optimization settings + advanced = conf[CONF_ADVANCED] + # DHCP server: only disable if explicitly set to false + # WiFi component handles its own optimization when AP mode is not used + # When using Arduino with Ethernet, DHCP server functions must be available + # for the Network library to compile, even if not actively used + if ( + CONF_ENABLE_LWIP_DHCP_SERVER in advanced + and not advanced[CONF_ENABLE_LWIP_DHCP_SERVER] + and not ( + conf[CONF_TYPE] == FRAMEWORK_ARDUINO + and "ethernet" in CORE.loaded_integrations + ) + ): + add_idf_sdkconfig_option("CONFIG_LWIP_DHCPS", False) + if not advanced.get(CONF_ENABLE_LWIP_MDNS_QUERIES, True): + add_idf_sdkconfig_option("CONFIG_LWIP_DNS_SUPPORT_MDNS_QUERIES", False) + if not advanced.get(CONF_ENABLE_LWIP_BRIDGE_INTERFACE, False): + add_idf_sdkconfig_option("CONFIG_LWIP_BRIDGEIF_MAX_PORTS", 0) + if advanced.get(CONF_EXECUTE_FROM_PSRAM, False): + add_idf_sdkconfig_option("CONFIG_SPIRAM_FETCH_INSTRUCTIONS", True) + add_idf_sdkconfig_option("CONFIG_SPIRAM_RODATA", True) + + # Apply LWIP core locking for better socket performance + # This is already enabled by default in Arduino framework, where it provides + # significant performance benefits. Our benchmarks show socket operations are + # 24-200% faster with core locking enabled: + # - select() on 4 sockets: ~190μs (Arduino/core locking) vs ~235μs (ESP-IDF default) + # - Up to 200% slower under load when all operations queue through tcpip_thread + # Enabling this makes ESP-IDF socket performance match Arduino framework. + if advanced.get(CONF_ENABLE_LWIP_TCPIP_CORE_LOCKING, True): + add_idf_sdkconfig_option("CONFIG_LWIP_TCPIP_CORE_LOCKING", True) + if advanced.get(CONF_ENABLE_LWIP_CHECK_THREAD_SAFETY, True): + add_idf_sdkconfig_option("CONFIG_LWIP_CHECK_THREAD_SAFETY", True) + + cg.add_platformio_option("board_build.partitions", "partitions.csv") + if CONF_PARTITIONS in config: + add_extra_build_file( + "partitions.csv", CORE.relative_config_path(config[CONF_PARTITIONS]) + ) + + if assertion_level := advanced.get(CONF_ASSERTION_LEVEL): + for key, flag in ASSERTION_LEVELS.items(): + add_idf_sdkconfig_option(flag, assertion_level == key) + + add_idf_sdkconfig_option("CONFIG_COMPILER_OPTIMIZATION_DEFAULT", False) + compiler_optimization = advanced.get(CONF_COMPILER_OPTIMIZATION) + for key, flag in COMPILER_OPTIMIZATIONS.items(): + add_idf_sdkconfig_option(flag, compiler_optimization == key) + + add_idf_sdkconfig_option( + "CONFIG_LWIP_ESP_LWIP_ASSERT", + conf[CONF_ADVANCED][CONF_ENABLE_LWIP_ASSERT], + ) + + if advanced.get(CONF_IGNORE_EFUSE_MAC_CRC): + add_idf_sdkconfig_option("CONFIG_ESP_MAC_IGNORE_MAC_CRC_ERROR", True) + add_idf_sdkconfig_option("CONFIG_ESP_PHY_CALIBRATION_AND_DATA_STORAGE", False) + if advanced.get(CONF_ENABLE_IDF_EXPERIMENTAL_FEATURES): + _LOGGER.warning( + "Using experimental features in ESP-IDF may result in unexpected failures." + ) + add_idf_sdkconfig_option("CONFIG_IDF_EXPERIMENTAL_FEATURES", True) + + cg.add_define( + "USE_ESP_IDF_VERSION_CODE", + cg.RawExpression( + f"VERSION_CODE({framework_ver.major}, {framework_ver.minor}, {framework_ver.patch})" + ), + ) + + add_idf_sdkconfig_option(f"CONFIG_LOG_DEFAULT_LEVEL_{conf[CONF_LOG_LEVEL]}", True) + + for name, value in conf[CONF_SDKCONFIG_OPTIONS].items(): + add_idf_sdkconfig_option(name, RawSdkconfigValue(value)) + + for component in conf[CONF_COMPONENTS]: + add_idf_component( + name=component[CONF_NAME], + repo=component.get(CONF_SOURCE), + ref=component.get(CONF_REF), + path=component.get(CONF_PATH), + ) APP_PARTITION_SIZES = { @@ -1056,13 +973,14 @@ def _write_sdkconfig(): ) + "\n" ) + if write_file_if_changed(internal_path, contents): # internal changed, update real one write_file_if_changed(sdk_path, contents) def _write_idf_component_yml(): - yml_path = Path(CORE.relative_build_path("src/idf_component.yml")) + yml_path = CORE.relative_build_path("src/idf_component.yml") if CORE.data[KEY_ESP32][KEY_COMPONENTS]: components: dict = CORE.data[KEY_ESP32][KEY_COMPONENTS] dependencies = {} @@ -1080,51 +998,48 @@ def _write_idf_component_yml(): contents = "" if write_file_if_changed(yml_path, contents): dependencies_lock = CORE.relative_build_path("dependencies.lock") - if os.path.isfile(dependencies_lock): - os.remove(dependencies_lock) + if dependencies_lock.is_file(): + dependencies_lock.unlink() clean_cmake_cache() # Called by writer.py def copy_files(): - if ( - CORE.using_arduino - and "partitions.csv" not in CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES] - ): - write_file_if_changed( - CORE.relative_build_path("partitions.csv"), - get_arduino_partition_csv( - CORE.platformio_options.get("board_upload.flash_size") - ), - ) - if CORE.using_esp_idf: - _write_sdkconfig() - _write_idf_component_yml() - if "partitions.csv" not in CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES]: + _write_sdkconfig() + _write_idf_component_yml() + + if "partitions.csv" not in CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES]: + if CORE.using_arduino: + write_file_if_changed( + CORE.relative_build_path("partitions.csv"), + get_arduino_partition_csv( + CORE.platformio_options.get("board_upload.flash_size") + ), + ) + else: write_file_if_changed( CORE.relative_build_path("partitions.csv"), get_idf_partition_csv( CORE.platformio_options.get("board_upload.flash_size") ), ) - # IDF build scripts look for version string to put in the build. - # However, if the build path does not have an initialized git repo, - # and no version.txt file exists, the CMake script fails for some setups. - # Fix by manually pasting a version.txt file, containing the ESPHome version - write_file_if_changed( - CORE.relative_build_path("version.txt"), - __version__, - ) + # IDF build scripts look for version string to put in the build. + # However, if the build path does not have an initialized git repo, + # and no version.txt file exists, the CMake script fails for some setups. + # Fix by manually pasting a version.txt file, containing the ESPHome version + write_file_if_changed( + CORE.relative_build_path("version.txt"), + __version__, + ) for file in CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES].values(): - if file[KEY_PATH].startswith("http"): + name: str = file[KEY_NAME] + path: Path = file[KEY_PATH] + if str(path).startswith("http"): import requests - mkdir_p(CORE.relative_build_path(os.path.dirname(file[KEY_NAME]))) - with open(CORE.relative_build_path(file[KEY_NAME]), "wb") as f: - f.write(requests.get(file[KEY_PATH], timeout=30).content) + CORE.relative_build_path(name).parent.mkdir(parents=True, exist_ok=True) + content = requests.get(path, timeout=30).content + CORE.relative_build_path(name).write_bytes(content) else: - copy_file_if_changed( - file[KEY_PATH], - CORE.relative_build_path(file[KEY_NAME]), - ) + copy_file_if_changed(path, CORE.relative_build_path(name)) diff --git a/esphome/components/esp32/boards.py b/esphome/components/esp32/boards.py index cf6cf8cbe5..5f039492c8 100644 --- a/esphome/components/esp32/boards.py +++ b/esphome/components/esp32/boards.py @@ -1504,6 +1504,10 @@ BOARDS = { "name": "BPI-Bit", "variant": VARIANT_ESP32, }, + "bpi-centi-s3": { + "name": "BPI-Centi-S3", + "variant": VARIANT_ESP32S3, + }, "bpi_leaf_s3": { "name": "BPI-Leaf-S3", "variant": VARIANT_ESP32S3, @@ -1664,10 +1668,46 @@ BOARDS = { "name": "Espressif ESP32-S3-DevKitC-1-N8 (8 MB QD, No PSRAM)", "variant": VARIANT_ESP32S3, }, + "esp32-s3-devkitc-1-n32r8v": { + "name": "Espressif ESP32-S3-DevKitC-1-N32R8V (32 MB Flash Octal, 8 MB PSRAM Octal)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n16r16": { + "name": "Espressif ESP32-S3-DevKitC-1-N16R16V (16 MB Flash Quad, 16 MB PSRAM Octal)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n16r2": { + "name": "Espressif ESP32-S3-DevKitC-1-N16R2 (16 MB Flash Quad, 2 MB PSRAM Quad)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n16r8": { + "name": "Espressif ESP32-S3-DevKitC-1-N16R8V (16 MB Flash Quad, 8 MB PSRAM Octal)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n4r2": { + "name": "Espressif ESP32-S3-DevKitC-1-N4R2 (4 MB Flash Quad, 2 MB PSRAM Quad)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n4r8": { + "name": "Espressif ESP32-S3-DevKitC-1-N4R8 (4 MB Flash Quad, 8 MB PSRAM Octal)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n8r2": { + "name": "Espressif ESP32-S3-DevKitC-1-N8R2 (8 MB Flash Quad, 2 MB PSRAM quad)", + "variant": VARIANT_ESP32S3, + }, + "esp32-s3-devkitc1-n8r8": { + "name": "Espressif ESP32-S3-DevKitC-1-N8R8 (8 MB Flash Quad, 8 MB PSRAM Octal)", + "variant": VARIANT_ESP32S3, + }, "esp32-s3-devkitm-1": { "name": "Espressif ESP32-S3-DevKitM-1", "variant": VARIANT_ESP32S3, }, + "esp32-s3-fh4r2": { + "name": "Espressif ESP32-S3-FH4R2 (4 MB QD, 2MB PSRAM)", + "variant": VARIANT_ESP32S3, + }, "esp32-solo1": { "name": "Espressif Generic ESP32-solo1 4M Flash", "variant": VARIANT_ESP32, @@ -1764,6 +1804,10 @@ BOARDS = { "name": "Franzininho WiFi MSC", "variant": VARIANT_ESP32S2, }, + "freenove-esp32-s3-n8r8": { + "name": "Freenove ESP32-S3 WROOM N8R8 (8MB Flash / 8MB PSRAM)", + "variant": VARIANT_ESP32S3, + }, "freenove_esp32_s3_wroom": { "name": "Freenove ESP32-S3 WROOM N8R8 (8MB Flash / 8MB PSRAM)", "variant": VARIANT_ESP32S3, @@ -1964,6 +2008,10 @@ BOARDS = { "name": "M5Stack AtomS3", "variant": VARIANT_ESP32S3, }, + "m5stack-atoms3u": { + "name": "M5Stack AtomS3U", + "variant": VARIANT_ESP32S3, + }, "m5stack-core-esp32": { "name": "M5Stack Core ESP32", "variant": VARIANT_ESP32, @@ -2084,6 +2132,10 @@ BOARDS = { "name": "Ai-Thinker NodeMCU-32S2 (ESP-12K)", "variant": VARIANT_ESP32S2, }, + "nologo_esp32c3_super_mini": { + "name": "Nologo ESP32C3 SuperMini", + "variant": VARIANT_ESP32C3, + }, "nscreen-32": { "name": "YeaCreate NSCREEN-32", "variant": VARIANT_ESP32, @@ -2192,6 +2244,10 @@ BOARDS = { "name": "SparkFun LoRa Gateway 1-Channel", "variant": VARIANT_ESP32, }, + "sparkfun_pro_micro_esp32c3": { + "name": "SparkFun Pro Micro ESP32-C3", + "variant": VARIANT_ESP32C3, + }, "sparkfun_qwiic_pocket_esp32c6": { "name": "SparkFun ESP32-C6 Qwiic Pocket", "variant": VARIANT_ESP32C6, @@ -2256,6 +2312,14 @@ BOARDS = { "name": "Turta IoT Node", "variant": VARIANT_ESP32, }, + "um_bling": { + "name": "Unexpected Maker BLING!", + "variant": VARIANT_ESP32S3, + }, + "um_edges3_d": { + "name": "Unexpected Maker EDGES3[D]", + "variant": VARIANT_ESP32S3, + }, "um_feathers2": { "name": "Unexpected Maker FeatherS2", "variant": VARIANT_ESP32S2, @@ -2268,10 +2332,18 @@ BOARDS = { "name": "Unexpected Maker FeatherS3", "variant": VARIANT_ESP32S3, }, + "um_feathers3_neo": { + "name": "Unexpected Maker FeatherS3 Neo", + "variant": VARIANT_ESP32S3, + }, "um_nanos3": { "name": "Unexpected Maker NanoS3", "variant": VARIANT_ESP32S3, }, + "um_omgs3": { + "name": "Unexpected Maker OMGS3", + "variant": VARIANT_ESP32S3, + }, "um_pros3": { "name": "Unexpected Maker PROS3", "variant": VARIANT_ESP32S3, @@ -2280,6 +2352,14 @@ BOARDS = { "name": "Unexpected Maker RMP", "variant": VARIANT_ESP32S2, }, + "um_squixl": { + "name": "Unexpected Maker SQUiXL", + "variant": VARIANT_ESP32S3, + }, + "um_tinyc6": { + "name": "Unexpected Maker TinyC6", + "variant": VARIANT_ESP32C6, + }, "um_tinys2": { "name": "Unexpected Maker TinyS2", "variant": VARIANT_ESP32S2, @@ -2401,3 +2481,4 @@ BOARDS = { "variant": VARIANT_ESP32S3, }, } +# DO NOT ADD ANYTHING BELOW THIS LINE diff --git a/esphome/components/esp32/preferences.cpp b/esphome/components/esp32/preferences.cpp index c5b07b497c..7bdbb265ca 100644 --- a/esphome/components/esp32/preferences.cpp +++ b/esphome/components/esp32/preferences.cpp @@ -17,7 +17,14 @@ static const char *const TAG = "esp32.preferences"; struct NVSData { std::string key; - std::vector data; + std::unique_ptr data; + size_t len; + + void set_data(const uint8_t *src, size_t size) { + data = std::make_unique(size); + memcpy(data.get(), src, size); + len = size; + } }; static std::vector s_pending_save; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) @@ -30,26 +37,26 @@ class ESP32PreferenceBackend : public ESPPreferenceBackend { // try find in pending saves and update that for (auto &obj : s_pending_save) { if (obj.key == key) { - obj.data.assign(data, data + len); + obj.set_data(data, len); return true; } } NVSData save{}; save.key = key; - save.data.assign(data, data + len); - s_pending_save.emplace_back(save); - ESP_LOGVV(TAG, "s_pending_save: key: %s, len: %d", key.c_str(), len); + save.set_data(data, len); + s_pending_save.emplace_back(std::move(save)); + ESP_LOGVV(TAG, "s_pending_save: key: %s, len: %zu", key.c_str(), len); return true; } bool load(uint8_t *data, size_t len) override { // try find in pending saves and load from that for (auto &obj : s_pending_save) { if (obj.key == key) { - if (obj.data.size() != len) { + if (obj.len != len) { // size mismatch return false; } - memcpy(data, obj.data.data(), len); + memcpy(data, obj.data.get(), len); return true; } } @@ -61,7 +68,7 @@ class ESP32PreferenceBackend : public ESPPreferenceBackend { return false; } if (actual_len != len) { - ESP_LOGVV(TAG, "NVS length does not match (%u!=%u)", actual_len, len); + ESP_LOGVV(TAG, "NVS length does not match (%zu!=%zu)", actual_len, len); return false; } err = nvs_get_blob(nvs_handle, key.c_str(), data, &len); @@ -69,7 +76,7 @@ class ESP32PreferenceBackend : public ESPPreferenceBackend { ESP_LOGV(TAG, "nvs_get_blob('%s') failed: %s", key.c_str(), esp_err_to_name(err)); return false; } else { - ESP_LOGVV(TAG, "nvs_get_blob: key: %s, len: %d", key.c_str(), len); + ESP_LOGVV(TAG, "nvs_get_blob: key: %s, len: %zu", key.c_str(), len); } return true; } @@ -112,7 +119,7 @@ class ESP32Preferences : public ESPPreferences { if (s_pending_save.empty()) return true; - ESP_LOGV(TAG, "Saving %d items...", s_pending_save.size()); + ESP_LOGV(TAG, "Saving %zu items...", s_pending_save.size()); // goal try write all pending saves even if one fails int cached = 0, written = 0, failed = 0; esp_err_t last_err = ESP_OK; @@ -123,11 +130,10 @@ class ESP32Preferences : public ESPPreferences { const auto &save = s_pending_save[i]; ESP_LOGVV(TAG, "Checking if NVS data %s has changed", save.key.c_str()); if (is_changed(nvs_handle, save)) { - esp_err_t err = nvs_set_blob(nvs_handle, save.key.c_str(), save.data.data(), save.data.size()); - ESP_LOGV(TAG, "sync: key: %s, len: %d", save.key.c_str(), save.data.size()); + esp_err_t err = nvs_set_blob(nvs_handle, save.key.c_str(), save.data.get(), save.len); + ESP_LOGV(TAG, "sync: key: %s, len: %zu", save.key.c_str(), save.len); if (err != 0) { - ESP_LOGV(TAG, "nvs_set_blob('%s', len=%u) failed: %s", save.key.c_str(), save.data.size(), - esp_err_to_name(err)); + ESP_LOGV(TAG, "nvs_set_blob('%s', len=%zu) failed: %s", save.key.c_str(), save.len, esp_err_to_name(err)); failed++; last_err = err; last_key = save.key; @@ -135,7 +141,7 @@ class ESP32Preferences : public ESPPreferences { } written++; } else { - ESP_LOGV(TAG, "NVS data not changed skipping %s len=%u", save.key.c_str(), save.data.size()); + ESP_LOGV(TAG, "NVS data not changed skipping %s len=%zu", save.key.c_str(), save.len); cached++; } s_pending_save.erase(s_pending_save.begin() + i); @@ -164,7 +170,7 @@ class ESP32Preferences : public ESPPreferences { return true; } // Check size first before allocating memory - if (actual_len != to_save.data.size()) { + if (actual_len != to_save.len) { return true; } auto stored_data = std::make_unique(actual_len); @@ -173,7 +179,7 @@ class ESP32Preferences : public ESPPreferences { ESP_LOGV(TAG, "nvs_get_blob('%s') failed: %s", to_save.key.c_str(), esp_err_to_name(err)); return true; } - return memcmp(to_save.data.data(), stored_data.get(), to_save.data.size()) != 0; + return memcmp(to_save.data.get(), stored_data.get(), to_save.len) != 0; } bool reset() override { diff --git a/esphome/components/esp32_ble/__init__.py b/esphome/components/esp32_ble/__init__.py index d2eaa3ce6f..15afb22ab8 100644 --- a/esphome/components/esp32_ble/__init__.py +++ b/esphome/components/esp32_ble/__init__.py @@ -1,5 +1,8 @@ +from collections.abc import Callable, MutableMapping from enum import Enum +import logging import re +from typing import Any from esphome import automation import esphome.codegen as cg @@ -9,6 +12,7 @@ from esphome.const import ( CONF_ENABLE_ON_BOOT, CONF_ESPHOME, CONF_ID, + CONF_MAX_CONNECTIONS, CONF_NAME, CONF_NAME_ADD_MAC_SUFFIX, ) @@ -19,6 +23,8 @@ DEPENDENCIES = ["esp32"] CODEOWNERS = ["@jesserockz", "@Rapsssito", "@bdraco"] DOMAIN = "esp32_ble" +_LOGGER = logging.getLogger(__name__) + class BTLoggers(Enum): """Bluetooth logger categories available in ESP-IDF. @@ -127,6 +133,28 @@ CONF_DISABLE_BT_LOGS = "disable_bt_logs" CONF_CONNECTION_TIMEOUT = "connection_timeout" CONF_MAX_NOTIFICATIONS = "max_notifications" +# BLE connection limits +# ESP-IDF CONFIG_BT_ACL_CONNECTIONS has range 1-9, default 4 +# Total instances: 10 (ADV + SCAN + connections) +# - ADV only: up to 9 connections +# - SCAN only: up to 9 connections +# - ADV + SCAN: up to 8 connections +DEFAULT_MAX_CONNECTIONS = 3 +IDF_MAX_CONNECTIONS = 9 + +# Connection slot tracking keys +KEY_ESP32_BLE = "esp32_ble" +KEY_USED_CONNECTION_SLOTS = "used_connection_slots" + +# Export for use by other components (bluetooth_proxy, etc.) +__all__ = [ + "DEFAULT_MAX_CONNECTIONS", + "IDF_MAX_CONNECTIONS", + "KEY_ESP32_BLE", + "KEY_USED_CONNECTION_SLOTS", + "consume_connection_slots", +] + NO_BLUETOOTH_VARIANTS = [const.VARIANT_ESP32S2] esp32_ble_ns = cg.esphome_ns.namespace("esp32_ble") @@ -174,19 +202,18 @@ CONFIG_SCHEMA = cv.Schema( cv.Optional( CONF_ADVERTISING_CYCLE_TIME, default="10s" ): cv.positive_time_period_milliseconds, - cv.SplitDefault(CONF_DISABLE_BT_LOGS, esp32_idf=True): cv.All( - cv.only_with_esp_idf, cv.boolean - ), - cv.SplitDefault(CONF_CONNECTION_TIMEOUT, esp32_idf="20s"): cv.All( - cv.only_with_esp_idf, + cv.Optional(CONF_DISABLE_BT_LOGS, default=True): cv.boolean, + cv.Optional(CONF_CONNECTION_TIMEOUT, default="20s"): cv.All( cv.positive_time_period_seconds, cv.Range(min=TimePeriod(seconds=10), max=TimePeriod(seconds=180)), ), - cv.SplitDefault(CONF_MAX_NOTIFICATIONS, esp32_idf=12): cv.All( - cv.only_with_esp_idf, + cv.Optional(CONF_MAX_NOTIFICATIONS, default=12): cv.All( cv.positive_int, cv.Range(min=1, max=64), ), + cv.Optional(CONF_MAX_CONNECTIONS, default=DEFAULT_MAX_CONNECTIONS): cv.All( + cv.positive_int, cv.Range(min=1, max=IDF_MAX_CONNECTIONS) + ), } ).extend(cv.COMPONENT_SCHEMA) @@ -234,6 +261,56 @@ def validate_variant(_): raise cv.Invalid(f"{variant} does not support Bluetooth") +def consume_connection_slots( + value: int, consumer: str +) -> Callable[[MutableMapping], MutableMapping]: + """Reserve BLE connection slots for a component. + + Args: + value: Number of connection slots to reserve + consumer: Name of the component consuming the slots + + Returns: + A validator function that records the slot usage + """ + + def _consume_connection_slots(config: MutableMapping) -> MutableMapping: + data: dict[str, Any] = CORE.data.setdefault(KEY_ESP32_BLE, {}) + slots: list[str] = data.setdefault(KEY_USED_CONNECTION_SLOTS, []) + slots.extend([consumer] * value) + return config + + return _consume_connection_slots + + +def validate_connection_slots(max_connections: int) -> None: + """Validate that BLE connection slots don't exceed the configured maximum.""" + ble_data = CORE.data.get(KEY_ESP32_BLE, {}) + used_slots = ble_data.get(KEY_USED_CONNECTION_SLOTS, []) + num_used = len(used_slots) + + if num_used <= max_connections: + return + + slot_users = ", ".join(used_slots) + + if num_used > IDF_MAX_CONNECTIONS: + raise cv.Invalid( + f"BLE components require {num_used} connection slots but maximum is {IDF_MAX_CONNECTIONS}. " + f"Reduce the number of BLE clients. Components: {slot_users}" + ) + + _LOGGER.warning( + "BLE components require %d connection slot(s) but only %d configured. " + "Please set 'max_connections: %d' in the 'esp32_ble' component. " + "Components: %s", + num_used, + max_connections, + num_used, + slot_users, + ) + + def final_validation(config): validate_variant(config) if (name := config.get(CONF_NAME)) is not None: @@ -246,6 +323,43 @@ def final_validation(config): f"Name '{name}' is too long, maximum length is {max_length} characters" ) + # Set GATT Client/Server sdkconfig options based on which components are loaded + full_config = fv.full_config.get() + + # Validate connection slots usage + max_connections = config.get(CONF_MAX_CONNECTIONS, DEFAULT_MAX_CONNECTIONS) + validate_connection_slots(max_connections) + + # Check if BLE Server is needed + has_ble_server = "esp32_ble_server" in full_config + add_idf_sdkconfig_option("CONFIG_BT_GATTS_ENABLE", has_ble_server) + + # Check if BLE Client is needed (via esp32_ble_tracker or esp32_ble_client) + has_ble_client = ( + "esp32_ble_tracker" in full_config or "esp32_ble_client" in full_config + ) + add_idf_sdkconfig_option("CONFIG_BT_GATTC_ENABLE", has_ble_client) + + # Handle max_connections: check for deprecated location in esp32_ble_tracker + max_connections = config.get(CONF_MAX_CONNECTIONS, DEFAULT_MAX_CONNECTIONS) + + # Use value from tracker if esp32_ble doesn't have it explicitly set (backward compat) + if "esp32_ble_tracker" in full_config: + tracker_config = full_config["esp32_ble_tracker"] + if "max_connections" in tracker_config and CONF_MAX_CONNECTIONS not in config: + max_connections = tracker_config["max_connections"] + + # Set CONFIG_BT_ACL_CONNECTIONS to the maximum connections needed + 1 for ADV/SCAN + # This is the Bluedroid host stack total instance limit (range 1-9, default 4) + # Total instances = ADV/SCAN (1) + connection slots (max_connections) + # Shared between client (tracker/ble_client) and server + add_idf_sdkconfig_option("CONFIG_BT_ACL_CONNECTIONS", max_connections + 1) + + # Set controller-specific max connections for ESP32 (classic) + # CONFIG_BTDM_CTRL_BLE_MAX_CONN is ESP32-specific controller limit (just connections, not ADV/SCAN) + # For newer chips (C3/S3/etc), different configs are used automatically + add_idf_sdkconfig_option("CONFIG_BTDM_CTRL_BLE_MAX_CONN", max_connections) + return config @@ -261,43 +375,44 @@ async def to_code(config): cg.add(var.set_name(name)) await cg.register_component(var, config) - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) - add_idf_sdkconfig_option("CONFIG_BT_BLE_42_FEATURES_SUPPORTED", True) + # Define max connections for use in C++ code (e.g., ble_server.h) + max_connections = config.get(CONF_MAX_CONNECTIONS, DEFAULT_MAX_CONNECTIONS) + cg.add_define("USE_ESP32_BLE_MAX_CONNECTIONS", max_connections) - # Register the core BLE loggers that are always needed - register_bt_logger(BTLoggers.GAP, BTLoggers.BTM, BTLoggers.HCI) + add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) + add_idf_sdkconfig_option("CONFIG_BT_BLE_42_FEATURES_SUPPORTED", True) - # Apply logger settings if log disabling is enabled - if config.get(CONF_DISABLE_BT_LOGS, False): - # Disable all Bluetooth loggers that are not required - for logger in BTLoggers: - if logger not in _required_loggers: - add_idf_sdkconfig_option(f"{logger.value}_NONE", True) + # Register the core BLE loggers that are always needed + register_bt_logger(BTLoggers.GAP, BTLoggers.BTM, BTLoggers.HCI) - # Set BLE connection establishment timeout to match aioesphomeapi/bleak-retry-connector - # Default is 20 seconds instead of ESP-IDF's 30 seconds. Because there is no way to - # cancel a BLE connection in progress, when aioesphomeapi times out at 20 seconds, - # the connection slot remains occupied for the remaining time, preventing new connection - # attempts and wasting valuable connection slots. - if CONF_CONNECTION_TIMEOUT in config: - timeout_seconds = int(config[CONF_CONNECTION_TIMEOUT].total_seconds) - add_idf_sdkconfig_option( - "CONFIG_BT_BLE_ESTAB_LINK_CONN_TOUT", timeout_seconds - ) - # Increase GATT client connection retry count for problematic devices - # Default in ESP-IDF is 3, we increase to 10 for better reliability with - # low-power/timing-sensitive devices - add_idf_sdkconfig_option("CONFIG_BT_GATTC_CONNECT_RETRY_COUNT", 10) + # Apply logger settings if log disabling is enabled + if config.get(CONF_DISABLE_BT_LOGS, False): + # Disable all Bluetooth loggers that are not required + for logger in BTLoggers: + if logger not in _required_loggers: + add_idf_sdkconfig_option(f"{logger.value}_NONE", True) - # Set the maximum number of notification registrations - # This controls how many BLE characteristics can have notifications enabled - # across all connections for a single GATT client interface - # https://github.com/esphome/issues/issues/6808 - if CONF_MAX_NOTIFICATIONS in config: - add_idf_sdkconfig_option( - "CONFIG_BT_GATTC_NOTIF_REG_MAX", config[CONF_MAX_NOTIFICATIONS] - ) + # Set BLE connection establishment timeout to match aioesphomeapi/bleak-retry-connector + # Default is 20 seconds instead of ESP-IDF's 30 seconds. Because there is no way to + # cancel a BLE connection in progress, when aioesphomeapi times out at 20 seconds, + # the connection slot remains occupied for the remaining time, preventing new connection + # attempts and wasting valuable connection slots. + if CONF_CONNECTION_TIMEOUT in config: + timeout_seconds = int(config[CONF_CONNECTION_TIMEOUT].total_seconds) + add_idf_sdkconfig_option("CONFIG_BT_BLE_ESTAB_LINK_CONN_TOUT", timeout_seconds) + # Increase GATT client connection retry count for problematic devices + # Default in ESP-IDF is 3, we increase to 10 for better reliability with + # low-power/timing-sensitive devices + add_idf_sdkconfig_option("CONFIG_BT_GATTC_CONNECT_RETRY_COUNT", 10) + + # Set the maximum number of notification registrations + # This controls how many BLE characteristics can have notifications enabled + # across all connections for a single GATT client interface + # https://github.com/esphome/issues/issues/6808 + if CONF_MAX_NOTIFICATIONS in config: + add_idf_sdkconfig_option( + "CONFIG_BT_GATTC_NOTIF_REG_MAX", config[CONF_MAX_NOTIFICATIONS] + ) cg.add_define("USE_ESP32_BLE") diff --git a/esphome/components/esp32_ble/ble.cpp b/esphome/components/esp32_ble/ble.cpp index e22d43c0cc..0c340c55cc 100644 --- a/esphome/components/esp32_ble/ble.cpp +++ b/esphome/components/esp32_ble/ble.cpp @@ -73,6 +73,28 @@ void ESP32BLE::advertising_set_manufacturer_data(const std::vector &dat this->advertising_start(); } +void ESP32BLE::advertising_set_service_data_and_name(std::span data, bool include_name) { + // This method atomically updates both service data and device name inclusion in BLE advertising. + // When include_name is true, the device name is included in the advertising packet making it + // visible to passive BLE scanners. When false, the name is only visible in scan response + // (requires active scanning). This atomic operation ensures we only restart advertising once + // when changing both properties, avoiding the brief gap that would occur with separate calls. + + this->advertising_init_(); + + if (include_name) { + // When including name, clear service data first to avoid packet overflow + this->advertising_->set_service_data(std::span{}); + this->advertising_->set_include_name(true); + } else { + // When including service data, clear name first to avoid packet overflow + this->advertising_->set_include_name(false); + this->advertising_->set_service_data(data); + } + + this->advertising_start(); +} + void ESP32BLE::advertising_register_raw_advertisement_callback(std::function &&callback) { this->advertising_init_(); this->advertising_->register_raw_advertisement_callback(std::move(callback)); @@ -167,6 +189,7 @@ bool ESP32BLE::ble_setup_() { } } +#ifdef USE_ESP32_BLE_SERVER if (!this->gatts_event_handlers_.empty()) { err = esp_ble_gatts_register_callback(ESP32BLE::gatts_event_handler); if (err != ESP_OK) { @@ -174,7 +197,9 @@ bool ESP32BLE::ble_setup_() { return false; } } +#endif +#ifdef USE_ESP32_BLE_CLIENT if (!this->gattc_event_handlers_.empty()) { err = esp_ble_gattc_register_callback(ESP32BLE::gattc_event_handler); if (err != ESP_OK) { @@ -182,20 +207,23 @@ bool ESP32BLE::ble_setup_() { return false; } } +#endif std::string name; if (this->name_.has_value()) { name = this->name_.value(); if (App.is_name_add_mac_suffix_enabled()) { - name += "-" + get_mac_address().substr(6); + name += "-"; + name += get_mac_address().substr(6); } } else { name = App.get_name(); if (name.length() > 20) { if (App.is_name_add_mac_suffix_enabled()) { - name.erase(name.begin() + 13, name.end() - 7); // Remove characters between 13 and the mac address + // Keep first 13 chars and last 7 chars (MAC suffix), remove middle + name.erase(13, name.length() - 20); } else { - name = name.substr(0, 20); + name.resize(20); } } } @@ -303,6 +331,7 @@ void ESP32BLE::loop() { BLEEvent *ble_event = this->ble_events_.pop(); while (ble_event != nullptr) { switch (ble_event->type_) { +#ifdef USE_ESP32_BLE_SERVER case BLEEvent::GATTS: { esp_gatts_cb_event_t event = ble_event->event_.gatts.gatts_event; esp_gatt_if_t gatts_if = ble_event->event_.gatts.gatts_if; @@ -313,6 +342,8 @@ void ESP32BLE::loop() { } break; } +#endif +#ifdef USE_ESP32_BLE_CLIENT case BLEEvent::GATTC: { esp_gattc_cb_event_t event = ble_event->event_.gattc.gattc_event; esp_gatt_if_t gattc_if = ble_event->event_.gattc.gattc_if; @@ -323,6 +354,7 @@ void ESP32BLE::loop() { } break; } +#endif case BLEEvent::GAP: { esp_gap_ble_cb_event_t gap_event = ble_event->event_.gap.gap_event; switch (gap_event) { @@ -416,13 +448,17 @@ void load_ble_event(BLEEvent *event, esp_gap_ble_cb_event_t e, esp_ble_gap_cb_pa event->load_gap_event(e, p); } +#ifdef USE_ESP32_BLE_CLIENT void load_ble_event(BLEEvent *event, esp_gattc_cb_event_t e, esp_gatt_if_t i, esp_ble_gattc_cb_param_t *p) { event->load_gattc_event(e, i, p); } +#endif +#ifdef USE_ESP32_BLE_SERVER void load_ble_event(BLEEvent *event, esp_gatts_cb_event_t e, esp_gatt_if_t i, esp_ble_gatts_cb_param_t *p) { event->load_gatts_event(e, i, p); } +#endif template void enqueue_ble_event(Args... args) { // Allocate an event from the pool @@ -443,8 +479,12 @@ template void enqueue_ble_event(Args... args) { // Explicit template instantiations for the friend function template void enqueue_ble_event(esp_gap_ble_cb_event_t, esp_ble_gap_cb_param_t *); +#ifdef USE_ESP32_BLE_SERVER template void enqueue_ble_event(esp_gatts_cb_event_t, esp_gatt_if_t, esp_ble_gatts_cb_param_t *); +#endif +#ifdef USE_ESP32_BLE_CLIENT template void enqueue_ble_event(esp_gattc_cb_event_t, esp_gatt_if_t, esp_ble_gattc_cb_param_t *); +#endif void ESP32BLE::gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_gap_cb_param_t *param) { switch (event) { @@ -484,15 +524,19 @@ void ESP32BLE::gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_gap_cb_pa ESP_LOGW(TAG, "Ignoring unexpected GAP event type: %d", event); } +#ifdef USE_ESP32_BLE_SERVER void ESP32BLE::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t gatts_if, esp_ble_gatts_cb_param_t *param) { enqueue_ble_event(event, gatts_if, param); } +#endif +#ifdef USE_ESP32_BLE_CLIENT void ESP32BLE::gattc_event_handler(esp_gattc_cb_event_t event, esp_gatt_if_t gattc_if, esp_ble_gattc_cb_param_t *param) { enqueue_ble_event(event, gattc_if, param); } +#endif float ESP32BLE::get_setup_priority() const { return setup_priority::BLUETOOTH; } diff --git a/esphome/components/esp32_ble/ble.h b/esphome/components/esp32_ble/ble.h index 712787fe53..1aa3bc86ef 100644 --- a/esphome/components/esp32_ble/ble.h +++ b/esphome/components/esp32_ble/ble.h @@ -9,6 +9,7 @@ #endif #include +#include #include "esphome/core/automation.h" #include "esphome/core/component.h" @@ -74,17 +75,21 @@ class GAPScanEventHandler { virtual void gap_scan_event_handler(const BLEScanResult &scan_result) = 0; }; +#ifdef USE_ESP32_BLE_CLIENT class GATTcEventHandler { public: virtual void gattc_event_handler(esp_gattc_cb_event_t event, esp_gatt_if_t gattc_if, esp_ble_gattc_cb_param_t *param) = 0; }; +#endif +#ifdef USE_ESP32_BLE_SERVER class GATTsEventHandler { public: virtual void gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t gatts_if, esp_ble_gatts_cb_param_t *param) = 0; }; +#endif class BLEStatusEventHandler { public: @@ -114,6 +119,7 @@ class ESP32BLE : public Component { void advertising_set_service_data(const std::vector &data); void advertising_set_manufacturer_data(const std::vector &data); void advertising_set_appearance(uint16_t appearance) { this->appearance_ = appearance; } + void advertising_set_service_data_and_name(std::span data, bool include_name); void advertising_add_service_uuid(ESPBTUUID uuid); void advertising_remove_service_uuid(ESPBTUUID uuid); void advertising_register_raw_advertisement_callback(std::function &&callback); @@ -123,16 +129,24 @@ class ESP32BLE : public Component { void register_gap_scan_event_handler(GAPScanEventHandler *handler) { this->gap_scan_event_handlers_.push_back(handler); } +#ifdef USE_ESP32_BLE_CLIENT void register_gattc_event_handler(GATTcEventHandler *handler) { this->gattc_event_handlers_.push_back(handler); } +#endif +#ifdef USE_ESP32_BLE_SERVER void register_gatts_event_handler(GATTsEventHandler *handler) { this->gatts_event_handlers_.push_back(handler); } +#endif void register_ble_status_event_handler(BLEStatusEventHandler *handler) { this->ble_status_event_handlers_.push_back(handler); } void set_enable_on_boot(bool enable_on_boot) { this->enable_on_boot_ = enable_on_boot; } protected: +#ifdef USE_ESP32_BLE_SERVER static void gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t gatts_if, esp_ble_gatts_cb_param_t *param); +#endif +#ifdef USE_ESP32_BLE_CLIENT static void gattc_event_handler(esp_gattc_cb_event_t event, esp_gatt_if_t gattc_if, esp_ble_gattc_cb_param_t *param); +#endif static void gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_gap_cb_param_t *param); bool ble_setup_(); @@ -148,8 +162,12 @@ class ESP32BLE : public Component { // Vectors (12 bytes each on 32-bit, naturally aligned to 4 bytes) std::vector gap_event_handlers_; std::vector gap_scan_event_handlers_; +#ifdef USE_ESP32_BLE_CLIENT std::vector gattc_event_handlers_; +#endif +#ifdef USE_ESP32_BLE_SERVER std::vector gatts_event_handlers_; +#endif std::vector ble_status_event_handlers_; // Large objects (size depends on template parameters, but typically aligned to 4 bytes) diff --git a/esphome/components/esp32_ble/ble_advertising.cpp b/esphome/components/esp32_ble/ble_advertising.cpp index d8b9b1cc36..68704e49e2 100644 --- a/esphome/components/esp32_ble/ble_advertising.cpp +++ b/esphome/components/esp32_ble/ble_advertising.cpp @@ -43,7 +43,7 @@ void BLEAdvertising::remove_service_uuid(ESPBTUUID uuid) { this->advertising_uuids_.end()); } -void BLEAdvertising::set_service_data(const std::vector &data) { +void BLEAdvertising::set_service_data(std::span data) { delete[] this->advertising_data_.p_service_data; this->advertising_data_.p_service_data = nullptr; this->advertising_data_.service_data_len = data.size(); @@ -54,6 +54,10 @@ void BLEAdvertising::set_service_data(const std::vector &data) { } } +void BLEAdvertising::set_service_data(const std::vector &data) { + this->set_service_data(std::span(data)); +} + void BLEAdvertising::set_manufacturer_data(const std::vector &data) { delete[] this->advertising_data_.p_manufacturer_data; this->advertising_data_.p_manufacturer_data = nullptr; @@ -84,7 +88,7 @@ esp_err_t BLEAdvertising::services_advertisement_() { esp_err_t err; this->advertising_data_.set_scan_rsp = false; - this->advertising_data_.include_name = !this->scan_response_; + this->advertising_data_.include_name = this->include_name_in_adv_ || !this->scan_response_; this->advertising_data_.include_txpower = !this->scan_response_; err = esp_ble_gap_config_adv_data(&this->advertising_data_); if (err != ESP_OK) { @@ -148,7 +152,7 @@ void BLEAdvertising::loop() { if (now - this->last_advertisement_time_ > this->advertising_cycle_time_) { this->stop(); this->current_adv_index_ += 1; - if (this->current_adv_index_ >= this->raw_advertisements_callbacks_.size()) { + if (static_cast(this->current_adv_index_) >= this->raw_advertisements_callbacks_.size()) { this->current_adv_index_ = -1; } this->start(); diff --git a/esphome/components/esp32_ble/ble_advertising.h b/esphome/components/esp32_ble/ble_advertising.h index e373554ea9..7a31d926f6 100644 --- a/esphome/components/esp32_ble/ble_advertising.h +++ b/esphome/components/esp32_ble/ble_advertising.h @@ -4,6 +4,7 @@ #include #include +#include #include #ifdef USE_ESP32 @@ -36,6 +37,8 @@ class BLEAdvertising { void set_manufacturer_data(const std::vector &data); void set_appearance(uint16_t appearance) { this->advertising_data_.appearance = appearance; } void set_service_data(const std::vector &data); + void set_service_data(std::span data); + void set_include_name(bool include_name) { this->include_name_in_adv_ = include_name; } void register_raw_advertisement_callback(std::function &&callback); void start(); @@ -45,6 +48,7 @@ class BLEAdvertising { esp_err_t services_advertisement_(); bool scan_response_; + bool include_name_in_adv_{false}; esp_ble_adv_data_t advertising_data_; esp_ble_adv_data_t scan_response_data_; esp_ble_adv_params_t advertising_params_; diff --git a/esphome/components/esp32_ble/ble_uuid.cpp b/esphome/components/esp32_ble/ble_uuid.cpp index 5f83e2ba0b..dcbb285e07 100644 --- a/esphome/components/esp32_ble/ble_uuid.cpp +++ b/esphome/components/esp32_ble/ble_uuid.cpp @@ -42,32 +42,18 @@ ESPBTUUID ESPBTUUID::from_raw_reversed(const uint8_t *data) { ESPBTUUID ESPBTUUID::from_raw(const std::string &data) { ESPBTUUID ret; if (data.length() == 4) { - ret.uuid_.len = ESP_UUID_LEN_16; - ret.uuid_.uuid.uuid16 = 0; - for (uint i = 0; i < data.length(); i += 2) { - uint8_t msb = data.c_str()[i]; - uint8_t lsb = data.c_str()[i + 1]; - uint8_t lsb_shift = i <= 2 ? (2 - i) * 4 : 0; - - if (msb > '9') - msb -= 7; - if (lsb > '9') - lsb -= 7; - ret.uuid_.uuid.uuid16 += (((msb & 0x0F) << 4) | (lsb & 0x0F)) << lsb_shift; + // 16-bit UUID as 4-character hex string + auto parsed = parse_hex(data); + if (parsed.has_value()) { + ret.uuid_.len = ESP_UUID_LEN_16; + ret.uuid_.uuid.uuid16 = parsed.value(); } } else if (data.length() == 8) { - ret.uuid_.len = ESP_UUID_LEN_32; - ret.uuid_.uuid.uuid32 = 0; - for (uint i = 0; i < data.length(); i += 2) { - uint8_t msb = data.c_str()[i]; - uint8_t lsb = data.c_str()[i + 1]; - uint8_t lsb_shift = i <= 6 ? (6 - i) * 4 : 0; - - if (msb > '9') - msb -= 7; - if (lsb > '9') - lsb -= 7; - ret.uuid_.uuid.uuid32 += (((msb & 0x0F) << 4) | (lsb & 0x0F)) << lsb_shift; + // 32-bit UUID as 8-character hex string + auto parsed = parse_hex(data); + if (parsed.has_value()) { + ret.uuid_.len = ESP_UUID_LEN_32; + ret.uuid_.uuid.uuid32 = parsed.value(); } } else if (data.length() == 16) { // how we can have 16 byte length string reprezenting 128 bit uuid??? needs to be // investigated (lack of time) @@ -145,28 +131,16 @@ bool ESPBTUUID::operator==(const ESPBTUUID &uuid) const { if (this->uuid_.len == uuid.uuid_.len) { switch (this->uuid_.len) { case ESP_UUID_LEN_16: - if (uuid.uuid_.uuid.uuid16 == this->uuid_.uuid.uuid16) { - return true; - } - break; + return this->uuid_.uuid.uuid16 == uuid.uuid_.uuid.uuid16; case ESP_UUID_LEN_32: - if (uuid.uuid_.uuid.uuid32 == this->uuid_.uuid.uuid32) { - return true; - } - break; + return this->uuid_.uuid.uuid32 == uuid.uuid_.uuid.uuid32; case ESP_UUID_LEN_128: - for (uint8_t i = 0; i < ESP_UUID_LEN_128; i++) { - if (uuid.uuid_.uuid.uuid128[i] != this->uuid_.uuid.uuid128[i]) { - return false; - } - } - return true; - break; + return memcmp(this->uuid_.uuid.uuid128, uuid.uuid_.uuid.uuid128, ESP_UUID_LEN_128) == 0; + default: + return false; } - } else { - return this->as_128bit() == uuid.as_128bit(); } - return false; + return this->as_128bit() == uuid.as_128bit(); } esp_bt_uuid_t ESPBTUUID::get_uuid() const { return this->uuid_; } std::string ESPBTUUID::to_string() const { diff --git a/esphome/components/esp32_ble_beacon/__init__.py b/esphome/components/esp32_ble_beacon/__init__.py index 8fc4fe941d..794f5637a4 100644 --- a/esphome/components/esp32_ble_beacon/__init__.py +++ b/esphome/components/esp32_ble_beacon/__init__.py @@ -4,7 +4,7 @@ from esphome.components.esp32 import add_idf_sdkconfig_option from esphome.components.esp32_ble import CONF_BLE_ID import esphome.config_validation as cv from esphome.const import CONF_ID, CONF_TX_POWER, CONF_TYPE, CONF_UUID -from esphome.core import CORE, TimePeriod +from esphome.core import TimePeriod AUTO_LOAD = ["esp32_ble"] DEPENDENCIES = ["esp32"] @@ -86,6 +86,5 @@ async def to_code(config): cg.add_define("USE_ESP32_BLE_ADVERTISING") - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) - add_idf_sdkconfig_option("CONFIG_BT_BLE_42_FEATURES_SUPPORTED", True) + add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) + add_idf_sdkconfig_option("CONFIG_BT_BLE_42_FEATURES_SUPPORTED", True) diff --git a/esphome/components/esp32_ble_client/ble_client_base.cpp b/esphome/components/esp32_ble_client/ble_client_base.cpp index af5162afb0..18321ef91c 100644 --- a/esphome/components/esp32_ble_client/ble_client_base.cpp +++ b/esphome/components/esp32_ble_client/ble_client_base.cpp @@ -43,13 +43,6 @@ void BLEClientBase::setup() { void BLEClientBase::set_state(espbt::ClientState st) { ESP_LOGV(TAG, "[%d] [%s] Set state %d", this->connection_index_, this->address_str_.c_str(), (int) st); ESPBTClient::set_state(st); - - if (st == espbt::ClientState::READY_TO_CONNECT) { - // Enable loop for state processing - this->enable_loop(); - // Connect immediately instead of waiting for next loop - this->connect(); - } } void BLEClientBase::loop() { @@ -65,8 +58,8 @@ void BLEClientBase::loop() { } this->set_state(espbt::ClientState::IDLE); } - // If its idle, we can disable the loop as set_state - // will enable it again when we need to connect. + // If idle, we can disable the loop as connect() + // will enable it again when a connection is needed. else if (this->state_ == espbt::ClientState::IDLE) { this->disable_loop(); } @@ -108,9 +101,20 @@ bool BLEClientBase::parse_device(const espbt::ESPBTDevice &device) { #endif void BLEClientBase::connect() { + // Prevent duplicate connection attempts + if (this->state_ == espbt::ClientState::CONNECTING || this->state_ == espbt::ClientState::CONNECTED || + this->state_ == espbt::ClientState::ESTABLISHED) { + ESP_LOGW(TAG, "[%d] [%s] Connection already in progress, state=%s", this->connection_index_, + this->address_str_.c_str(), espbt::client_state_to_string(this->state_)); + return; + } ESP_LOGI(TAG, "[%d] [%s] 0x%02x Connecting", this->connection_index_, this->address_str_.c_str(), this->remote_addr_type_); this->paired_ = false; + // Enable loop for state processing + this->enable_loop(); + // Immediately transition to CONNECTING to prevent duplicate connection attempts + this->set_state(espbt::ClientState::CONNECTING); // Determine connection parameters based on connection type if (this->connection_type_ == espbt::ConnectionType::V3_WITHOUT_CACHE) { @@ -168,7 +172,7 @@ void BLEClientBase::unconditional_disconnect() { this->log_gattc_warning_("esp_ble_gattc_close", err); } - if (this->state_ == espbt::ClientState::READY_TO_CONNECT || this->state_ == espbt::ClientState::DISCOVERED) { + if (this->state_ == espbt::ClientState::DISCOVERED) { this->set_address(0); this->set_state(espbt::ClientState::IDLE); } else { @@ -212,8 +216,6 @@ void BLEClientBase::handle_connection_result_(esp_err_t ret) { if (ret) { this->log_gattc_warning_("esp_ble_gattc_open", ret); this->set_state(espbt::ClientState::IDLE); - } else { - this->set_state(espbt::ClientState::CONNECTING); } } diff --git a/esphome/components/esp32_ble_server/__init__.py b/esphome/components/esp32_ble_server/__init__.py index 8ddb15a7f8..10fa09fcc3 100644 --- a/esphome/components/esp32_ble_server/__init__.py +++ b/esphome/components/esp32_ble_server/__init__.py @@ -26,7 +26,7 @@ from esphome.const import ( from esphome.core import CORE from esphome.schema_extractors import SCHEMA_EXTRACT -AUTO_LOAD = ["esp32_ble", "bytebuffer", "event_emitter"] +AUTO_LOAD = ["esp32_ble", "bytebuffer"] CODEOWNERS = ["@jesserockz", "@clydebarrow", "@Rapsssito"] DEPENDENCIES = ["esp32"] DOMAIN = "esp32_ble_server" @@ -488,6 +488,7 @@ async def to_code_descriptor(descriptor_conf, char_var): cg.add(desc_var.set_value(value)) if CONF_ON_WRITE in descriptor_conf: on_write_conf = descriptor_conf[CONF_ON_WRITE] + cg.add_define("USE_ESP32_BLE_SERVER_DESCRIPTOR_ON_WRITE") await automation.build_automation( BLETriggers_ns.create_descriptor_on_write_trigger(desc_var), [(cg.std_vector.template(cg.uint8), "x"), (cg.uint16, "id")], @@ -505,23 +506,32 @@ async def to_code_characteristic(service_var, char_conf): ) if CONF_ON_WRITE in char_conf: on_write_conf = char_conf[CONF_ON_WRITE] + cg.add_define("USE_ESP32_BLE_SERVER_CHARACTERISTIC_ON_WRITE") await automation.build_automation( BLETriggers_ns.create_characteristic_on_write_trigger(char_var), [(cg.std_vector.template(cg.uint8), "x"), (cg.uint16, "id")], on_write_conf, ) if CONF_VALUE in char_conf: - action_conf = { - CONF_ID: char_conf[CONF_ID], - CONF_VALUE: char_conf[CONF_VALUE], - } - value_action = await ble_server_characteristic_set_value( - action_conf, - char_conf[CONF_CHAR_VALUE_ACTION_ID_], - cg.TemplateArguments(), - {}, - ) - cg.add(value_action.play()) + # Check if the value is templated (Lambda) + value_data = char_conf[CONF_VALUE][CONF_DATA] + if isinstance(value_data, cv.Lambda): + # Templated value - need the full action infrastructure + action_conf = { + CONF_ID: char_conf[CONF_ID], + CONF_VALUE: char_conf[CONF_VALUE], + } + value_action = await ble_server_characteristic_set_value( + action_conf, + char_conf[CONF_CHAR_VALUE_ACTION_ID_], + cg.TemplateArguments(), + {}, + ) + cg.add(value_action.play()) + else: + # Static value - just set it directly without action infrastructure + value = await parse_value(char_conf[CONF_VALUE], {}) + cg.add(char_var.set_value(value)) for descriptor_conf in char_conf[CONF_DESCRIPTORS]: await to_code_descriptor(descriptor_conf, char_var) @@ -560,12 +570,14 @@ async def to_code(config): else: cg.add(var.enqueue_start_service(service_var)) if CONF_ON_CONNECT in config: + cg.add_define("USE_ESP32_BLE_SERVER_ON_CONNECT") await automation.build_automation( BLETriggers_ns.create_server_on_connect_trigger(var), [(cg.uint16, "id")], config[CONF_ON_CONNECT], ) if CONF_ON_DISCONNECT in config: + cg.add_define("USE_ESP32_BLE_SERVER_ON_DISCONNECT") await automation.build_automation( BLETriggers_ns.create_server_on_disconnect_trigger(var), [(cg.uint16, "id")], @@ -573,8 +585,7 @@ async def to_code(config): ) cg.add_define("USE_ESP32_BLE_SERVER") cg.add_define("USE_ESP32_BLE_ADVERTISING") - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) + add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) @automation.register_action( @@ -595,6 +606,7 @@ async def ble_server_characteristic_set_value(config, action_id, template_arg, a var = cg.new_Pvariable(action_id, template_arg, paren) value = await parse_value(config[CONF_VALUE], args) cg.add(var.set_buffer(value)) + cg.add_define("USE_ESP32_BLE_SERVER_SET_VALUE_ACTION") return var @@ -613,6 +625,7 @@ async def ble_server_descriptor_set_value(config, action_id, template_arg, args) var = cg.new_Pvariable(action_id, template_arg, paren) value = await parse_value(config[CONF_VALUE], args) cg.add(var.set_buffer(value)) + cg.add_define("USE_ESP32_BLE_SERVER_DESCRIPTOR_SET_VALUE_ACTION") return var @@ -630,4 +643,5 @@ async def ble_server_descriptor_set_value(config, action_id, template_arg, args) ) async def ble_server_characteristic_notify(config, action_id, template_arg, args): paren = await cg.get_variable(config[CONF_ID]) + cg.add_define("USE_ESP32_BLE_SERVER_NOTIFY_ACTION") return cg.new_Pvariable(action_id, template_arg, paren) diff --git a/esphome/components/esp32_ble_server/ble_characteristic.cpp b/esphome/components/esp32_ble_server/ble_characteristic.cpp index 373d57436e..87f562a250 100644 --- a/esphome/components/esp32_ble_server/ble_characteristic.cpp +++ b/esphome/components/esp32_ble_server/ble_characteristic.cpp @@ -49,13 +49,17 @@ void BLECharacteristic::notify() { this->service_->get_server()->get_connected_client_count() == 0) return; - for (auto &client : this->service_->get_server()->get_clients()) { + const uint16_t *clients = this->service_->get_server()->get_clients(); + uint8_t client_count = this->service_->get_server()->get_client_count(); + + for (uint8_t i = 0; i < client_count; i++) { + uint16_t client = clients[i]; size_t length = this->value_.size(); - // If the client is not in the list of clients to notify, skip it - if (this->clients_to_notify_.count(client) == 0) + // Find the client in the list of clients to notify + auto *entry = this->find_client_in_notify_list_(client); + if (entry == nullptr) continue; - // If the client is in the list of clients to notify, check if it requires an ack (i.e. INDICATE) - bool require_ack = this->clients_to_notify_[client]; + bool require_ack = entry->indicate; // TODO: Remove this block when INDICATE acknowledgment is supported if (require_ack) { ESP_LOGW(TAG, "INDICATE acknowledgment is not yet supported (i.e. it works as a NOTIFY)"); @@ -73,16 +77,17 @@ void BLECharacteristic::notify() { void BLECharacteristic::add_descriptor(BLEDescriptor *descriptor) { // If the descriptor is the CCCD descriptor, listen to its write event to know if the client wants to be notified if (descriptor->get_uuid() == ESPBTUUID::from_uint16(ESP_GATT_UUID_CHAR_CLIENT_CONFIG)) { - descriptor->on(BLEDescriptorEvt::VectorEvt::ON_WRITE, [this](const std::vector &value, uint16_t conn_id) { + descriptor->on_write([this](std::span value, uint16_t conn_id) { if (value.size() != 2) return; uint16_t cccd = encode_uint16(value[1], value[0]); bool notify = (cccd & 1) != 0; bool indicate = (cccd & 2) != 0; + // Remove existing entry if present + this->remove_client_from_notify_list_(conn_id); + // Add new entry if needed if (notify || indicate) { - this->clients_to_notify_[conn_id] = indicate; - } else { - this->clients_to_notify_.erase(conn_id); + this->clients_to_notify_.push_back({conn_id, indicate}); } }); } @@ -120,69 +125,49 @@ bool BLECharacteristic::is_created() { if (this->state_ != CREATING_DEPENDENTS) return false; - bool created = true; for (auto *descriptor : this->descriptors_) { - created &= descriptor->is_created(); + if (!descriptor->is_created()) + return false; } - if (created) - this->state_ = CREATED; - return this->state_ == CREATED; + // All descriptors are created if we reach here + this->state_ = CREATED; + return true; } bool BLECharacteristic::is_failed() { if (this->state_ == FAILED) return true; - bool failed = false; for (auto *descriptor : this->descriptors_) { - failed |= descriptor->is_failed(); + if (descriptor->is_failed()) { + this->state_ = FAILED; + return true; + } + } + return false; +} + +void BLECharacteristic::set_property_bit_(esp_gatt_char_prop_t bit, bool value) { + if (value) { + this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | bit); + } else { + this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~bit); } - if (failed) - this->state_ = FAILED; - return this->state_ == FAILED; } void BLECharacteristic::set_broadcast_property(bool value) { - if (value) { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | ESP_GATT_CHAR_PROP_BIT_BROADCAST); - } else { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~ESP_GATT_CHAR_PROP_BIT_BROADCAST); - } + this->set_property_bit_(ESP_GATT_CHAR_PROP_BIT_BROADCAST, value); } void BLECharacteristic::set_indicate_property(bool value) { - if (value) { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | ESP_GATT_CHAR_PROP_BIT_INDICATE); - } else { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~ESP_GATT_CHAR_PROP_BIT_INDICATE); - } + this->set_property_bit_(ESP_GATT_CHAR_PROP_BIT_INDICATE, value); } void BLECharacteristic::set_notify_property(bool value) { - if (value) { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | ESP_GATT_CHAR_PROP_BIT_NOTIFY); - } else { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~ESP_GATT_CHAR_PROP_BIT_NOTIFY); - } -} -void BLECharacteristic::set_read_property(bool value) { - if (value) { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | ESP_GATT_CHAR_PROP_BIT_READ); - } else { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~ESP_GATT_CHAR_PROP_BIT_READ); - } -} -void BLECharacteristic::set_write_property(bool value) { - if (value) { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | ESP_GATT_CHAR_PROP_BIT_WRITE); - } else { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~ESP_GATT_CHAR_PROP_BIT_WRITE); - } + this->set_property_bit_(ESP_GATT_CHAR_PROP_BIT_NOTIFY, value); } +void BLECharacteristic::set_read_property(bool value) { this->set_property_bit_(ESP_GATT_CHAR_PROP_BIT_READ, value); } +void BLECharacteristic::set_write_property(bool value) { this->set_property_bit_(ESP_GATT_CHAR_PROP_BIT_WRITE, value); } void BLECharacteristic::set_write_no_response_property(bool value) { - if (value) { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ | ESP_GATT_CHAR_PROP_BIT_WRITE_NR); - } else { - this->properties_ = (esp_gatt_char_prop_t) (this->properties_ & ~ESP_GATT_CHAR_PROP_BIT_WRITE_NR); - } + this->set_property_bit_(ESP_GATT_CHAR_PROP_BIT_WRITE_NR, value); } void BLECharacteristic::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t gatts_if, @@ -207,8 +192,9 @@ void BLECharacteristic::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt if (!param->read.need_rsp) break; // For some reason you can request a read but not want a response - this->EventEmitter::emit_(BLECharacteristicEvt::EmptyEvt::ON_READ, - param->read.conn_id); + if (this->on_read_callback_) { + (*this->on_read_callback_)(param->read.conn_id); + } uint16_t max_offset = 22; @@ -276,8 +262,9 @@ void BLECharacteristic::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt } if (!param->write.is_prep) { - this->EventEmitter, uint16_t>::emit_( - BLECharacteristicEvt::VectorEvt::ON_WRITE, this->value_, param->write.conn_id); + if (this->on_write_callback_) { + (*this->on_write_callback_)(this->value_, param->write.conn_id); + } } break; @@ -288,8 +275,9 @@ void BLECharacteristic::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt break; this->write_event_ = false; if (param->exec_write.exec_write_flag == ESP_GATT_PREP_WRITE_EXEC) { - this->EventEmitter, uint16_t>::emit_( - BLECharacteristicEvt::VectorEvt::ON_WRITE, this->value_, param->exec_write.conn_id); + if (this->on_write_callback_) { + (*this->on_write_callback_)(this->value_, param->exec_write.conn_id); + } } esp_err_t err = esp_ble_gatts_send_response(gatts_if, param->write.conn_id, param->write.trans_id, ESP_GATT_OK, nullptr); @@ -307,6 +295,28 @@ void BLECharacteristic::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt } } +void BLECharacteristic::remove_client_from_notify_list_(uint16_t conn_id) { + // Since we typically have very few clients (often just 1), we can optimize + // for the common case by swapping with the last element and popping + for (size_t i = 0; i < this->clients_to_notify_.size(); i++) { + if (this->clients_to_notify_[i].conn_id == conn_id) { + // Swap with last element and pop (safe even when i is the last element) + this->clients_to_notify_[i] = this->clients_to_notify_.back(); + this->clients_to_notify_.pop_back(); + return; + } + } +} + +BLECharacteristic::ClientNotificationEntry *BLECharacteristic::find_client_in_notify_list_(uint16_t conn_id) { + for (auto &entry : this->clients_to_notify_) { + if (entry.conn_id == conn_id) { + return &entry; + } + } + return nullptr; +} + } // namespace esp32_ble_server } // namespace esphome diff --git a/esphome/components/esp32_ble_server/ble_characteristic.h b/esphome/components/esp32_ble_server/ble_characteristic.h index 3698b8c4aa..7cceec0ef1 100644 --- a/esphome/components/esp32_ble_server/ble_characteristic.h +++ b/esphome/components/esp32_ble_server/ble_characteristic.h @@ -2,11 +2,12 @@ #include "ble_descriptor.h" #include "esphome/components/esp32_ble/ble_uuid.h" -#include "esphome/components/event_emitter/event_emitter.h" #include "esphome/components/bytebuffer/bytebuffer.h" #include -#include +#include +#include +#include #ifdef USE_ESP32 @@ -23,22 +24,10 @@ namespace esp32_ble_server { using namespace esp32_ble; using namespace bytebuffer; -using namespace event_emitter; class BLEService; -namespace BLECharacteristicEvt { -enum VectorEvt { - ON_WRITE, -}; - -enum EmptyEvt { - ON_READ, -}; -} // namespace BLECharacteristicEvt - -class BLECharacteristic : public EventEmitter, uint16_t>, - public EventEmitter { +class BLECharacteristic { public: BLECharacteristic(ESPBTUUID uuid, uint32_t properties); ~BLECharacteristic(); @@ -77,6 +66,15 @@ class BLECharacteristic : public EventEmitter, uint16_t)> &&callback) { + this->on_write_callback_ = + std::make_unique, uint16_t)>>(std::move(callback)); + } + void on_read(std::function &&callback) { + this->on_read_callback_ = std::make_unique>(std::move(callback)); + } + protected: bool write_event_{false}; BLEService *service_{}; @@ -89,7 +87,20 @@ class BLECharacteristic : public EventEmitter descriptors_; - std::unordered_map clients_to_notify_; + + struct ClientNotificationEntry { + uint16_t conn_id; + bool indicate; // true = indicate, false = notify + }; + std::vector clients_to_notify_; + + void remove_client_from_notify_list_(uint16_t conn_id); + ClientNotificationEntry *find_client_in_notify_list_(uint16_t conn_id); + + void set_property_bit_(esp_gatt_char_prop_t bit, bool value); + + std::unique_ptr, uint16_t)>> on_write_callback_; + std::unique_ptr> on_read_callback_; esp_gatt_perm_t permissions_ = ESP_GATT_PERM_READ | ESP_GATT_PERM_WRITE; diff --git a/esphome/components/esp32_ble_server/ble_descriptor.cpp b/esphome/components/esp32_ble_server/ble_descriptor.cpp index afbe579513..16941cca0f 100644 --- a/esphome/components/esp32_ble_server/ble_descriptor.cpp +++ b/esphome/components/esp32_ble_server/ble_descriptor.cpp @@ -74,9 +74,10 @@ void BLEDescriptor::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_ break; this->value_.attr_len = param->write.len; memcpy(this->value_.attr_value, param->write.value, param->write.len); - this->emit_(BLEDescriptorEvt::VectorEvt::ON_WRITE, - std::vector(param->write.value, param->write.value + param->write.len), - param->write.conn_id); + if (this->on_write_callback_) { + (*this->on_write_callback_)(std::span(param->write.value, param->write.len), + param->write.conn_id); + } break; } default: diff --git a/esphome/components/esp32_ble_server/ble_descriptor.h b/esphome/components/esp32_ble_server/ble_descriptor.h index 8d3c22c5a1..425462a316 100644 --- a/esphome/components/esp32_ble_server/ble_descriptor.h +++ b/esphome/components/esp32_ble_server/ble_descriptor.h @@ -1,30 +1,26 @@ #pragma once #include "esphome/components/esp32_ble/ble_uuid.h" -#include "esphome/components/event_emitter/event_emitter.h" #include "esphome/components/bytebuffer/bytebuffer.h" #ifdef USE_ESP32 #include #include +#include +#include +#include namespace esphome { namespace esp32_ble_server { using namespace esp32_ble; using namespace bytebuffer; -using namespace event_emitter; class BLECharacteristic; -namespace BLEDescriptorEvt { -enum VectorEvt { - ON_WRITE, -}; -} // namespace BLEDescriptorEvt - -class BLEDescriptor : public EventEmitter, uint16_t> { +// Base class for BLE descriptors +class BLEDescriptor { public: BLEDescriptor(ESPBTUUID uuid, uint16_t max_len = 100, bool read = true, bool write = true); virtual ~BLEDescriptor(); @@ -39,6 +35,12 @@ class BLEDescriptor : public EventEmitterstate_ == CREATED; } bool is_failed() { return this->state_ == FAILED; } + // Direct callback registration - only allocates when callback is set + void on_write(std::function, uint16_t)> &&callback) { + this->on_write_callback_ = + std::make_unique, uint16_t)>>(std::move(callback)); + } + protected: BLECharacteristic *characteristic_{nullptr}; ESPBTUUID uuid_; @@ -46,6 +48,8 @@ class BLEDescriptor : public EventEmitter, uint16_t)>> on_write_callback_; + esp_gatt_perm_t permissions_{}; enum State : uint8_t { diff --git a/esphome/components/esp32_ble_server/ble_server.cpp b/esphome/components/esp32_ble_server/ble_server.cpp index 5339bf8aed..25cc97eeaf 100644 --- a/esphome/components/esp32_ble_server/ble_server.cpp +++ b/esphome/components/esp32_ble_server/ble_server.cpp @@ -70,11 +70,11 @@ void BLEServer::loop() { // it is at the top of the GATT table this->device_information_service_->do_create(this); // Create all services previously created - for (auto &pair : this->services_) { - if (pair.second == this->device_information_service_) { + for (auto &entry : this->services_) { + if (entry.service == this->device_information_service_) { continue; } - pair.second->do_create(this); + entry.service->do_create(this); } this->state_ = STARTING_SERVICE; } @@ -118,7 +118,7 @@ BLEService *BLEServer::create_service(ESPBTUUID uuid, bool advertise, uint16_t n } BLEService *service = // NOLINT(cppcoreguidelines-owning-memory) new BLEService(uuid, num_handles, inst_id, advertise); - this->services_.emplace(BLEServer::get_service_key(uuid, inst_id), service); + this->services_.push_back({uuid, inst_id, service}); if (this->parent_->is_active() && this->registered_) { service->do_create(this); } @@ -127,26 +127,32 @@ BLEService *BLEServer::create_service(ESPBTUUID uuid, bool advertise, uint16_t n void BLEServer::remove_service(ESPBTUUID uuid, uint8_t inst_id) { ESP_LOGV(TAG, "Removing BLE service - %s %d", uuid.to_string().c_str(), inst_id); - BLEService *service = this->get_service(uuid, inst_id); - if (service == nullptr) { - ESP_LOGW(TAG, "BLE service %s %d does not exist", uuid.to_string().c_str(), inst_id); - return; + for (auto it = this->services_.begin(); it != this->services_.end(); ++it) { + if (it->uuid == uuid && it->inst_id == inst_id) { + it->service->do_delete(); + delete it->service; // NOLINT(cppcoreguidelines-owning-memory) + this->services_.erase(it); + return; + } } - service->do_delete(); - delete service; // NOLINT(cppcoreguidelines-owning-memory) - this->services_.erase(BLEServer::get_service_key(uuid, inst_id)); + ESP_LOGW(TAG, "BLE service %s %d does not exist", uuid.to_string().c_str(), inst_id); } BLEService *BLEServer::get_service(ESPBTUUID uuid, uint8_t inst_id) { - BLEService *service = nullptr; - if (this->services_.count(BLEServer::get_service_key(uuid, inst_id)) > 0) { - service = this->services_.at(BLEServer::get_service_key(uuid, inst_id)); + for (auto &entry : this->services_) { + if (entry.uuid == uuid && entry.inst_id == inst_id) { + return entry.service; + } } - return service; + return nullptr; } -std::string BLEServer::get_service_key(ESPBTUUID uuid, uint8_t inst_id) { - return uuid.to_string() + std::to_string(inst_id); +void BLEServer::dispatch_callbacks_(CallbackType type, uint16_t conn_id) { + for (auto &entry : this->callbacks_) { + if (entry.type == type) { + entry.callback(conn_id); + } + } } void BLEServer::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t gatts_if, @@ -155,14 +161,14 @@ void BLEServer::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t ga case ESP_GATTS_CONNECT_EVT: { ESP_LOGD(TAG, "BLE Client connected"); this->add_client_(param->connect.conn_id); - this->emit_(BLEServerEvt::EmptyEvt::ON_CONNECT, param->connect.conn_id); + this->dispatch_callbacks_(CallbackType::ON_CONNECT, param->connect.conn_id); break; } case ESP_GATTS_DISCONNECT_EVT: { ESP_LOGD(TAG, "BLE Client disconnected"); this->remove_client_(param->disconnect.conn_id); this->parent_->advertising_start(); - this->emit_(BLEServerEvt::EmptyEvt::ON_DISCONNECT, param->disconnect.conn_id); + this->dispatch_callbacks_(CallbackType::ON_DISCONNECT, param->disconnect.conn_id); break; } case ESP_GATTS_REG_EVT: { @@ -174,17 +180,46 @@ void BLEServer::gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t ga break; } - for (const auto &pair : this->services_) { - pair.second->gatts_event_handler(event, gatts_if, param); + for (auto &entry : this->services_) { + entry.service->gatts_event_handler(event, gatts_if, param); + } +} + +int8_t BLEServer::find_client_index_(uint16_t conn_id) const { + for (uint8_t i = 0; i < this->client_count_; i++) { + if (this->clients_[i] == conn_id) + return i; + } + return -1; +} + +void BLEServer::add_client_(uint16_t conn_id) { + // Check if already in list + if (this->find_client_index_(conn_id) >= 0) + return; + // Add if there's space + if (this->client_count_ < USE_ESP32_BLE_MAX_CONNECTIONS) { + this->clients_[this->client_count_++] = conn_id; + } else { + // This should never happen since max clients is known at compile time + ESP_LOGE(TAG, "Client array full"); + } +} + +void BLEServer::remove_client_(uint16_t conn_id) { + int8_t index = this->find_client_index_(conn_id); + if (index >= 0) { + // Replace with last element and decrement count (client order not preserved) + this->clients_[index] = this->clients_[--this->client_count_]; } } void BLEServer::ble_before_disabled_event_handler() { // Delete all clients - this->clients_.clear(); + this->client_count_ = 0; // Delete all services - for (auto &pair : this->services_) { - pair.second->do_delete(); + for (auto &entry : this->services_) { + entry.service->do_delete(); } this->registered_ = false; this->state_ = INIT; diff --git a/esphome/components/esp32_ble_server/ble_server.h b/esphome/components/esp32_ble_server/ble_server.h index 531b52d6b9..6fa86dd67f 100644 --- a/esphome/components/esp32_ble_server/ble_server.h +++ b/esphome/components/esp32_ble_server/ble_server.h @@ -12,7 +12,7 @@ #include #include #include -#include +#include #ifdef USE_ESP32 @@ -24,18 +24,7 @@ namespace esp32_ble_server { using namespace esp32_ble; using namespace bytebuffer; -namespace BLEServerEvt { -enum EmptyEvt { - ON_CONNECT, - ON_DISCONNECT, -}; -} // namespace BLEServerEvt - -class BLEServer : public Component, - public GATTsEventHandler, - public BLEStatusEventHandler, - public Parented, - public EventEmitter { +class BLEServer : public Component, public GATTsEventHandler, public BLEStatusEventHandler, public Parented { public: void setup() override; void loop() override; @@ -57,27 +46,56 @@ class BLEServer : public Component, void set_device_information_service(BLEService *service) { this->device_information_service_ = service; } esp_gatt_if_t get_gatts_if() { return this->gatts_if_; } - uint32_t get_connected_client_count() { return this->clients_.size(); } - const std::unordered_set &get_clients() { return this->clients_; } + uint32_t get_connected_client_count() { return this->client_count_; } + const uint16_t *get_clients() const { return this->clients_; } + uint8_t get_client_count() const { return this->client_count_; } void gatts_event_handler(esp_gatts_cb_event_t event, esp_gatt_if_t gatts_if, esp_ble_gatts_cb_param_t *param) override; void ble_before_disabled_event_handler() override; + // Direct callback registration - supports multiple callbacks + void on_connect(std::function &&callback) { + this->callbacks_.push_back({CallbackType::ON_CONNECT, std::move(callback)}); + } + void on_disconnect(std::function &&callback) { + this->callbacks_.push_back({CallbackType::ON_DISCONNECT, std::move(callback)}); + } + protected: - static std::string get_service_key(ESPBTUUID uuid, uint8_t inst_id); + enum class CallbackType : uint8_t { + ON_CONNECT, + ON_DISCONNECT, + }; + + struct CallbackEntry { + CallbackType type; + std::function callback; + }; + + struct ServiceEntry { + ESPBTUUID uuid; + uint8_t inst_id; + BLEService *service; + }; + void restart_advertising_(); - void add_client_(uint16_t conn_id) { this->clients_.insert(conn_id); } - void remove_client_(uint16_t conn_id) { this->clients_.erase(conn_id); } + int8_t find_client_index_(uint16_t conn_id) const; + void add_client_(uint16_t conn_id); + void remove_client_(uint16_t conn_id); + void dispatch_callbacks_(CallbackType type, uint16_t conn_id); + + std::vector callbacks_; std::vector manufacturer_data_{}; esp_gatt_if_t gatts_if_{0}; bool registered_{false}; - std::unordered_set clients_; - std::unordered_map services_{}; + uint16_t clients_[USE_ESP32_BLE_MAX_CONNECTIONS]{}; + uint8_t client_count_{0}; + std::vector services_{}; std::vector services_to_start_{}; BLEService *device_information_service_{}; diff --git a/esphome/components/esp32_ble_server/ble_server_automations.cpp b/esphome/components/esp32_ble_server/ble_server_automations.cpp index 41ef2b8bfe..0761de994a 100644 --- a/esphome/components/esp32_ble_server/ble_server_automations.cpp +++ b/esphome/components/esp32_ble_server/ble_server_automations.cpp @@ -9,67 +9,83 @@ namespace esp32_ble_server_automations { using namespace esp32_ble; +#ifdef USE_ESP32_BLE_SERVER_CHARACTERISTIC_ON_WRITE Trigger, uint16_t> *BLETriggers::create_characteristic_on_write_trigger( BLECharacteristic *characteristic) { Trigger, uint16_t> *on_write_trigger = // NOLINT(cppcoreguidelines-owning-memory) new Trigger, uint16_t>(); - characteristic->EventEmitter, uint16_t>::on( - BLECharacteristicEvt::VectorEvt::ON_WRITE, - [on_write_trigger](const std::vector &data, uint16_t id) { on_write_trigger->trigger(data, id); }); + characteristic->on_write([on_write_trigger](std::span data, uint16_t id) { + // Convert span to vector for trigger + on_write_trigger->trigger(std::vector(data.begin(), data.end()), id); + }); return on_write_trigger; } +#endif +#ifdef USE_ESP32_BLE_SERVER_DESCRIPTOR_ON_WRITE Trigger, uint16_t> *BLETriggers::create_descriptor_on_write_trigger(BLEDescriptor *descriptor) { Trigger, uint16_t> *on_write_trigger = // NOLINT(cppcoreguidelines-owning-memory) new Trigger, uint16_t>(); - descriptor->on( - BLEDescriptorEvt::VectorEvt::ON_WRITE, - [on_write_trigger](const std::vector &data, uint16_t id) { on_write_trigger->trigger(data, id); }); + descriptor->on_write([on_write_trigger](std::span data, uint16_t id) { + // Convert span to vector for trigger + on_write_trigger->trigger(std::vector(data.begin(), data.end()), id); + }); return on_write_trigger; } +#endif +#ifdef USE_ESP32_BLE_SERVER_ON_CONNECT Trigger *BLETriggers::create_server_on_connect_trigger(BLEServer *server) { Trigger *on_connect_trigger = new Trigger(); // NOLINT(cppcoreguidelines-owning-memory) - server->on(BLEServerEvt::EmptyEvt::ON_CONNECT, - [on_connect_trigger](uint16_t conn_id) { on_connect_trigger->trigger(conn_id); }); + server->on_connect([on_connect_trigger](uint16_t conn_id) { on_connect_trigger->trigger(conn_id); }); return on_connect_trigger; } +#endif +#ifdef USE_ESP32_BLE_SERVER_ON_DISCONNECT Trigger *BLETriggers::create_server_on_disconnect_trigger(BLEServer *server) { Trigger *on_disconnect_trigger = new Trigger(); // NOLINT(cppcoreguidelines-owning-memory) - server->on(BLEServerEvt::EmptyEvt::ON_DISCONNECT, - [on_disconnect_trigger](uint16_t conn_id) { on_disconnect_trigger->trigger(conn_id); }); + server->on_disconnect([on_disconnect_trigger](uint16_t conn_id) { on_disconnect_trigger->trigger(conn_id); }); return on_disconnect_trigger; } +#endif +#ifdef USE_ESP32_BLE_SERVER_SET_VALUE_ACTION void BLECharacteristicSetValueActionManager::set_listener(BLECharacteristic *characteristic, - EventEmitterListenerID listener_id, const std::function &pre_notify_listener) { - // Check if there is already a listener for this characteristic - if (this->listeners_.count(characteristic) > 0) { - // Unpack the pair listener_id, pre_notify_listener_id - auto listener_pairs = this->listeners_[characteristic]; - EventEmitterListenerID old_listener_id = listener_pairs.first; - EventEmitterListenerID old_pre_notify_listener_id = listener_pairs.second; - // Remove the previous listener - characteristic->EventEmitter::off(BLECharacteristicEvt::EmptyEvt::ON_READ, - old_listener_id); - // Remove the pre-notify listener - this->off(BLECharacteristicSetValueActionEvt::PRE_NOTIFY, old_pre_notify_listener_id); + // Find and remove existing listener for this characteristic + auto *existing = this->find_listener_(characteristic); + if (existing != nullptr) { + // Remove from vector + this->remove_listener_(characteristic); } - // Create a new listener for the pre-notify event - EventEmitterListenerID pre_notify_listener_id = - this->on(BLECharacteristicSetValueActionEvt::PRE_NOTIFY, - [pre_notify_listener, characteristic](const BLECharacteristic *evt_characteristic) { - // Only call the pre-notify listener if the characteristic is the one we are interested in - if (characteristic == evt_characteristic) { - pre_notify_listener(); - } - }); - // Save the pair listener_id, pre_notify_listener_id to the map - this->listeners_[characteristic] = std::make_pair(listener_id, pre_notify_listener_id); + // Save the entry to the vector + this->listeners_.push_back({characteristic, pre_notify_listener}); } +BLECharacteristicSetValueActionManager::ListenerEntry *BLECharacteristicSetValueActionManager::find_listener_( + BLECharacteristic *characteristic) { + for (auto &entry : this->listeners_) { + if (entry.characteristic == characteristic) { + return &entry; + } + } + return nullptr; +} + +void BLECharacteristicSetValueActionManager::remove_listener_(BLECharacteristic *characteristic) { + // Since we typically have very few listeners, optimize by swapping with back and popping + for (size_t i = 0; i < this->listeners_.size(); i++) { + if (this->listeners_[i].characteristic == characteristic) { + // Swap with last element and pop (safe even when i is the last element) + this->listeners_[i] = this->listeners_.back(); + this->listeners_.pop_back(); + return; + } + } +} +#endif + } // namespace esp32_ble_server_automations } // namespace esp32_ble_server } // namespace esphome diff --git a/esphome/components/esp32_ble_server/ble_server_automations.h b/esphome/components/esp32_ble_server/ble_server_automations.h index eab6b05f05..543b1153fc 100644 --- a/esphome/components/esp32_ble_server/ble_server_automations.h +++ b/esphome/components/esp32_ble_server/ble_server_automations.h @@ -4,11 +4,9 @@ #include "ble_characteristic.h" #include "ble_descriptor.h" -#include "esphome/components/event_emitter/event_emitter.h" #include "esphome/core/automation.h" #include -#include #include #ifdef USE_ESP32 @@ -19,41 +17,53 @@ namespace esp32_ble_server { namespace esp32_ble_server_automations { using namespace esp32_ble; -using namespace event_emitter; class BLETriggers { public: +#ifdef USE_ESP32_BLE_SERVER_CHARACTERISTIC_ON_WRITE static Trigger, uint16_t> *create_characteristic_on_write_trigger( BLECharacteristic *characteristic); +#endif +#ifdef USE_ESP32_BLE_SERVER_DESCRIPTOR_ON_WRITE static Trigger, uint16_t> *create_descriptor_on_write_trigger(BLEDescriptor *descriptor); +#endif +#ifdef USE_ESP32_BLE_SERVER_ON_CONNECT static Trigger *create_server_on_connect_trigger(BLEServer *server); +#endif +#ifdef USE_ESP32_BLE_SERVER_ON_DISCONNECT static Trigger *create_server_on_disconnect_trigger(BLEServer *server); +#endif }; -enum BLECharacteristicSetValueActionEvt { - PRE_NOTIFY, -}; - +#ifdef USE_ESP32_BLE_SERVER_SET_VALUE_ACTION // Class to make sure only one BLECharacteristicSetValueAction is active at a time for each characteristic -class BLECharacteristicSetValueActionManager - : public EventEmitter { +class BLECharacteristicSetValueActionManager { public: // Singleton pattern static BLECharacteristicSetValueActionManager *get_instance() { static BLECharacteristicSetValueActionManager instance; return &instance; } - void set_listener(BLECharacteristic *characteristic, EventEmitterListenerID listener_id, - const std::function &pre_notify_listener); - EventEmitterListenerID get_listener(BLECharacteristic *characteristic) { - return this->listeners_[characteristic].first; - } + void set_listener(BLECharacteristic *characteristic, const std::function &pre_notify_listener); + bool has_listener(BLECharacteristic *characteristic) { return this->find_listener_(characteristic) != nullptr; } void emit_pre_notify(BLECharacteristic *characteristic) { - this->emit_(BLECharacteristicSetValueActionEvt::PRE_NOTIFY, characteristic); + for (const auto &entry : this->listeners_) { + if (entry.characteristic == characteristic) { + entry.pre_notify_listener(); + break; + } + } } private: - std::unordered_map> listeners_; + struct ListenerEntry { + BLECharacteristic *characteristic; + std::function pre_notify_listener; + }; + std::vector listeners_; + + ListenerEntry *find_listener_(BLECharacteristic *characteristic); + void remove_listener_(BLECharacteristic *characteristic); }; template class BLECharacteristicSetValueAction : public Action { @@ -63,32 +73,34 @@ template class BLECharacteristicSetValueAction : public Actionset_buffer(buffer.get_data()); } void play(Ts... x) override { // If the listener is already set, do nothing - if (BLECharacteristicSetValueActionManager::get_instance()->get_listener(this->parent_) == this->listener_id_) + if (BLECharacteristicSetValueActionManager::get_instance()->has_listener(this->parent_)) return; // Set initial value this->parent_->set_value(this->buffer_.value(x...)); // Set the listener for read events - this->listener_id_ = this->parent_->EventEmitter::on( - BLECharacteristicEvt::EmptyEvt::ON_READ, [this, x...](uint16_t id) { - // Set the value of the characteristic every time it is read - this->parent_->set_value(this->buffer_.value(x...)); - }); + this->parent_->on_read([this, x...](uint16_t id) { + // Set the value of the characteristic every time it is read + this->parent_->set_value(this->buffer_.value(x...)); + }); // Set the listener in the global manager so only one BLECharacteristicSetValueAction is set for each characteristic BLECharacteristicSetValueActionManager::get_instance()->set_listener( - this->parent_, this->listener_id_, [this, x...]() { this->parent_->set_value(this->buffer_.value(x...)); }); + this->parent_, [this, x...]() { this->parent_->set_value(this->buffer_.value(x...)); }); } protected: BLECharacteristic *parent_; - EventEmitterListenerID listener_id_; }; +#endif // USE_ESP32_BLE_SERVER_SET_VALUE_ACTION +#ifdef USE_ESP32_BLE_SERVER_NOTIFY_ACTION template class BLECharacteristicNotifyAction : public Action { public: BLECharacteristicNotifyAction(BLECharacteristic *characteristic) : parent_(characteristic) {} void play(Ts... x) override { +#ifdef USE_ESP32_BLE_SERVER_SET_VALUE_ACTION // Call the pre-notify event BLECharacteristicSetValueActionManager::get_instance()->emit_pre_notify(this->parent_); +#endif // Notify the characteristic this->parent_->notify(); } @@ -96,7 +108,9 @@ template class BLECharacteristicNotifyAction : public Action class BLEDescriptorSetValueAction : public Action { public: BLEDescriptorSetValueAction(BLEDescriptor *descriptor) : parent_(descriptor) {} @@ -107,6 +121,7 @@ template class BLEDescriptorSetValueAction : public Action ConfigType: + if CONF_MAX_CONNECTIONS in config: + _LOGGER.warning( + "The 'max_connections' option in 'esp32_ble_tracker' is deprecated. " + "Please move it to the 'esp32_ble' component instead." + ) + return config + + def as_hex(value): return cg.RawExpression(f"0x{value}ULL") @@ -150,29 +152,13 @@ def as_reversed_hex_array(value): ) -def max_connections() -> int: - return IDF_MAX_CONNECTIONS if CORE.using_esp_idf else DEFAULT_MAX_CONNECTIONS - - -def consume_connection_slots( - value: int, consumer: str -) -> Callable[[MutableMapping], MutableMapping]: - def _consume_connection_slots(config: MutableMapping) -> MutableMapping: - data: dict[str, Any] = CORE.data.setdefault(KEY_ESP32_BLE_TRACKER, {}) - slots: list[str] = data.setdefault(KEY_USED_CONNECTION_SLOTS, []) - slots.extend([consumer] * value) - return config - - return _consume_connection_slots - - CONFIG_SCHEMA = cv.All( cv.Schema( { cv.GenerateID(): cv.declare_id(ESP32BLETracker), cv.GenerateID(esp32_ble.CONF_BLE_ID): cv.use_id(esp32_ble.ESP32BLE), - cv.Optional(CONF_MAX_CONNECTIONS, default=DEFAULT_MAX_CONNECTIONS): cv.All( - cv.positive_int, cv.Range(min=0, max=max_connections()) + cv.Optional(CONF_MAX_CONNECTIONS): cv.All( + cv.positive_int, cv.Range(min=0, max=IDF_MAX_CONNECTIONS) ), cv.Optional(CONF_SCAN_PARAMETERS, default={}): cv.All( cv.Schema( @@ -228,49 +214,11 @@ CONFIG_SCHEMA = cv.All( cv.OnlyWith(CONF_SOFTWARE_COEXISTENCE, "wifi", default=True): bool, } ).extend(cv.COMPONENT_SCHEMA), + validate_max_connections_deprecated, ) -def validate_remaining_connections(config): - data: dict[str, Any] = CORE.data.get(KEY_ESP32_BLE_TRACKER, {}) - slots: list[str] = data.get(KEY_USED_CONNECTION_SLOTS, []) - used_slots = len(slots) - if used_slots <= config[CONF_MAX_CONNECTIONS]: - return config - slot_users = ", ".join(slots) - hard_limit = max_connections() - - if used_slots < hard_limit: - _LOGGER.warning( - "esp32_ble_tracker exceeded `%s`: components attempted to consume %d " - "connection slot(s) out of available configured maximum %d connection " - "slot(s); The system automatically increased `%s` to %d to match the " - "number of used connection slot(s) by components: %s.", - CONF_MAX_CONNECTIONS, - used_slots, - config[CONF_MAX_CONNECTIONS], - CONF_MAX_CONNECTIONS, - used_slots, - slot_users, - ) - config[CONF_MAX_CONNECTIONS] = used_slots - return config - - msg = ( - f"esp32_ble_tracker exceeded `{CONF_MAX_CONNECTIONS}`: " - f"components attempted to consume {used_slots} connection slot(s) " - f"out of available configured maximum {config[CONF_MAX_CONNECTIONS]} " - f"connection slot(s); Decrease the number of BLE clients ({slot_users})" - ) - if config[CONF_MAX_CONNECTIONS] < hard_limit: - msg += f" or increase {CONF_MAX_CONNECTIONS}` to {used_slots}" - msg += f" to stay under the {hard_limit} connection slot(s) limit." - raise cv.Invalid(msg) - - -FINAL_VALIDATE_SCHEMA = cv.All( - validate_remaining_connections, esp32_ble.validate_variant -) +FINAL_VALIDATE_SCHEMA = esp32_ble.validate_variant ESP_BLE_DEVICE_SCHEMA = cv.Schema( { @@ -342,19 +290,16 @@ async def to_code(config): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation(trigger, [], conf) - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) - if config.get(CONF_SOFTWARE_COEXISTENCE): - add_idf_sdkconfig_option("CONFIG_SW_COEXIST_ENABLE", True) - # https://github.com/espressif/esp-idf/issues/4101 - # https://github.com/espressif/esp-idf/issues/2503 - # Match arduino CONFIG_BTU_TASK_STACK_SIZE - # https://github.com/espressif/arduino-esp32/blob/fd72cf46ad6fc1a6de99c1d83ba8eba17d80a4ee/tools/sdk/esp32/sdkconfig#L1866 - add_idf_sdkconfig_option("CONFIG_BT_BTU_TASK_STACK_SIZE", 8192) - add_idf_sdkconfig_option("CONFIG_BT_ACL_CONNECTIONS", 9) - add_idf_sdkconfig_option( - "CONFIG_BTDM_CTRL_BLE_MAX_CONN", config[CONF_MAX_CONNECTIONS] - ) + add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) + if config.get(CONF_SOFTWARE_COEXISTENCE): + add_idf_sdkconfig_option("CONFIG_SW_COEXIST_ENABLE", True) + # https://github.com/espressif/esp-idf/issues/4101 + # https://github.com/espressif/esp-idf/issues/2503 + # Match arduino CONFIG_BTU_TASK_STACK_SIZE + # https://github.com/espressif/arduino-esp32/blob/fd72cf46ad6fc1a6de99c1d83ba8eba17d80a4ee/tools/sdk/esp32/sdkconfig#L1866 + add_idf_sdkconfig_option("CONFIG_BT_BTU_TASK_STACK_SIZE", 8192) + # Note: CONFIG_BT_ACL_CONNECTIONS and CONFIG_BTDM_CTRL_BLE_MAX_CONN are now + # configured in esp32_ble component based on max_connections setting cg.add_define("USE_OTA_STATE_CALLBACK") # To be notified when an OTA update starts cg.add_define("USE_ESP32_BLE_CLIENT") diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index 63fb3b8b32..a7d73a9709 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -51,8 +51,6 @@ const char *client_state_to_string(ClientState state) { return "IDLE"; case ClientState::DISCOVERED: return "DISCOVERED"; - case ClientState::READY_TO_CONNECT: - return "READY_TO_CONNECT"; case ClientState::CONNECTING: return "CONNECTING"; case ClientState::CONNECTED: @@ -297,7 +295,7 @@ void ESP32BLETracker::gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_ga void ESP32BLETracker::gap_scan_event_handler(const BLEScanResult &scan_result) { // Note: This handler is called from the main loop context via esp32_ble's event queue. // We process advertisements immediately instead of buffering them. - ESP_LOGV(TAG, "gap_scan_result - event %d", scan_result.search_evt); + ESP_LOGVV(TAG, "gap_scan_result - event %d", scan_result.search_evt); if (scan_result.search_evt == ESP_GAP_SEARCH_INQ_RES_EVT) { // Process the scan result immediately @@ -794,7 +792,7 @@ void ESP32BLETracker::try_promote_discovered_clients_() { #ifdef USE_ESP32_BLE_SOFTWARE_COEXISTENCE this->update_coex_preference_(true); #endif - client->set_state(ClientState::READY_TO_CONNECT); + client->connect(); break; } } diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index dd67156108..e53c2ac097 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -159,8 +159,6 @@ enum class ClientState : uint8_t { IDLE, // Device advertisement found. DISCOVERED, - // Device is discovered and the scanner is stopped - READY_TO_CONNECT, // Connection in progress. CONNECTING, // Initial connection established. @@ -313,7 +311,6 @@ class ESP32BLETracker : public Component, counts.discovered++; break; case ClientState::CONNECTING: - case ClientState::READY_TO_CONNECT: counts.connecting++; break; default: diff --git a/esphome/components/esp32_camera/__init__.py b/esphome/components/esp32_camera/__init__.py index 6206fe4682..d8ba098645 100644 --- a/esphome/components/esp32_camera/__init__.py +++ b/esphome/components/esp32_camera/__init__.py @@ -21,7 +21,6 @@ from esphome.const import ( CONF_TRIGGER_ID, CONF_VSYNC_PIN, ) -from esphome.core import CORE from esphome.core.entity_helpers import setup_entity import esphome.final_validate as fv @@ -344,8 +343,7 @@ async def to_code(config): cg.add_define("USE_CAMERA") - if CORE.using_esp_idf: - add_idf_component(name="espressif/esp32-camera", ref="2.1.1") + add_idf_component(name="espressif/esp32-camera", ref="2.1.1") for conf in config.get(CONF_ON_STREAM_START, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) diff --git a/esphome/components/esp32_can/esp32_can.cpp b/esphome/components/esp32_can/esp32_can.cpp index b5e72497ce..cdef7b1930 100644 --- a/esphome/components/esp32_can/esp32_can.cpp +++ b/esphome/components/esp32_can/esp32_can.cpp @@ -67,8 +67,16 @@ static bool get_bitrate(canbus::CanSpeed bitrate, twai_timing_config_t *t_config } bool ESP32Can::setup_internal() { + static int next_twai_ctrl_num = 0; + if (static_cast(next_twai_ctrl_num) >= SOC_TWAI_CONTROLLER_NUM) { + ESP_LOGW(TAG, "Maximum number of esp32_can components created already"); + this->mark_failed(); + return false; + } + twai_general_config_t g_config = TWAI_GENERAL_CONFIG_DEFAULT((gpio_num_t) this->tx_, (gpio_num_t) this->rx_, TWAI_MODE_NORMAL); + g_config.controller_id = next_twai_ctrl_num++; if (this->tx_queue_len_.has_value()) { g_config.tx_queue_len = this->tx_queue_len_.value(); } @@ -86,14 +94,14 @@ bool ESP32Can::setup_internal() { } // Install TWAI driver - if (twai_driver_install(&g_config, &t_config, &f_config) != ESP_OK) { + if (twai_driver_install_v2(&g_config, &t_config, &f_config, &(this->twai_handle_)) != ESP_OK) { // Failed to install driver this->mark_failed(); return false; } // Start TWAI driver - if (twai_start() != ESP_OK) { + if (twai_start_v2(this->twai_handle_) != ESP_OK) { // Failed to start driver this->mark_failed(); return false; @@ -102,6 +110,11 @@ bool ESP32Can::setup_internal() { } canbus::Error ESP32Can::send_message(struct canbus::CanFrame *frame) { + if (this->twai_handle_ == nullptr) { + // not setup yet or setup failed + return canbus::ERROR_FAIL; + } + if (frame->can_data_length_code > canbus::CAN_MAX_DATA_LENGTH) { return canbus::ERROR_FAILTX; } @@ -124,7 +137,7 @@ canbus::Error ESP32Can::send_message(struct canbus::CanFrame *frame) { memcpy(message.data, frame->data, frame->can_data_length_code); } - if (twai_transmit(&message, this->tx_enqueue_timeout_ticks_) == ESP_OK) { + if (twai_transmit_v2(this->twai_handle_, &message, this->tx_enqueue_timeout_ticks_) == ESP_OK) { return canbus::ERROR_OK; } else { return canbus::ERROR_ALLTXBUSY; @@ -132,9 +145,14 @@ canbus::Error ESP32Can::send_message(struct canbus::CanFrame *frame) { } canbus::Error ESP32Can::read_message(struct canbus::CanFrame *frame) { + if (this->twai_handle_ == nullptr) { + // not setup yet or setup failed + return canbus::ERROR_FAIL; + } + twai_message_t message; - if (twai_receive(&message, 0) != ESP_OK) { + if (twai_receive_v2(this->twai_handle_, &message, 0) != ESP_OK) { return canbus::ERROR_NOMSG; } diff --git a/esphome/components/esp32_can/esp32_can.h b/esphome/components/esp32_can/esp32_can.h index 416f037083..dc44aceb36 100644 --- a/esphome/components/esp32_can/esp32_can.h +++ b/esphome/components/esp32_can/esp32_can.h @@ -5,6 +5,8 @@ #include "esphome/components/canbus/canbus.h" #include "esphome/core/component.h" +#include + namespace esphome { namespace esp32_can { @@ -29,6 +31,7 @@ class ESP32Can : public canbus::Canbus { TickType_t tx_enqueue_timeout_ticks_{}; optional tx_queue_len_{}; optional rx_queue_len_{}; + twai_handle_t twai_handle_{nullptr}; }; } // namespace esp32_can diff --git a/esphome/components/esp32_hosted/__init__.py b/esphome/components/esp32_hosted/__init__.py index 330800df12..9cea02c322 100644 --- a/esphome/components/esp32_hosted/__init__.py +++ b/esphome/components/esp32_hosted/__init__.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from esphome import pins from esphome.components import esp32 @@ -97,5 +98,5 @@ async def to_code(config): esp32.add_extra_script( "post", "esp32_hosted.py", - os.path.join(os.path.dirname(__file__), "esp32_hosted.py.script"), + Path(__file__).parent / "esp32_hosted.py.script", ) diff --git a/esphome/components/esp32_improv/esp32_improv_component.cpp b/esphome/components/esp32_improv/esp32_improv_component.cpp index c5a0b89f99..f773083890 100644 --- a/esphome/components/esp32_improv/esp32_improv_component.cpp +++ b/esphome/components/esp32_improv/esp32_improv_component.cpp @@ -17,6 +17,13 @@ static const char *const TAG = "esp32_improv.component"; static const char *const ESPHOME_MY_LINK = "https://my.home-assistant.io/redirect/config_flow_start?domain=esphome"; static constexpr uint16_t STOP_ADVERTISING_DELAY = 10000; // Delay (ms) before stopping service to allow BLE clients to read the final state +static constexpr uint16_t NAME_ADVERTISING_INTERVAL = 60000; // Advertise name every 60 seconds +static constexpr uint16_t NAME_ADVERTISING_DURATION = 1000; // Advertise name for 1 second + +// Improv service data constants +static constexpr uint8_t IMPROV_SERVICE_DATA_SIZE = 8; +static constexpr uint8_t IMPROV_PROTOCOL_ID_1 = 0x77; // 'P' << 1 | 'R' >> 7 +static constexpr uint8_t IMPROV_PROTOCOL_ID_2 = 0x46; // 'I' << 1 | 'M' >> 7 ESP32ImprovComponent::ESP32ImprovComponent() { global_improv_component = this; } @@ -31,8 +38,7 @@ void ESP32ImprovComponent::setup() { }); } #endif - global_ble_server->on(BLEServerEvt::EmptyEvt::ON_DISCONNECT, - [this](uint16_t conn_id) { this->set_error_(improv::ERROR_NONE); }); + global_ble_server->on_disconnect([this](uint16_t conn_id) { this->set_error_(improv::ERROR_NONE); }); // Start with loop disabled - will be enabled by start() when needed this->disable_loop(); @@ -50,12 +56,11 @@ void ESP32ImprovComponent::setup_characteristics() { this->error_->add_descriptor(error_descriptor); this->rpc_ = this->service_->create_characteristic(improv::RPC_COMMAND_UUID, BLECharacteristic::PROPERTY_WRITE); - this->rpc_->EventEmitter, uint16_t>::on( - BLECharacteristicEvt::VectorEvt::ON_WRITE, [this](const std::vector &data, uint16_t id) { - if (!data.empty()) { - this->incoming_data_.insert(this->incoming_data_.end(), data.begin(), data.end()); - } - }); + this->rpc_->on_write([this](std::span data, uint16_t id) { + if (!data.empty()) { + this->incoming_data_.insert(this->incoming_data_.end(), data.begin(), data.end()); + } + }); BLEDescriptor *rpc_descriptor = new BLE2902(); this->rpc_->add_descriptor(rpc_descriptor); @@ -99,6 +104,11 @@ void ESP32ImprovComponent::loop() { this->process_incoming_data_(); uint32_t now = App.get_loop_component_start_time(); + // Check if we need to update advertising type + if (this->state_ != improv::STATE_STOPPED && this->state_ != improv::STATE_PROVISIONED) { + this->update_advertising_type_(); + } + switch (this->state_) { case improv::STATE_STOPPED: this->set_status_indicator_state_(false); @@ -107,9 +117,15 @@ void ESP32ImprovComponent::loop() { if (this->service_->is_created()) { this->service_->start(); } else if (this->service_->is_running()) { + // Start by advertising the device name first BEFORE setting any state + ESP_LOGV(TAG, "Starting with device name advertising"); + this->advertising_device_name_ = true; + this->last_name_adv_time_ = App.get_loop_component_start_time(); + esp32_ble::global_ble->advertising_set_service_data_and_name(std::span{}, true); esp32_ble::global_ble->advertising_start(); - this->set_state_(improv::STATE_AWAITING_AUTHORIZATION); + // Set initial state based on whether we have an authorizer + this->set_state_(this->get_initial_state_(), false); this->set_error_(improv::ERROR_NONE); ESP_LOGD(TAG, "Service started!"); } @@ -120,24 +136,21 @@ void ESP32ImprovComponent::loop() { if (this->authorizer_ == nullptr || (this->authorized_start_ != 0 && ((now - this->authorized_start_) < this->authorized_duration_))) { this->set_state_(improv::STATE_AUTHORIZED); - } else -#else - { this->set_state_(improv::STATE_AUTHORIZED); } -#endif - { + } else { if (!this->check_identify_()) this->set_status_indicator_state_(true); } +#else + this->set_state_(improv::STATE_AUTHORIZED); +#endif break; } case improv::STATE_AUTHORIZED: { #ifdef USE_BINARY_SENSOR - if (this->authorizer_ != nullptr) { - if (now - this->authorized_start_ > this->authorized_duration_) { - ESP_LOGD(TAG, "Authorization timeout"); - this->set_state_(improv::STATE_AWAITING_AUTHORIZATION); - return; - } + if (this->authorizer_ != nullptr && now - this->authorized_start_ > this->authorized_duration_) { + ESP_LOGD(TAG, "Authorization timeout"); + this->set_state_(improv::STATE_AWAITING_AUTHORIZATION); + return; } #endif if (!this->check_identify_()) { @@ -226,12 +239,15 @@ bool ESP32ImprovComponent::check_identify_() { return identify; } -void ESP32ImprovComponent::set_state_(improv::State state) { -#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_DEBUG - if (this->state_ != state) { - ESP_LOGD(TAG, "State transition: %s (0x%02X) -> %s (0x%02X)", this->state_to_string_(this->state_), this->state_, - this->state_to_string_(state), state); +void ESP32ImprovComponent::set_state_(improv::State state, bool update_advertising) { + // Skip if state hasn't changed + if (this->state_ == state) { + return; } + +#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_DEBUG + ESP_LOGD(TAG, "State transition: %s (0x%02X) -> %s (0x%02X)", this->state_to_string_(this->state_), this->state_, + this->state_to_string_(state), state); #endif this->state_ = state; if (this->status_ != nullptr && (this->status_->get_value().empty() || this->status_->get_value()[0] != state)) { @@ -243,25 +259,13 @@ void ESP32ImprovComponent::set_state_(improv::State state) { // STATE_STOPPED (0x00) is internal only and not part of the Improv spec. // Advertising 0x00 causes undefined behavior in some clients and makes them // repeatedly connect trying to determine the actual state. - if (state != improv::STATE_STOPPED) { - std::vector service_data(8, 0); - service_data[0] = 0x77; // PR - service_data[1] = 0x46; // IM - service_data[2] = static_cast(state); - - uint8_t capabilities = 0x00; -#ifdef USE_OUTPUT - if (this->status_indicator_ != nullptr) - capabilities |= improv::CAPABILITY_IDENTIFY; -#endif - - service_data[3] = capabilities; - service_data[4] = 0x00; // Reserved - service_data[5] = 0x00; // Reserved - service_data[6] = 0x00; // Reserved - service_data[7] = 0x00; // Reserved - - esp32_ble::global_ble->advertising_set_service_data(service_data); + if (state != improv::STATE_STOPPED && update_advertising) { + // State change always overrides name advertising and resets the timer + this->advertising_device_name_ = false; + // Reset the timer so we wait another 60 seconds before advertising name + this->last_name_adv_time_ = App.get_loop_component_start_time(); + // Advertise the new state via service data + this->advertise_service_data_(); } #ifdef USE_ESP32_IMPROV_STATE_CALLBACK this->state_callback_.call(this->state_, this->error_state_); @@ -388,6 +392,60 @@ void ESP32ImprovComponent::on_wifi_connect_timeout_() { wifi::global_wifi_component->clear_sta(); } +void ESP32ImprovComponent::advertise_service_data_() { + uint8_t service_data[IMPROV_SERVICE_DATA_SIZE] = {}; + service_data[0] = IMPROV_PROTOCOL_ID_1; // PR + service_data[1] = IMPROV_PROTOCOL_ID_2; // IM + service_data[2] = static_cast(this->state_); + + uint8_t capabilities = 0x00; +#ifdef USE_OUTPUT + if (this->status_indicator_ != nullptr) + capabilities |= improv::CAPABILITY_IDENTIFY; +#endif + + service_data[3] = capabilities; + // service_data[4-7] are already 0 (Reserved) + + // Atomically set service data and disable name in advertising + esp32_ble::global_ble->advertising_set_service_data_and_name(std::span(service_data), false); +} + +void ESP32ImprovComponent::update_advertising_type_() { + uint32_t now = App.get_loop_component_start_time(); + + // If we're advertising the device name and it's been more than NAME_ADVERTISING_DURATION, switch back to service data + if (this->advertising_device_name_) { + if (now - this->last_name_adv_time_ >= NAME_ADVERTISING_DURATION) { + ESP_LOGV(TAG, "Switching back to service data advertising"); + this->advertising_device_name_ = false; + // Restore service data advertising + this->advertise_service_data_(); + } + return; + } + + // Check if it's time to advertise the device name (every NAME_ADVERTISING_INTERVAL) + if (now - this->last_name_adv_time_ >= NAME_ADVERTISING_INTERVAL) { + ESP_LOGV(TAG, "Switching to device name advertising"); + this->advertising_device_name_ = true; + this->last_name_adv_time_ = now; + + // Atomically clear service data and enable name in advertising data + esp32_ble::global_ble->advertising_set_service_data_and_name(std::span{}, true); + } +} + +improv::State ESP32ImprovComponent::get_initial_state_() const { +#ifdef USE_BINARY_SENSOR + // If we have an authorizer, start in awaiting authorization state + return this->authorizer_ == nullptr ? improv::STATE_AUTHORIZED : improv::STATE_AWAITING_AUTHORIZATION; +#else + // No binary_sensor support = no authorizer possible, start as authorized + return improv::STATE_AUTHORIZED; +#endif +} + ESP32ImprovComponent *global_improv_component = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) } // namespace esp32_improv diff --git a/esphome/components/esp32_improv/esp32_improv_component.h b/esphome/components/esp32_improv/esp32_improv_component.h index 686da08111..eb07e09dce 100644 --- a/esphome/components/esp32_improv/esp32_improv_component.h +++ b/esphome/components/esp32_improv/esp32_improv_component.h @@ -100,14 +100,19 @@ class ESP32ImprovComponent : public Component { #endif bool status_indicator_state_{false}; + uint32_t last_name_adv_time_{0}; + bool advertising_device_name_{false}; void set_status_indicator_state_(bool state); + void update_advertising_type_(); - void set_state_(improv::State state); + void set_state_(improv::State state, bool update_advertising = true); void set_error_(improv::Error error); + improv::State get_initial_state_() const; void send_response_(std::vector &response); void process_incoming_data_(); void on_wifi_connect_timeout_(); bool check_identify_(); + void advertise_service_data_(); #if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_DEBUG const char *state_to_string_(improv::State state); #endif diff --git a/esphome/components/esp32_rmt_led_strip/led_strip.cpp b/esphome/components/esp32_rmt_led_strip/led_strip.cpp index 344ea35e81..fa43aa5950 100644 --- a/esphome/components/esp32_rmt_led_strip/led_strip.cpp +++ b/esphome/components/esp32_rmt_led_strip/led_strip.cpp @@ -35,7 +35,7 @@ static size_t IRAM_ATTR HOT encoder_callback(const void *data, size_t size, size if (symbols_free < RMT_SYMBOLS_PER_BYTE) { return 0; } - for (int32_t i = 0; i < RMT_SYMBOLS_PER_BYTE; i++) { + for (size_t i = 0; i < RMT_SYMBOLS_PER_BYTE; i++) { if (bytes[index] & (1 << (7 - i))) { symbols[i] = params->bit1; } else { diff --git a/esphome/components/esp8266/__init__.py b/esphome/components/esp8266/__init__.py index b85314214e..8a7fbbcb0a 100644 --- a/esphome/components/esp8266/__init__.py +++ b/esphome/components/esp8266/__init__.py @@ -1,5 +1,5 @@ import logging -import os +from pathlib import Path import esphome.codegen as cg import esphome.config_validation as cv @@ -259,8 +259,8 @@ async def to_code(config): # Called by writer.py def copy_files(): - dir = os.path.dirname(__file__) - post_build_file = os.path.join(dir, "post_build.py.script") + dir = Path(__file__).parent + post_build_file = dir / "post_build.py.script" copy_file_if_changed( post_build_file, CORE.relative_build_path("post_build.py"), diff --git a/esphome/components/esphome/ota/__init__.py b/esphome/components/esphome/ota/__init__.py index 93216f9425..e6f249e021 100644 --- a/esphome/components/esphome/ota/__init__.py +++ b/esphome/components/esphome/ota/__init__.py @@ -16,7 +16,7 @@ from esphome.const import ( CONF_SAFE_MODE, CONF_VERSION, ) -from esphome.core import coroutine_with_priority +from esphome.core import CORE, coroutine_with_priority from esphome.coroutine import CoroPriority import esphome.final_validate as fv @@ -24,9 +24,22 @@ _LOGGER = logging.getLogger(__name__) CODEOWNERS = ["@esphome/core"] -AUTO_LOAD = ["md5", "socket"] DEPENDENCIES = ["network"] + +def supports_sha256() -> bool: + """Check if the current platform supports SHA256 for OTA authentication.""" + return bool(CORE.is_esp32 or CORE.is_esp8266 or CORE.is_rp2040 or CORE.is_libretiny) + + +def AUTO_LOAD() -> list[str]: + """Conditionally auto-load sha256 only on platforms that support it.""" + base_components = ["md5", "socket"] + if supports_sha256(): + return base_components + ["sha256"] + return base_components + + esphome = cg.esphome_ns.namespace("esphome") ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent) @@ -126,9 +139,15 @@ FINAL_VALIDATE_SCHEMA = ota_esphome_final_validate async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) cg.add(var.set_port(config[CONF_PORT])) + if CONF_PASSWORD in config: cg.add(var.set_auth_password(config[CONF_PASSWORD])) cg.add_define("USE_OTA_PASSWORD") + # Only include hash algorithms when password is configured + cg.add_define("USE_OTA_MD5") + # Only include SHA256 support on platforms that have it + if supports_sha256(): + cg.add_define("USE_OTA_SHA256") cg.add_define("USE_OTA_VERSION", config[CONF_VERSION]) await cg.register_component(var, config) diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index 6654ef8748..b65bfc5ab8 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -1,6 +1,13 @@ #include "ota_esphome.h" #ifdef USE_OTA +#ifdef USE_OTA_PASSWORD +#ifdef USE_OTA_MD5 #include "esphome/components/md5/md5.h" +#endif +#ifdef USE_OTA_SHA256 +#include "esphome/components/sha256/sha256.h" +#endif +#endif #include "esphome/components/network/util.h" #include "esphome/components/ota/ota_backend.h" #include "esphome/components/ota/ota_backend_arduino_esp32.h" @@ -10,6 +17,7 @@ #include "esphome/components/ota/ota_backend_esp_idf.h" #include "esphome/core/application.h" #include "esphome/core/hal.h" +#include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "esphome/core/util.h" @@ -20,9 +28,19 @@ namespace esphome { static const char *const TAG = "esphome.ota"; static constexpr uint16_t OTA_BLOCK_SIZE = 8192; +static constexpr size_t OTA_BUFFER_SIZE = 1024; // buffer size for OTA data transfer static constexpr uint32_t OTA_SOCKET_TIMEOUT_HANDSHAKE = 10000; // milliseconds for initial handshake static constexpr uint32_t OTA_SOCKET_TIMEOUT_DATA = 90000; // milliseconds for data transfer +#ifdef USE_OTA_PASSWORD +#ifdef USE_OTA_MD5 +static constexpr size_t MD5_HEX_SIZE = 32; // MD5 hash as hex string (16 bytes * 2) +#endif +#ifdef USE_OTA_SHA256 +static constexpr size_t SHA256_HEX_SIZE = 64; // SHA256 hash as hex string (32 bytes * 2) +#endif +#endif // USE_OTA_PASSWORD + void ESPHomeOTAComponent::setup() { #ifdef USE_OTA_STATE_CALLBACK ota::register_ota_platform(this); @@ -63,7 +81,7 @@ void ESPHomeOTAComponent::setup() { return; } - err = this->server_->listen(4); + err = this->server_->listen(1); // Only one client at a time if (err != 0) { this->log_socket_error_(LOG_STR("listen")); this->mark_failed(); @@ -95,13 +113,22 @@ void ESPHomeOTAComponent::loop() { } static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01; +#ifdef USE_OTA_SHA256 +static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02; +#endif + +// Temporary flag to allow MD5 downgrade for ~3 versions (until 2026.1.0) +// This allows users to downgrade via OTA if they encounter issues after updating. +// Without this, users would need to do a serial flash to downgrade. +// TODO: Remove this flag and all associated code in 2026.1.0 +#define ALLOW_OTA_DOWNGRADE_MD5 void ESPHomeOTAComponent::handle_handshake_() { - /// Handle the initial OTA handshake. + /// Handle the OTA handshake and authentication. /// /// This method is non-blocking and will return immediately if no data is available. - /// It reads all 5 magic bytes (0x6C, 0x26, 0xF7, 0x5C, 0x45) non-blocking - /// before proceeding to handle_data_(). A 10-second timeout is enforced from initial connection. + /// It manages the state machine through connection, magic bytes validation, feature + /// negotiation, and authentication before entering the blocking data transfer phase. if (this->client_ == nullptr) { // We already checked server_->ready() in loop(), so we can accept directly @@ -126,7 +153,8 @@ void ESPHomeOTAComponent::handle_handshake_() { } this->log_start_(LOG_STR("handshake")); this->client_connect_time_ = App.get_loop_component_start_time(); - this->magic_buf_pos_ = 0; // Reset magic buffer position + this->handshake_buf_pos_ = 0; // Reset handshake buffer position + this->ota_state_ = OTAState::MAGIC_READ; } // Check for handshake timeout @@ -137,46 +165,99 @@ void ESPHomeOTAComponent::handle_handshake_() { return; } - // Try to read remaining magic bytes - if (this->magic_buf_pos_ < 5) { - // Read as many bytes as available - uint8_t bytes_to_read = 5 - this->magic_buf_pos_; - ssize_t read = this->client_->read(this->magic_buf_ + this->magic_buf_pos_, bytes_to_read); - - if (read == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { - return; // No data yet, try again next loop - } - - if (read <= 0) { - // Error or connection closed - if (read == -1) { - this->log_socket_error_(LOG_STR("reading magic bytes")); - } else { - ESP_LOGW(TAG, "Remote closed during handshake"); + switch (this->ota_state_) { + case OTAState::MAGIC_READ: { + // Try to read remaining magic bytes (5 total) + if (!this->try_read_(5, LOG_STR("read magic"))) { + return; } - this->cleanup_connection_(); - return; + + // Validate magic bytes + static const uint8_t MAGIC_BYTES[5] = {0x6C, 0x26, 0xF7, 0x5C, 0x45}; + if (memcmp(this->handshake_buf_, MAGIC_BYTES, 5) != 0) { + ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", this->handshake_buf_[0], + this->handshake_buf_[1], this->handshake_buf_[2], this->handshake_buf_[3], this->handshake_buf_[4]); + this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_MAGIC); + return; + } + + // Magic bytes valid, move to next state + this->transition_ota_state_(OTAState::MAGIC_ACK); + this->handshake_buf_[0] = ota::OTA_RESPONSE_OK; + this->handshake_buf_[1] = USE_OTA_VERSION; + [[fallthrough]]; } - this->magic_buf_pos_ += read; - } - - // Check if we have all 5 magic bytes - if (this->magic_buf_pos_ == 5) { - // Validate magic bytes - static const uint8_t MAGIC_BYTES[5] = {0x6C, 0x26, 0xF7, 0x5C, 0x45}; - if (memcmp(this->magic_buf_, MAGIC_BYTES, 5) != 0) { - ESP_LOGW(TAG, "Magic bytes mismatch! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", this->magic_buf_[0], - this->magic_buf_[1], this->magic_buf_[2], this->magic_buf_[3], this->magic_buf_[4]); - // Send error response (non-blocking, best effort) - uint8_t error = static_cast(ota::OTA_RESPONSE_ERROR_MAGIC); - this->client_->write(&error, 1); - this->cleanup_connection_(); - return; + case OTAState::MAGIC_ACK: { + // Send OK and version - 2 bytes + if (!this->try_write_(2, LOG_STR("ack magic"))) { + return; + } + // All bytes sent, create backend and move to next state + this->backend_ = ota::make_ota_backend(); + this->transition_ota_state_(OTAState::FEATURE_READ); + [[fallthrough]]; } - // All 5 magic bytes are valid, continue with data handling - this->handle_data_(); + case OTAState::FEATURE_READ: { + // Read features - 1 byte + if (!this->try_read_(1, LOG_STR("read feature"))) { + return; + } + this->ota_features_ = this->handshake_buf_[0]; + ESP_LOGV(TAG, "Features: 0x%02X", this->ota_features_); + this->transition_ota_state_(OTAState::FEATURE_ACK); + this->handshake_buf_[0] = + ((this->ota_features_ & FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression()) + ? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION + : ota::OTA_RESPONSE_HEADER_OK; + [[fallthrough]]; + } + + case OTAState::FEATURE_ACK: { + // Acknowledge header - 1 byte + if (!this->try_write_(1, LOG_STR("ack feature"))) { + return; + } +#ifdef USE_OTA_PASSWORD + // If password is set, move to auth phase + if (!this->password_.empty()) { + this->transition_ota_state_(OTAState::AUTH_SEND); + } else +#endif + { + // No password, move directly to data phase + this->transition_ota_state_(OTAState::DATA); + } + [[fallthrough]]; + } + +#ifdef USE_OTA_PASSWORD + case OTAState::AUTH_SEND: { + // Non-blocking authentication send + if (!this->handle_auth_send_()) { + return; + } + this->transition_ota_state_(OTAState::AUTH_READ); + [[fallthrough]]; + } + + case OTAState::AUTH_READ: { + // Non-blocking authentication read & verify + if (!this->handle_auth_read_()) { + return; + } + this->transition_ota_state_(OTAState::DATA); + [[fallthrough]]; + } +#endif + + case OTAState::DATA: + this->handle_data_(); + return; + + default: + break; } } @@ -184,104 +265,21 @@ void ESPHomeOTAComponent::handle_data_() { /// Handle the OTA data transfer and update process. /// /// This method is blocking and will not return until the OTA update completes, - /// fails, or times out. It handles authentication, receives the firmware data, - /// writes it to flash, and reboots on success. + /// fails, or times out. It receives the firmware data, writes it to flash, + /// and reboots on success. + /// + /// Authentication has already been handled in the non-blocking states AUTH_SEND/AUTH_READ. ota::OTAResponseTypes error_code = ota::OTA_RESPONSE_ERROR_UNKNOWN; bool update_started = false; size_t total = 0; uint32_t last_progress = 0; - uint8_t buf[1024]; + uint8_t buf[OTA_BUFFER_SIZE]; char *sbuf = reinterpret_cast(buf); size_t ota_size; - uint8_t ota_features; - std::unique_ptr backend; - (void) ota_features; #if USE_OTA_VERSION == 2 size_t size_acknowledged = 0; #endif - // Send OK and version - 2 bytes - buf[0] = ota::OTA_RESPONSE_OK; - buf[1] = USE_OTA_VERSION; - this->writeall_(buf, 2); - - backend = ota::make_ota_backend(); - - // Read features - 1 byte - if (!this->readall_(buf, 1)) { - this->log_read_error_(LOG_STR("features")); - goto error; // NOLINT(cppcoreguidelines-avoid-goto) - } - ota_features = buf[0]; // NOLINT - ESP_LOGV(TAG, "Features: 0x%02X", ota_features); - - // Acknowledge header - 1 byte - buf[0] = ota::OTA_RESPONSE_HEADER_OK; - if ((ota_features & FEATURE_SUPPORTS_COMPRESSION) != 0 && backend->supports_compression()) { - buf[0] = ota::OTA_RESPONSE_SUPPORTS_COMPRESSION; - } - - this->writeall_(buf, 1); - -#ifdef USE_OTA_PASSWORD - if (!this->password_.empty()) { - buf[0] = ota::OTA_RESPONSE_REQUEST_AUTH; - this->writeall_(buf, 1); - md5::MD5Digest md5{}; - md5.init(); - sprintf(sbuf, "%08" PRIx32, random_uint32()); - md5.add(sbuf, 8); - md5.calculate(); - md5.get_hex(sbuf); - ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf); - - // Send nonce, 32 bytes hex MD5 - if (!this->writeall_(reinterpret_cast(sbuf), 32)) { - ESP_LOGW(TAG, "Auth: Writing nonce failed"); - goto error; // NOLINT(cppcoreguidelines-avoid-goto) - } - - // prepare challenge - md5.init(); - md5.add(this->password_.c_str(), this->password_.length()); - // add nonce - md5.add(sbuf, 32); - - // Receive cnonce, 32 bytes hex MD5 - if (!this->readall_(buf, 32)) { - ESP_LOGW(TAG, "Auth: Reading cnonce failed"); - goto error; // NOLINT(cppcoreguidelines-avoid-goto) - } - sbuf[32] = '\0'; - ESP_LOGV(TAG, "Auth: CNonce is %s", sbuf); - // add cnonce - md5.add(sbuf, 32); - - // calculate result - md5.calculate(); - md5.get_hex(sbuf); - ESP_LOGV(TAG, "Auth: Result is %s", sbuf); - - // Receive result, 32 bytes hex MD5 - if (!this->readall_(buf + 64, 32)) { - ESP_LOGW(TAG, "Auth: Reading response failed"); - goto error; // NOLINT(cppcoreguidelines-avoid-goto) - } - sbuf[64 + 32] = '\0'; - ESP_LOGV(TAG, "Auth: Response is %s", sbuf + 64); - - bool matches = true; - for (uint8_t i = 0; i < 32; i++) - matches = matches && buf[i] == buf[64 + i]; - - if (!matches) { - ESP_LOGW(TAG, "Auth failed! Passwords do not match"); - error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID; - goto error; // NOLINT(cppcoreguidelines-avoid-goto) - } - } -#endif // USE_OTA_PASSWORD - // Acknowledge auth OK - 1 byte buf[0] = ota::OTA_RESPONSE_AUTH_OK; this->writeall_(buf, 1); @@ -309,7 +307,7 @@ void ESPHomeOTAComponent::handle_data_() { #endif // This will block for a few seconds as it locks flash - error_code = backend->begin(ota_size); + error_code = this->backend_->begin(ota_size); if (error_code != ota::OTA_RESPONSE_OK) goto error; // NOLINT(cppcoreguidelines-avoid-goto) update_started = true; @@ -325,7 +323,7 @@ void ESPHomeOTAComponent::handle_data_() { } sbuf[32] = '\0'; ESP_LOGV(TAG, "Update: Binary MD5 is %s", sbuf); - backend->set_update_md5(sbuf); + this->backend_->set_update_md5(sbuf); // Acknowledge MD5 OK - 1 byte buf[0] = ota::OTA_RESPONSE_BIN_MD5_OK; @@ -333,26 +331,24 @@ void ESPHomeOTAComponent::handle_data_() { while (total < ota_size) { // TODO: timeout check - size_t requested = std::min(sizeof(buf), ota_size - total); + size_t remaining = ota_size - total; + size_t requested = remaining < OTA_BUFFER_SIZE ? remaining : OTA_BUFFER_SIZE; ssize_t read = this->client_->read(buf, requested); if (read == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { + if (this->would_block_(errno)) { this->yield_and_feed_watchdog_(); continue; } - ESP_LOGW(TAG, "Read error, errno %d", errno); + ESP_LOGW(TAG, "Read err %d", errno); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } else if (read == 0) { - // $ man recv - // "When a stream socket peer has performed an orderly shutdown, the return value will - // be 0 (the traditional "end-of-file" return)." - ESP_LOGW(TAG, "Remote closed connection"); + ESP_LOGW(TAG, "Remote closed"); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } - error_code = backend->write(buf, read); + error_code = this->backend_->write(buf, read); if (error_code != ota::OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Flash write error, code: %d", error_code); + ESP_LOGW(TAG, "Flash write err %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } total += read; @@ -381,9 +377,9 @@ void ESPHomeOTAComponent::handle_data_() { buf[0] = ota::OTA_RESPONSE_RECEIVE_OK; this->writeall_(buf, 1); - error_code = backend->end(); + error_code = this->backend_->end(); if (error_code != ota::OTA_RESPONSE_OK) { - ESP_LOGW(TAG, "Error ending update! code: %d", error_code); + ESP_LOGW(TAG, "End update err %d", error_code); goto error; // NOLINT(cppcoreguidelines-avoid-goto) } @@ -412,8 +408,8 @@ error: this->writeall_(buf, 1); this->cleanup_connection_(); - if (backend != nullptr && update_started) { - backend->abort(); + if (this->backend_ != nullptr && update_started) { + this->backend_->abort(); } this->status_momentary_error("onerror", 5000); @@ -434,12 +430,12 @@ bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) { ssize_t read = this->client_->read(buf + at, len - at); if (read == -1) { - if (errno != EAGAIN && errno != EWOULDBLOCK) { - ESP_LOGW(TAG, "Error reading %d bytes, errno %d", len, errno); + if (!this->would_block_(errno)) { + ESP_LOGW(TAG, "Read err %d bytes, errno %d", len, errno); return false; } } else if (read == 0) { - ESP_LOGW(TAG, "Remote closed connection"); + ESP_LOGW(TAG, "Remote closed"); return false; } else { at += read; @@ -461,8 +457,8 @@ bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) { ssize_t written = this->client_->write(buf + at, len - at); if (written == -1) { - if (errno != EAGAIN && errno != EWOULDBLOCK) { - ESP_LOGW(TAG, "Error writing %d bytes, errno %d", len, errno); + if (!this->would_block_(errno)) { + ESP_LOGW(TAG, "Write err %d bytes, errno %d", len, errno); return false; } } else { @@ -487,11 +483,74 @@ void ESPHomeOTAComponent::log_start_(const LogString *phase) { ESP_LOGD(TAG, "Starting %s from %s", LOG_STR_ARG(phase), this->client_->getpeername().c_str()); } +void ESPHomeOTAComponent::log_remote_closed_(const LogString *during) { + ESP_LOGW(TAG, "Remote closed at %s", LOG_STR_ARG(during)); +} + +bool ESPHomeOTAComponent::handle_read_error_(ssize_t read, const LogString *desc) { + if (read == -1 && this->would_block_(errno)) { + return false; // No data yet, try again next loop + } + + if (read <= 0) { + read == 0 ? this->log_remote_closed_(desc) : this->log_socket_error_(desc); + this->cleanup_connection_(); + return false; + } + return true; +} + +bool ESPHomeOTAComponent::handle_write_error_(ssize_t written, const LogString *desc) { + if (written == -1) { + if (this->would_block_(errno)) { + return false; // Try again next loop + } + this->log_socket_error_(desc); + this->cleanup_connection_(); + return false; + } + return true; +} + +bool ESPHomeOTAComponent::try_read_(size_t to_read, const LogString *desc) { + // Read bytes into handshake buffer, starting at handshake_buf_pos_ + size_t bytes_to_read = to_read - this->handshake_buf_pos_; + ssize_t read = this->client_->read(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_read); + + if (!this->handle_read_error_(read, desc)) { + return false; + } + + this->handshake_buf_pos_ += read; + // Return true only if we have all the requested bytes + return this->handshake_buf_pos_ >= to_read; +} + +bool ESPHomeOTAComponent::try_write_(size_t to_write, const LogString *desc) { + // Write bytes from handshake buffer, starting at handshake_buf_pos_ + size_t bytes_to_write = to_write - this->handshake_buf_pos_; + ssize_t written = this->client_->write(this->handshake_buf_ + this->handshake_buf_pos_, bytes_to_write); + + if (!this->handle_write_error_(written, desc)) { + return false; + } + + this->handshake_buf_pos_ += written; + // Return true only if we have written all the requested bytes + return this->handshake_buf_pos_ >= to_write; +} + void ESPHomeOTAComponent::cleanup_connection_() { this->client_->close(); this->client_ = nullptr; this->client_connect_time_ = 0; - this->magic_buf_pos_ = 0; + this->handshake_buf_pos_ = 0; + this->ota_state_ = OTAState::IDLE; + this->ota_features_ = 0; + this->backend_ = nullptr; +#ifdef USE_OTA_PASSWORD + this->cleanup_auth_(); +#endif } void ESPHomeOTAComponent::yield_and_feed_watchdog_() { @@ -499,5 +558,256 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() { delay(1); } +#ifdef USE_OTA_PASSWORD +void ESPHomeOTAComponent::log_auth_warning_(const LogString *msg) { ESP_LOGW(TAG, "Auth: %s", LOG_STR_ARG(msg)); } + +bool ESPHomeOTAComponent::select_auth_type_() { +#ifdef USE_OTA_SHA256 + bool client_supports_sha256 = (this->ota_features_ & FEATURE_SUPPORTS_SHA256_AUTH) != 0; + +#ifdef ALLOW_OTA_DOWNGRADE_MD5 + // Allow fallback to MD5 if client doesn't support SHA256 + if (client_supports_sha256) { + this->auth_type_ = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH; + return true; + } +#ifdef USE_OTA_MD5 + this->log_auth_warning_(LOG_STR("Using deprecated MD5")); + this->auth_type_ = ota::OTA_RESPONSE_REQUEST_AUTH; + return true; +#else + this->log_auth_warning_(LOG_STR("SHA256 required")); + this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID); + return false; +#endif // USE_OTA_MD5 + +#else // !ALLOW_OTA_DOWNGRADE_MD5 + // Require SHA256 + if (!client_supports_sha256) { + this->log_auth_warning_(LOG_STR("SHA256 required")); + this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID); + return false; + } + this->auth_type_ = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH; + return true; +#endif // ALLOW_OTA_DOWNGRADE_MD5 + +#else // !USE_OTA_SHA256 +#ifdef USE_OTA_MD5 + // Only MD5 available + this->auth_type_ = ota::OTA_RESPONSE_REQUEST_AUTH; + return true; +#else + // No auth methods available + this->log_auth_warning_(LOG_STR("No auth methods available")); + this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID); + return false; +#endif // USE_OTA_MD5 +#endif // USE_OTA_SHA256 +} + +bool ESPHomeOTAComponent::handle_auth_send_() { + // Initialize auth buffer if not already done + if (!this->auth_buf_) { + // Select auth type based on client capabilities and configuration + if (!this->select_auth_type_()) { + return false; + } + + // Generate nonce - hasher must be created and used in same stack frame + // CRITICAL ESP32-S3 HARDWARE SHA ACCELERATION REQUIREMENTS: + // 1. Hash objects must NEVER be passed to another function (different stack frame) + // 2. NO Variable Length Arrays (VLAs) - they corrupt the stack with hardware DMA + // 3. All hash operations (init/add/calculate) must happen in the SAME function where object is created + // Violating these causes truncated hash output (20 bytes instead of 32) or memory corruption. + // + // Buffer layout after AUTH_READ completes: + // [0]: auth_type (1 byte) + // [1...hex_size]: nonce (hex_size bytes) - our random nonce sent in AUTH_SEND + // [1+hex_size...1+2*hex_size-1]: cnonce (hex_size bytes) - client's nonce + // [1+2*hex_size...1+3*hex_size-1]: response (hex_size bytes) - client's hash + + // Declare both hash objects in same stack frame, use pointer to select. + // NOTE: Both objects are declared here even though only one is used. This is REQUIRED for ESP32-S3 + // hardware SHA acceleration - the object must exist in this stack frame for all operations. + // Do NOT try to "optimize" by creating the object inside the if block, as it would go out of scope. +#ifdef USE_OTA_SHA256 + sha256::SHA256 sha_hasher; +#endif +#ifdef USE_OTA_MD5 + md5::MD5Digest md5_hasher; +#endif + HashBase *hasher = nullptr; + +#ifdef USE_OTA_SHA256 + if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_SHA256_AUTH) { + hasher = &sha_hasher; + } +#endif +#ifdef USE_OTA_MD5 + if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_AUTH) { + hasher = &md5_hasher; + } +#endif + + const size_t hex_size = hasher->get_size() * 2; + const size_t nonce_len = hasher->get_size() / 4; + const size_t auth_buf_size = 1 + 3 * hex_size; + this->auth_buf_ = std::make_unique(auth_buf_size); + this->auth_buf_pos_ = 0; + + char *buf = reinterpret_cast(this->auth_buf_.get() + 1); + if (!random_bytes(reinterpret_cast(buf), nonce_len)) { + this->log_auth_warning_(LOG_STR("Random failed")); + this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_UNKNOWN); + return false; + } + + hasher->init(); + hasher->add(buf, nonce_len); + hasher->calculate(); + this->auth_buf_[0] = this->auth_type_; + hasher->get_hex(buf); + +#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE + char log_buf[65]; // Fixed size for SHA256 hex (64) + null, works for MD5 (32) too + memcpy(log_buf, buf, hex_size); + log_buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: Nonce is %s", log_buf); +#endif + } + + // Try to write auth_type + nonce + size_t hex_size = this->get_auth_hex_size_(); + const size_t to_write = 1 + hex_size; + size_t remaining = to_write - this->auth_buf_pos_; + + ssize_t written = this->client_->write(this->auth_buf_.get() + this->auth_buf_pos_, remaining); + if (!this->handle_write_error_(written, LOG_STR("ack auth"))) { + return false; + } + + this->auth_buf_pos_ += written; + + // Check if we still have more to write + if (this->auth_buf_pos_ < to_write) { + return false; // More to write, try again next loop + } + + // All written, prepare for reading phase + this->auth_buf_pos_ = 0; + return true; +} + +bool ESPHomeOTAComponent::handle_auth_read_() { + size_t hex_size = this->get_auth_hex_size_(); + const size_t to_read = hex_size * 2; // CNonce + Response + + // Try to read remaining bytes (CNonce + Response) + // We read cnonce+response starting at offset 1+hex_size (after auth_type and our nonce) + size_t cnonce_offset = 1 + hex_size; // Offset where cnonce should be stored in buffer + size_t remaining = to_read - this->auth_buf_pos_; + ssize_t read = this->client_->read(this->auth_buf_.get() + cnonce_offset + this->auth_buf_pos_, remaining); + + if (!this->handle_read_error_(read, LOG_STR("read auth"))) { + return false; + } + + this->auth_buf_pos_ += read; + + // Check if we still need more data + if (this->auth_buf_pos_ < to_read) { + return false; // More to read, try again next loop + } + + // We have all the data, verify it + const char *nonce = reinterpret_cast(this->auth_buf_.get() + 1); + const char *cnonce = nonce + hex_size; + const char *response = cnonce + hex_size; + + // CRITICAL ESP32-S3: Hash objects must stay in same stack frame (no passing to other functions). + // Declare both hash objects in same stack frame, use pointer to select. + // NOTE: Both objects are declared here even though only one is used. This is REQUIRED for ESP32-S3 + // hardware SHA acceleration - the object must exist in this stack frame for all operations. + // Do NOT try to "optimize" by creating the object inside the if block, as it would go out of scope. +#ifdef USE_OTA_SHA256 + sha256::SHA256 sha_hasher; +#endif +#ifdef USE_OTA_MD5 + md5::MD5Digest md5_hasher; +#endif + HashBase *hasher = nullptr; + +#ifdef USE_OTA_SHA256 + if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_SHA256_AUTH) { + hasher = &sha_hasher; + } +#endif +#ifdef USE_OTA_MD5 + if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_AUTH) { + hasher = &md5_hasher; + } +#endif + + hasher->init(); + hasher->add(this->password_.c_str(), this->password_.length()); + hasher->add(nonce, hex_size * 2); // Add both nonce and cnonce (contiguous in buffer) + hasher->calculate(); + +#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE + char log_buf[65]; // Fixed size for SHA256 hex (64) + null, works for MD5 (32) too + // Log CNonce + memcpy(log_buf, cnonce, hex_size); + log_buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: CNonce is %s", log_buf); + + // Log computed hash + hasher->get_hex(log_buf); + log_buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: Result is %s", log_buf); + + // Log received response + memcpy(log_buf, response, hex_size); + log_buf[hex_size] = '\0'; + ESP_LOGV(TAG, "Auth: Response is %s", log_buf); +#endif + + // Compare response + bool matches = hasher->equals_hex(response); + + if (!matches) { + this->log_auth_warning_(LOG_STR("Password mismatch")); + this->send_error_and_cleanup_(ota::OTA_RESPONSE_ERROR_AUTH_INVALID); + return false; + } + + // Authentication successful - clean up auth state + this->cleanup_auth_(); + + return true; +} + +size_t ESPHomeOTAComponent::get_auth_hex_size_() const { +#ifdef USE_OTA_SHA256 + if (this->auth_type_ == ota::OTA_RESPONSE_REQUEST_SHA256_AUTH) { + return SHA256_HEX_SIZE; + } +#endif +#ifdef USE_OTA_MD5 + return MD5_HEX_SIZE; +#else +#ifndef USE_OTA_SHA256 +#error "Either USE_OTA_MD5 or USE_OTA_SHA256 must be defined when USE_OTA_PASSWORD is enabled" +#endif +#endif +} + +void ESPHomeOTAComponent::cleanup_auth_() { + this->auth_buf_ = nullptr; + this->auth_buf_pos_ = 0; + this->auth_type_ = 0; +} +#endif // USE_OTA_PASSWORD + } // namespace esphome #endif diff --git a/esphome/components/esphome/ota/ota_esphome.h b/esphome/components/esphome/ota/ota_esphome.h index f5a3e43ae3..d4a8410d35 100644 --- a/esphome/components/esphome/ota/ota_esphome.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -7,12 +7,25 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "esphome/core/preferences.h" +#include "esphome/core/hash_base.h" namespace esphome { /// ESPHomeOTAComponent provides a simple way to integrate Over-the-Air updates into your app using ArduinoOTA. class ESPHomeOTAComponent : public ota::OTAComponent { public: + enum class OTAState : uint8_t { + IDLE, + MAGIC_READ, // Reading magic bytes + MAGIC_ACK, // Sending OK and version after magic bytes + FEATURE_READ, // Reading feature flags from client + FEATURE_ACK, // Sending feature acknowledgment +#ifdef USE_OTA_PASSWORD + AUTH_SEND, // Sending authentication request + AUTH_READ, // Reading authentication data +#endif // USE_OTA_PASSWORD + DATA, // BLOCKING! Processing OTA data (update, etc.) + }; #ifdef USE_OTA_PASSWORD void set_auth_password(const std::string &password) { password_ = password; } #endif // USE_OTA_PASSWORD @@ -30,12 +43,38 @@ class ESPHomeOTAComponent : public ota::OTAComponent { protected: void handle_handshake_(); void handle_data_(); +#ifdef USE_OTA_PASSWORD + bool handle_auth_send_(); + bool handle_auth_read_(); + bool select_auth_type_(); + size_t get_auth_hex_size_() const; + void cleanup_auth_(); + void log_auth_warning_(const LogString *msg); +#endif // USE_OTA_PASSWORD bool readall_(uint8_t *buf, size_t len); bool writeall_(const uint8_t *buf, size_t len); + + bool try_read_(size_t to_read, const LogString *desc); + bool try_write_(size_t to_write, const LogString *desc); + + inline bool would_block_(int error_code) const { return error_code == EAGAIN || error_code == EWOULDBLOCK; } + bool handle_read_error_(ssize_t read, const LogString *desc); + bool handle_write_error_(ssize_t written, const LogString *desc); + inline void transition_ota_state_(OTAState next_state) { + this->ota_state_ = next_state; + this->handshake_buf_pos_ = 0; // Reset buffer position for next state + } + void log_socket_error_(const LogString *msg); void log_read_error_(const LogString *what); void log_start_(const LogString *phase); + void log_remote_closed_(const LogString *during); void cleanup_connection_(); + inline void send_error_and_cleanup_(ota::OTAResponseTypes error) { + uint8_t error_byte = static_cast(error); + this->client_->write(&error_byte, 1); // Best effort, non-blocking + this->cleanup_connection_(); + } void yield_and_feed_watchdog_(); #ifdef USE_OTA_PASSWORD @@ -44,11 +83,19 @@ class ESPHomeOTAComponent : public ota::OTAComponent { std::unique_ptr server_; std::unique_ptr client_; + std::unique_ptr backend_; uint32_t client_connect_time_{0}; uint16_t port_; - uint8_t magic_buf_[5]; - uint8_t magic_buf_pos_{0}; + uint8_t handshake_buf_[5]; + OTAState ota_state_{OTAState::IDLE}; + uint8_t handshake_buf_pos_{0}; + uint8_t ota_features_{0}; +#ifdef USE_OTA_PASSWORD + std::unique_ptr auth_buf_; + uint8_t auth_buf_pos_{0}; + uint8_t auth_type_{0}; // Store auth type to know which hasher to use +#endif // USE_OTA_PASSWORD }; } // namespace esphome diff --git a/esphome/components/ethernet/__init__.py b/esphome/components/ethernet/__init__.py index 151da7d0e5..7384bb26d3 100644 --- a/esphome/components/ethernet/__init__.py +++ b/esphome/components/ethernet/__init__.py @@ -2,9 +2,15 @@ import logging from esphome import pins import esphome.codegen as cg -from esphome.components.esp32 import add_idf_sdkconfig_option, get_esp32_variant +from esphome.components.esp32 import ( + add_idf_component, + add_idf_sdkconfig_option, + get_esp32_variant, +) from esphome.components.esp32.const import ( + VARIANT_ESP32, VARIANT_ESP32C3, + VARIANT_ESP32P4, VARIANT_ESP32S2, VARIANT_ESP32S3, ) @@ -21,6 +27,7 @@ from esphome.const import ( CONF_GATEWAY, CONF_ID, CONF_INTERRUPT_PIN, + CONF_MAC_ADDRESS, CONF_MANUAL_IP, CONF_MISO_PIN, CONF_MODE, @@ -75,12 +82,14 @@ ETHERNET_TYPES = { "W5500": EthernetType.ETHERNET_TYPE_W5500, "OPENETH": EthernetType.ETHERNET_TYPE_OPENETH, "DM9051": EthernetType.ETHERNET_TYPE_DM9051, + "LAN8670": EthernetType.ETHERNET_TYPE_LAN8670, } # PHY types that need compile-time defines for conditional compilation _PHY_TYPE_TO_DEFINE = { "KSZ8081": "USE_ETHERNET_KSZ8081", "KSZ8081RNA": "USE_ETHERNET_KSZ8081", + "LAN8670": "USE_ETHERNET_LAN8670", # Add other PHY types here only if they need conditional compilation } @@ -117,19 +126,15 @@ ManualIP = ethernet_ns.struct("ManualIP") def _is_framework_spi_polling_mode_supported(): # SPI Ethernet without IRQ feature is added in - # esp-idf >= (5.3+ ,5.2.1+, 5.1.4) and arduino-esp32 >= 3.0.0 + # esp-idf >= (5.3+ ,5.2.1+, 5.1.4) + # Note: Arduino now uses ESP-IDF as a component, so we only check IDF version framework_version = CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] - if CORE.using_esp_idf: - if framework_version >= cv.Version(5, 3, 0): - return True - if cv.Version(5, 3, 0) > framework_version >= cv.Version(5, 2, 1): - return True - if cv.Version(5, 2, 0) > framework_version >= cv.Version(5, 1, 4): # noqa: SIM103 - return True - return False - if CORE.using_arduino: - return framework_version >= cv.Version(3, 0, 0) - # fail safe: Unknown framework + if framework_version >= cv.Version(5, 3, 0): + return True + if cv.Version(5, 3, 0) > framework_version >= cv.Version(5, 2, 1): + return True + if cv.Version(5, 2, 0) > framework_version >= cv.Version(5, 1, 4): # noqa: SIM103 + return True return False @@ -140,6 +145,7 @@ def _validate(config): else: use_address = CORE.name + config[CONF_DOMAIN] config[CONF_USE_ADDRESS] = use_address + if config[CONF_TYPE] in SPI_ETHERNET_TYPES: if _is_framework_spi_polling_mode_supported(): if CONF_POLLING_INTERVAL in config and CONF_INTERRUPT_PIN in config: @@ -172,6 +178,12 @@ def _validate(config): del config[CONF_CLK_MODE] elif CONF_CLK not in config: raise cv.Invalid("'clk' is a required option for [ethernet].") + variant = get_esp32_variant() + if variant not in (VARIANT_ESP32, VARIANT_ESP32P4): + raise cv.Invalid( + f"{config[CONF_TYPE]} PHY requires RMII interface and is only supported " + f"on ESP32 classic and ESP32-P4, not {variant}" + ) return config @@ -186,6 +198,7 @@ BASE_SCHEMA = cv.Schema( "This option has been removed. Please use the [disabled] option under the " "new mdns component instead." ), + cv.Optional(CONF_MAC_ADDRESS): cv.mac_address, } ).extend(cv.COMPONENT_SCHEMA) @@ -252,6 +265,7 @@ CONFIG_SCHEMA = cv.All( "W5500": SPI_SCHEMA, "OPENETH": BASE_SCHEMA, "DM9051": SPI_SCHEMA, + "LAN8670": RMII_SCHEMA, }, upper=True, ), @@ -322,11 +336,8 @@ async def to_code(config): cg.add(var.set_clock_speed(config[CONF_CLOCK_SPEED])) cg.add_define("USE_ETHERNET_SPI") - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_ETH_USE_SPI_ETHERNET", True) - add_idf_sdkconfig_option( - f"CONFIG_ETH_SPI_ETHERNET_{config[CONF_TYPE]}", True - ) + add_idf_sdkconfig_option("CONFIG_ETH_USE_SPI_ETHERNET", True) + add_idf_sdkconfig_option(f"CONFIG_ETH_SPI_ETHERNET_{config[CONF_TYPE]}", True) elif config[CONF_TYPE] == "OPENETH": cg.add_define("USE_ETHERNET_OPENETH") add_idf_sdkconfig_option("CONFIG_ETH_USE_OPENETH", True) @@ -356,13 +367,19 @@ async def to_code(config): if phy_define := _PHY_TYPE_TO_DEFINE.get(config[CONF_TYPE]): cg.add_define(phy_define) + if mac_address := config.get(CONF_MAC_ADDRESS): + cg.add(var.set_fixed_mac(mac_address.parts)) + cg.add_define("USE_ETHERNET") # Disable WiFi when using Ethernet to save memory - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_ESP_WIFI_ENABLED", False) - # Also disable WiFi/BT coexistence since WiFi is disabled - add_idf_sdkconfig_option("CONFIG_SW_COEXIST_ENABLE", False) + add_idf_sdkconfig_option("CONFIG_ESP_WIFI_ENABLED", False) + # Also disable WiFi/BT coexistence since WiFi is disabled + add_idf_sdkconfig_option("CONFIG_SW_COEXIST_ENABLE", False) + + if config[CONF_TYPE] == "LAN8670": + # Add LAN867x 10BASE-T1S PHY support component + add_idf_component(name="espressif/lan867x", ref="2.0.0") if CORE.using_arduino: cg.add_library("WiFi", None) diff --git a/esphome/components/ethernet/ethernet_component.cpp b/esphome/components/ethernet/ethernet_component.cpp index ff14d19427..28043dd969 100644 --- a/esphome/components/ethernet/ethernet_component.cpp +++ b/esphome/components/ethernet/ethernet_component.cpp @@ -9,6 +9,10 @@ #include #include "esp_event.h" +#ifdef USE_ETHERNET_LAN8670 +#include "esp_eth_phy_lan867x.h" +#endif + #ifdef USE_ETHERNET_SPI #include #include @@ -37,17 +41,20 @@ static const char *const TAG = "ethernet"; EthernetComponent *global_eth_component; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) +void EthernetComponent::log_error_and_mark_failed_(esp_err_t err, const char *message) { + ESP_LOGE(TAG, "%s: (%d) %s", message, err, esp_err_to_name(err)); + this->mark_failed(); +} + #define ESPHL_ERROR_CHECK(err, message) \ if ((err) != ESP_OK) { \ - ESP_LOGE(TAG, message ": (%d) %s", err, esp_err_to_name(err)); \ - this->mark_failed(); \ + this->log_error_and_mark_failed_(err, message); \ return; \ } #define ESPHL_ERROR_CHECK_RET(err, message, ret) \ if ((err) != ESP_OK) { \ - ESP_LOGE(TAG, message ": (%d) %s", err, esp_err_to_name(err)); \ - this->mark_failed(); \ + this->log_error_and_mark_failed_(err, message); \ return ret; \ } @@ -200,6 +207,12 @@ void EthernetComponent::setup() { this->phy_ = esp_eth_phy_new_ksz80xx(&phy_config); break; } +#ifdef USE_ETHERNET_LAN8670 + case ETHERNET_TYPE_LAN8670: { + this->phy_ = esp_eth_phy_new_lan867x(&phy_config); + break; + } +#endif #endif #ifdef USE_ETHERNET_SPI #if CONFIG_ETH_SPI_ETHERNET_W5500 @@ -243,7 +256,11 @@ void EthernetComponent::setup() { // use ESP internal eth mac uint8_t mac_addr[6]; - esp_read_mac(mac_addr, ESP_MAC_ETH); + if (this->fixed_mac_.has_value()) { + memcpy(mac_addr, this->fixed_mac_->data(), 6); + } else { + esp_read_mac(mac_addr, ESP_MAC_ETH); + } err = esp_eth_ioctl(this->eth_handle_, ETH_CMD_S_MAC_ADDR, mac_addr); ESPHL_ERROR_CHECK(err, "set mac address error"); @@ -353,6 +370,12 @@ void EthernetComponent::dump_config() { eth_type = "DM9051"; break; +#ifdef USE_ETHERNET_LAN8670 + case ETHERNET_TYPE_LAN8670: + eth_type = "LAN8670"; + break; +#endif + default: eth_type = "Unknown"; break; diff --git a/esphome/components/ethernet/ethernet_component.h b/esphome/components/ethernet/ethernet_component.h index bbb9d7fb60..6b4e342df5 100644 --- a/esphome/components/ethernet/ethernet_component.h +++ b/esphome/components/ethernet/ethernet_component.h @@ -28,6 +28,7 @@ enum EthernetType : uint8_t { ETHERNET_TYPE_W5500, ETHERNET_TYPE_OPENETH, ETHERNET_TYPE_DM9051, + ETHERNET_TYPE_LAN8670, }; struct ManualIP { @@ -83,6 +84,7 @@ class EthernetComponent : public Component { #endif void set_type(EthernetType type); void set_manual_ip(const ManualIP &manual_ip); + void set_fixed_mac(const std::array &mac) { this->fixed_mac_ = mac; } network::IPAddresses get_ip_addresses(); network::IPAddress get_dns_address(uint8_t num); @@ -104,6 +106,7 @@ class EthernetComponent : public Component { void start_connect_(); void finish_connect_(); void dump_connect_params_(); + void log_error_and_mark_failed_(esp_err_t err, const char *message); #ifdef USE_ETHERNET_KSZ8081 /// @brief Set `RMII Reference Clock Select` bit for KSZ8081. void ksz8081_set_clock_reference_(esp_eth_mac_t *mac); @@ -154,12 +157,13 @@ class EthernetComponent : public Component { esp_netif_t *eth_netif_{nullptr}; esp_eth_handle_t eth_handle_; esp_eth_phy_t *phy_{nullptr}; + optional> fixed_mac_; }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) extern EthernetComponent *global_eth_component; -#if defined(USE_ARDUINO) || ESP_IDF_VERSION < ESP_IDF_VERSION_VAL(5, 4, 2) +#if ESP_IDF_VERSION < ESP_IDF_VERSION_VAL(5, 4, 2) extern "C" esp_eth_phy_t *esp_eth_phy_new_jl1101(const eth_phy_config_t *config); #endif diff --git a/esphome/components/event_emitter/__init__.py b/esphome/components/event_emitter/__init__.py deleted file mode 100644 index fcbbf26f02..0000000000 --- a/esphome/components/event_emitter/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -CODEOWNERS = ["@Rapsssito"] - -# Allows event_emitter to be configured in yaml, to allow use of the C++ api. - -CONFIG_SCHEMA = {} diff --git a/esphome/components/event_emitter/event_emitter.cpp b/esphome/components/event_emitter/event_emitter.cpp deleted file mode 100644 index 8487e19c2f..0000000000 --- a/esphome/components/event_emitter/event_emitter.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "event_emitter.h" - -namespace esphome { -namespace event_emitter { - -static const char *const TAG = "event_emitter"; - -void raise_event_emitter_full_error() { - ESP_LOGE(TAG, "EventEmitter has reached the maximum number of listeners for event"); - ESP_LOGW(TAG, "Removing listener to make space for new listener"); -} - -} // namespace event_emitter -} // namespace esphome diff --git a/esphome/components/event_emitter/event_emitter.h b/esphome/components/event_emitter/event_emitter.h deleted file mode 100644 index 3876a2cc14..0000000000 --- a/esphome/components/event_emitter/event_emitter.h +++ /dev/null @@ -1,63 +0,0 @@ -#pragma once -#include -#include -#include -#include - -#include "esphome/core/log.h" - -namespace esphome { -namespace event_emitter { - -using EventEmitterListenerID = uint32_t; -void raise_event_emitter_full_error(); - -// EventEmitter class that can emit events with a specific name (it is highly recommended to use an enum class for this) -// and a list of arguments. Supports multiple listeners for each event. -template class EventEmitter { - public: - EventEmitterListenerID on(EvtType event, std::function listener) { - EventEmitterListenerID listener_id = get_next_id_(event); - listeners_[event][listener_id] = listener; - return listener_id; - } - - void off(EvtType event, EventEmitterListenerID id) { - if (listeners_.count(event) == 0) - return; - listeners_[event].erase(id); - } - - protected: - void emit_(EvtType event, Args... args) { - if (listeners_.count(event) == 0) - return; - for (const auto &listener : listeners_[event]) { - listener.second(args...); - } - } - - EventEmitterListenerID get_next_id_(EvtType event) { - // Check if the map is full - if (listeners_[event].size() == std::numeric_limits::max()) { - // Raise an error if the map is full - raise_event_emitter_full_error(); - off(event, 0); - return 0; - } - // Get the next ID for the given event. - EventEmitterListenerID next_id = (current_id_ + 1) % std::numeric_limits::max(); - while (listeners_[event].count(next_id) > 0) { - next_id = (next_id + 1) % std::numeric_limits::max(); - } - current_id_ = next_id; - return current_id_; - } - - private: - std::unordered_map>> listeners_; - EventEmitterListenerID current_id_ = 0; -}; - -} // namespace event_emitter -} // namespace esphome diff --git a/esphome/components/external_components/__init__.py b/esphome/components/external_components/__init__.py index a09217fd21..ceb402c5b7 100644 --- a/esphome/components/external_components/__init__.py +++ b/esphome/components/external_components/__init__.py @@ -39,11 +39,13 @@ async def to_code(config): pass -def _process_git_config(config: dict, refresh) -> str: +def _process_git_config(config: dict, refresh, skip_update: bool = False) -> str: + # When skip_update is True, use NEVER_REFRESH to prevent updates + actual_refresh = git.NEVER_REFRESH if skip_update else refresh repo_dir, _ = git.clone_or_update( url=config[CONF_URL], ref=config.get(CONF_REF), - refresh=refresh, + refresh=actual_refresh, domain=DOMAIN, username=config.get(CONF_USERNAME), password=config.get(CONF_PASSWORD), @@ -70,12 +72,12 @@ def _process_git_config(config: dict, refresh) -> str: return components_dir -def _process_single_config(config: dict): +def _process_single_config(config: dict, skip_update: bool = False): conf = config[CONF_SOURCE] if conf[CONF_TYPE] == TYPE_GIT: with cv.prepend_path([CONF_SOURCE]): components_dir = _process_git_config( - config[CONF_SOURCE], config[CONF_REFRESH] + config[CONF_SOURCE], config[CONF_REFRESH], skip_update ) elif conf[CONF_TYPE] == TYPE_LOCAL: components_dir = Path(CORE.relative_config_path(conf[CONF_PATH])) @@ -105,7 +107,7 @@ def _process_single_config(config: dict): loader.install_meta_finder(components_dir, allowed_components=allowed_components) -def do_external_components_pass(config: dict) -> None: +def do_external_components_pass(config: dict, skip_update: bool = False) -> None: conf = config.get(DOMAIN) if conf is None: return @@ -113,4 +115,4 @@ def do_external_components_pass(config: dict) -> None: conf = CONFIG_SCHEMA(conf) for i, c in enumerate(conf): with cv.prepend_path(i): - _process_single_config(c) + _process_single_config(c, skip_update) diff --git a/esphome/components/fingerprint_grow/fingerprint_grow.cpp b/esphome/components/fingerprint_grow/fingerprint_grow.cpp index 54a267a404..eb7ede8fe9 100644 --- a/esphome/components/fingerprint_grow/fingerprint_grow.cpp +++ b/esphome/components/fingerprint_grow/fingerprint_grow.cpp @@ -80,7 +80,7 @@ void FingerprintGrowComponent::setup() { delay(20); // This delay guarantees the sensor will in fact be powered power. if (this->check_password_()) { - if (this->new_password_ != -1) { + if (this->new_password_ != std::numeric_limits::max()) { if (this->set_password_()) return; } else { diff --git a/esphome/components/fingerprint_grow/fingerprint_grow.h b/esphome/components/fingerprint_grow/fingerprint_grow.h index 1c3098ef14..590c709c22 100644 --- a/esphome/components/fingerprint_grow/fingerprint_grow.h +++ b/esphome/components/fingerprint_grow/fingerprint_grow.h @@ -6,6 +6,7 @@ #include "esphome/components/binary_sensor/binary_sensor.h" #include "esphome/components/uart/uart.h" +#include #include namespace esphome { @@ -177,7 +178,7 @@ class FingerprintGrowComponent : public PollingComponent, public uart::UARTDevic uint8_t address_[4] = {0xFF, 0xFF, 0xFF, 0xFF}; uint16_t capacity_ = 64; uint32_t password_ = 0x0; - uint32_t new_password_ = -1; + uint32_t new_password_ = std::numeric_limits::max(); GPIOPin *sensing_pin_{nullptr}; GPIOPin *sensor_power_pin_{nullptr}; uint8_t enrollment_image_ = 0; diff --git a/esphome/components/font/__init__.py b/esphome/components/font/__init__.py index 4ecc76c561..ddcee14635 100644 --- a/esphome/components/font/__init__.py +++ b/esphome/components/font/__init__.py @@ -3,7 +3,6 @@ import functools import hashlib from itertools import accumulate import logging -import os from pathlib import Path import re @@ -38,6 +37,7 @@ from esphome.const import ( ) from esphome.core import CORE, HexInt from esphome.helpers import cpp_string_escape +from esphome.types import ConfigType _LOGGER = logging.getLogger(__name__) @@ -253,11 +253,11 @@ def validate_truetype_file(value): return CORE.relative_config_path(cv.file_(value)) -def add_local_file(value): +def add_local_file(value: ConfigType) -> ConfigType: if value in FONT_CACHE: return value - path = value[CONF_PATH] - if not os.path.isfile(path): + path = Path(value[CONF_PATH]) + if not path.is_file(): raise cv.Invalid(f"File '{path}' not found.") FONT_CACHE[value] = path return value @@ -318,7 +318,7 @@ def download_gfont(value): external_files.compute_local_file_dir(DOMAIN) / f"{value[CONF_FAMILY]}@{value[CONF_WEIGHT]}@{value[CONF_ITALIC]}@v1.ttf" ) - if not external_files.is_file_recent(str(path), value[CONF_REFRESH]): + if not external_files.is_file_recent(path, value[CONF_REFRESH]): _LOGGER.debug("download_gfont: path=%s", path) try: req = requests.get(url, timeout=external_files.NETWORK_TIMEOUT) diff --git a/esphome/components/graph/graph.cpp b/esphome/components/graph/graph.cpp index 5abf2ade0d..ac6ace96ee 100644 --- a/esphome/components/graph/graph.cpp +++ b/esphome/components/graph/graph.cpp @@ -179,7 +179,7 @@ void Graph::draw(Display *buff, uint16_t x_offset, uint16_t y_offset, Color colo if (b) { int16_t y = (int16_t) roundf((this->height_ - 1) * (1.0 - v)) - thick / 2 + y_offset; auto draw_pixel_at = [&buff, c, y_offset, this](int16_t x, int16_t y) { - if (y >= y_offset && y < y_offset + this->height_) + if (y >= y_offset && static_cast(y) < y_offset + this->height_) buff->draw_pixel_at(x, y, c); }; if (!continuous || !has_prev || !prev_b || (abs(y - prev_y) <= thick)) { diff --git a/esphome/components/graphical_display_menu/graphical_display_menu.cpp b/esphome/components/graphical_display_menu/graphical_display_menu.cpp index 1a29536b46..2b120a746f 100644 --- a/esphome/components/graphical_display_menu/graphical_display_menu.cpp +++ b/esphome/components/graphical_display_menu/graphical_display_menu.cpp @@ -116,7 +116,7 @@ void GraphicalDisplayMenu::draw_menu_internal_(display::Display *display, const int number_items_fit_to_screen = 0; const int max_item_index = this->displayed_item_->items_size() - 1; - for (size_t i = 0; i <= max_item_index; i++) { + for (size_t i = 0; max_item_index >= 0 && i <= static_cast(max_item_index); i++) { const auto *item = this->displayed_item_->get_item(i); const bool selected = i == this->cursor_index_; const display::Rect item_dimensions = this->measure_item(display, item, bounds, selected); @@ -174,7 +174,8 @@ void GraphicalDisplayMenu::draw_menu_internal_(display::Display *display, const display->filled_rectangle(bounds->x, bounds->y, max_width, total_height, this->background_color_); auto y_offset = bounds->y; - for (size_t i = first_item_index; i <= last_item_index; i++) { + for (size_t i = static_cast(first_item_index); + last_item_index >= 0 && i <= static_cast(last_item_index); i++) { const auto *item = this->displayed_item_->get_item(i); const bool selected = i == this->cursor_index_; display::Rect dimensions = menu_dimensions[i]; diff --git a/esphome/components/haier/hon_climate.cpp b/esphome/components/haier/hon_climate.cpp index 9614bb1e47..76558f2ebb 100644 --- a/esphome/components/haier/hon_climate.cpp +++ b/esphome/components/haier/hon_climate.cpp @@ -213,7 +213,7 @@ haier_protocol::HandlerError HonClimate::status_handler_(haier_protocol::FrameTy this->real_control_packet_size_); this->status_message_callback_.call((const char *) data, data_size); } else { - ESP_LOGW(TAG, "Status packet too small: %d (should be >= %d)", data_size, this->real_control_packet_size_); + ESP_LOGW(TAG, "Status packet too small: %zu (should be >= %zu)", data_size, this->real_control_packet_size_); } switch (this->protocol_phase_) { case ProtocolPhases::SENDING_FIRST_STATUS_REQUEST: @@ -827,7 +827,7 @@ haier_protocol::HandlerError HonClimate::process_status_message_(const uint8_t * size_t expected_size = 2 + this->status_message_header_size_ + this->real_control_packet_size_ + this->real_sensors_packet_size_; if (size < expected_size) { - ESP_LOGW(TAG, "Unexpected message size %d (expexted >= %d)", size, expected_size); + ESP_LOGW(TAG, "Unexpected message size %u (expexted >= %zu)", size, expected_size); return haier_protocol::HandlerError::WRONG_MESSAGE_STRUCTURE; } uint16_t subtype = (((uint16_t) packet_buffer[0]) << 8) + packet_buffer[1]; diff --git a/esphome/components/haier/hon_climate.h b/esphome/components/haier/hon_climate.h index 58173f8154..a567ab1d89 100644 --- a/esphome/components/haier/hon_climate.h +++ b/esphome/components/haier/hon_climate.h @@ -178,7 +178,7 @@ class HonClimate : public HaierClimateBase { int extra_control_packet_bytes_{0}; int extra_sensors_packet_bytes_{4}; int status_message_header_size_{0}; - int real_control_packet_size_{sizeof(hon_protocol::HaierPacketControl)}; + size_t real_control_packet_size_{sizeof(hon_protocol::HaierPacketControl)}; int real_sensors_packet_size_{sizeof(hon_protocol::HaierPacketSensors) + 4}; HonControlMethod control_method_; std::queue control_messages_queue_; diff --git a/esphome/components/hdc1080/hdc1080.cpp b/esphome/components/hdc1080/hdc1080.cpp index 6d16133c36..71b7cd7e6e 100644 --- a/esphome/components/hdc1080/hdc1080.cpp +++ b/esphome/components/hdc1080/hdc1080.cpp @@ -7,24 +7,20 @@ namespace hdc1080 { static const char *const TAG = "hdc1080"; -static const uint8_t HDC1080_ADDRESS = 0x40; // 0b1000000 from datasheet static const uint8_t HDC1080_CMD_CONFIGURATION = 0x02; static const uint8_t HDC1080_CMD_TEMPERATURE = 0x00; static const uint8_t HDC1080_CMD_HUMIDITY = 0x01; void HDC1080Component::setup() { - const uint8_t data[2] = { - 0b00000000, // resolution 14bit for both humidity and temperature - 0b00000000 // reserved - }; + const uint8_t config[2] = {0x00, 0x00}; // resolution 14bit for both humidity and temperature - if (!this->write_bytes(HDC1080_CMD_CONFIGURATION, data, 2)) { - // as instruction is same as powerup defaults (for now), interpret as warning if this fails - ESP_LOGW(TAG, "HDC1080 initial config instruction error"); - this->status_set_warning(); + // if configuration fails - there is a problem + if (this->write_register(HDC1080_CMD_CONFIGURATION, config, 2) != i2c::ERROR_OK) { + this->mark_failed(); return; } } + void HDC1080Component::dump_config() { ESP_LOGCONFIG(TAG, "HDC1080:"); LOG_I2C_DEVICE(this); @@ -35,39 +31,51 @@ void HDC1080Component::dump_config() { LOG_SENSOR(" ", "Temperature", this->temperature_); LOG_SENSOR(" ", "Humidity", this->humidity_); } + void HDC1080Component::update() { - uint16_t raw_temp; + // regardless of what sensor/s are defined in yaml configuration + // the hdc1080 setup configuration used, requires both temperature and humidity to be read + + this->status_clear_warning(); + if (this->write(&HDC1080_CMD_TEMPERATURE, 1) != i2c::ERROR_OK) { this->status_set_warning(); return; } - delay(20); - if (this->read(reinterpret_cast(&raw_temp), 2) != i2c::ERROR_OK) { - this->status_set_warning(); - return; - } - raw_temp = i2c::i2ctohs(raw_temp); - float temp = raw_temp * 0.0025177f - 40.0f; // raw * 2^-16 * 165 - 40 - this->temperature_->publish_state(temp); - uint16_t raw_humidity; - if (this->write(&HDC1080_CMD_HUMIDITY, 1) != i2c::ERROR_OK) { - this->status_set_warning(); - return; - } - delay(20); - if (this->read(reinterpret_cast(&raw_humidity), 2) != i2c::ERROR_OK) { - this->status_set_warning(); - return; - } - raw_humidity = i2c::i2ctohs(raw_humidity); - float humidity = raw_humidity * 0.001525879f; // raw * 2^-16 * 100 - this->humidity_->publish_state(humidity); + this->set_timeout(20, [this]() { + uint16_t raw_temperature; + if (this->read(reinterpret_cast(&raw_temperature), 2) != i2c::ERROR_OK) { + this->status_set_warning(); + return; + } - ESP_LOGD(TAG, "Got temperature=%.1f°C humidity=%.1f%%", temp, humidity); - this->status_clear_warning(); + if (this->temperature_ != nullptr) { + raw_temperature = i2c::i2ctohs(raw_temperature); + float temperature = raw_temperature * 0.0025177f - 40.0f; // raw * 2^-16 * 165 - 40 + this->temperature_->publish_state(temperature); + } + + if (this->write(&HDC1080_CMD_HUMIDITY, 1) != i2c::ERROR_OK) { + this->status_set_warning(); + return; + } + + this->set_timeout(20, [this]() { + uint16_t raw_humidity; + if (this->read(reinterpret_cast(&raw_humidity), 2) != i2c::ERROR_OK) { + this->status_set_warning(); + return; + } + + if (this->humidity_ != nullptr) { + raw_humidity = i2c::i2ctohs(raw_humidity); + float humidity = raw_humidity * 0.001525879f; // raw * 2^-16 * 100 + this->humidity_->publish_state(humidity); + } + }); + }); } -float HDC1080Component::get_setup_priority() const { return setup_priority::DATA; } } // namespace hdc1080 } // namespace esphome diff --git a/esphome/components/hdc1080/hdc1080.h b/esphome/components/hdc1080/hdc1080.h index 2ff7b6dc33..7ad0764f1f 100644 --- a/esphome/components/hdc1080/hdc1080.h +++ b/esphome/components/hdc1080/hdc1080.h @@ -12,13 +12,11 @@ class HDC1080Component : public PollingComponent, public i2c::I2CDevice { void set_temperature(sensor::Sensor *temperature) { temperature_ = temperature; } void set_humidity(sensor::Sensor *humidity) { humidity_ = humidity; } - /// Setup the sensor and check for connection. void setup() override; void dump_config() override; - /// Retrieve the latest sensor values. This operation takes approximately 16ms. void update() override; - float get_setup_priority() const override; + float get_setup_priority() const override { return setup_priority::DATA; } protected: sensor::Sensor *temperature_{nullptr}; diff --git a/esphome/components/homeassistant/number/homeassistant_number.cpp b/esphome/components/homeassistant/number/homeassistant_number.cpp index 87bf6727f2..c9fb006568 100644 --- a/esphome/components/homeassistant/number/homeassistant_number.cpp +++ b/esphome/components/homeassistant/number/homeassistant_number.cpp @@ -87,7 +87,7 @@ void HomeassistantNumber::control(float value) { static constexpr auto ENTITY_ID_KEY = StringRef::from_lit("entity_id"); static constexpr auto VALUE_KEY = StringRef::from_lit("value"); - api::HomeassistantServiceResponse resp; + api::HomeassistantActionRequest resp; resp.set_service(SERVICE_NAME); resp.data.emplace_back(); @@ -100,7 +100,7 @@ void HomeassistantNumber::control(float value) { entity_value.set_key(VALUE_KEY); entity_value.value = to_string(value); - api::global_api_server->send_homeassistant_service_call(resp); + api::global_api_server->send_homeassistant_action(resp); } } // namespace homeassistant diff --git a/esphome/components/homeassistant/switch/homeassistant_switch.cpp b/esphome/components/homeassistant/switch/homeassistant_switch.cpp index b3300335b9..8feec26fe6 100644 --- a/esphome/components/homeassistant/switch/homeassistant_switch.cpp +++ b/esphome/components/homeassistant/switch/homeassistant_switch.cpp @@ -44,7 +44,7 @@ void HomeassistantSwitch::write_state(bool state) { static constexpr auto SERVICE_OFF = StringRef::from_lit("homeassistant.turn_off"); static constexpr auto ENTITY_ID_KEY = StringRef::from_lit("entity_id"); - api::HomeassistantServiceResponse resp; + api::HomeassistantActionRequest resp; if (state) { resp.set_service(SERVICE_ON); } else { @@ -56,7 +56,7 @@ void HomeassistantSwitch::write_state(bool state) { entity_id_kv.set_key(ENTITY_ID_KEY); entity_id_kv.value = this->entity_id_; - api::global_api_server->send_homeassistant_service_call(resp); + api::global_api_server->send_homeassistant_action(resp); } } // namespace homeassistant diff --git a/esphome/components/http_request/__init__.py b/esphome/components/http_request/__init__.py index 146458f53b..e428838c83 100644 --- a/esphome/components/http_request/__init__.py +++ b/esphome/components/http_request/__init__.py @@ -5,10 +5,12 @@ from esphome.components.const import CONF_REQUEST_HEADERS from esphome.config_helpers import filter_source_files_from_platform import esphome.config_validation as cv from esphome.const import ( + CONF_CAPTURE_RESPONSE, CONF_ESP8266_DISABLE_SSL_SUPPORT, CONF_ID, CONF_METHOD, CONF_ON_ERROR, + CONF_ON_RESPONSE, CONF_TIMEOUT, CONF_TRIGGER_ID, CONF_URL, @@ -52,12 +54,10 @@ CONF_BUFFER_SIZE_TX = "buffer_size_tx" CONF_CA_CERTIFICATE_PATH = "ca_certificate_path" CONF_MAX_RESPONSE_BUFFER_SIZE = "max_response_buffer_size" -CONF_ON_RESPONSE = "on_response" CONF_HEADERS = "headers" CONF_COLLECT_HEADERS = "collect_headers" CONF_BODY = "body" CONF_JSON = "json" -CONF_CAPTURE_RESPONSE = "capture_response" def validate_url(value): @@ -194,7 +194,7 @@ async def to_code(config): cg.add_define("CPPHTTPLIB_OPENSSL_SUPPORT") elif path := config.get(CONF_CA_CERTIFICATE_PATH): cg.add_define("CPPHTTPLIB_OPENSSL_SUPPORT") - cg.add(var.set_ca_path(path)) + cg.add(var.set_ca_path(str(path))) cg.add_build_flag("-lssl") cg.add_build_flag("-lcrypto") diff --git a/esphome/components/htu21d/htu21d.cpp b/esphome/components/htu21d/htu21d.cpp index f2e7ae93cb..a7aae16f17 100644 --- a/esphome/components/htu21d/htu21d.cpp +++ b/esphome/components/htu21d/htu21d.cpp @@ -9,8 +9,8 @@ static const char *const TAG = "htu21d"; static const uint8_t HTU21D_ADDRESS = 0x40; static const uint8_t HTU21D_REGISTER_RESET = 0xFE; -static const uint8_t HTU21D_REGISTER_TEMPERATURE = 0xF3; -static const uint8_t HTU21D_REGISTER_HUMIDITY = 0xF5; +static const uint8_t HTU21D_REGISTER_TEMPERATURE = 0xE3; +static const uint8_t HTU21D_REGISTER_HUMIDITY = 0xE5; static const uint8_t HTU21D_WRITERHT_REG_CMD = 0xE6; /**< Write RH/T User Register 1 */ static const uint8_t HTU21D_REGISTER_STATUS = 0xE7; static const uint8_t HTU21D_WRITEHEATER_REG_CMD = 0x51; /**< Write Heater Control Register */ @@ -57,7 +57,6 @@ void HTU21DComponent::update() { if (this->temperature_ != nullptr) this->temperature_->publish_state(temperature); - this->status_clear_warning(); if (this->write(&HTU21D_REGISTER_HUMIDITY, 1) != i2c::ERROR_OK) { this->status_set_warning(); @@ -79,10 +78,11 @@ void HTU21DComponent::update() { if (this->humidity_ != nullptr) this->humidity_->publish_state(humidity); - int8_t heater_level; + this->status_clear_warning(); // HTU21D does have a heater module but does not have heater level // Setting heater level to 1 in case the heater is ON + uint8_t heater_level = 0; if (this->sensor_model_ == HTU21D_SENSOR_MODEL_HTU21D) { if (this->is_heater_enabled()) { heater_level = 1; @@ -97,34 +97,30 @@ void HTU21DComponent::update() { if (this->heater_ != nullptr) this->heater_->publish_state(heater_level); - this->status_clear_warning(); }); }); } bool HTU21DComponent::is_heater_enabled() { uint8_t raw_heater; - if (this->read_register(HTU21D_REGISTER_STATUS, reinterpret_cast(&raw_heater), 2) != i2c::ERROR_OK) { + if (this->read_register(HTU21D_REGISTER_STATUS, &raw_heater, 1) != i2c::ERROR_OK) { this->status_set_warning(); return false; } - raw_heater = i2c::i2ctohs(raw_heater); - return (bool) (((raw_heater) >> (HTU21D_REG_HTRE_BIT)) & 0x01); + return (bool) ((raw_heater >> HTU21D_REG_HTRE_BIT) & 0x01); } void HTU21DComponent::set_heater(bool status) { uint8_t raw_heater; - if (this->read_register(HTU21D_REGISTER_STATUS, reinterpret_cast(&raw_heater), 2) != i2c::ERROR_OK) { + if (this->read_register(HTU21D_REGISTER_STATUS, &raw_heater, 1) != i2c::ERROR_OK) { this->status_set_warning(); return; } - raw_heater = i2c::i2ctohs(raw_heater); if (status) { - raw_heater |= (1 << (HTU21D_REG_HTRE_BIT)); + raw_heater |= (1 << HTU21D_REG_HTRE_BIT); } else { - raw_heater &= ~(1 << (HTU21D_REG_HTRE_BIT)); + raw_heater &= ~(1 << HTU21D_REG_HTRE_BIT); } - if (this->write_register(HTU21D_WRITERHT_REG_CMD, &raw_heater, 1) != i2c::ERROR_OK) { this->status_set_warning(); return; @@ -138,14 +134,13 @@ void HTU21DComponent::set_heater_level(uint8_t level) { } } -int8_t HTU21DComponent::get_heater_level() { - int8_t raw_heater; - if (this->read_register(HTU21D_READHEATER_REG_CMD, reinterpret_cast(&raw_heater), 2) != i2c::ERROR_OK) { +uint8_t HTU21DComponent::get_heater_level() { + uint8_t raw_heater; + if (this->read_register(HTU21D_READHEATER_REG_CMD, &raw_heater, 1) != i2c::ERROR_OK) { this->status_set_warning(); return 0; } - raw_heater = i2c::i2ctohs(raw_heater); - return raw_heater; + return raw_heater & 0xF; } float HTU21DComponent::get_setup_priority() const { return setup_priority::DATA; } diff --git a/esphome/components/htu21d/htu21d.h b/esphome/components/htu21d/htu21d.h index 8533875d43..9b3831b784 100644 --- a/esphome/components/htu21d/htu21d.h +++ b/esphome/components/htu21d/htu21d.h @@ -26,7 +26,7 @@ class HTU21DComponent : public PollingComponent, public i2c::I2CDevice { bool is_heater_enabled(); void set_heater(bool status); void set_heater_level(uint8_t level); - int8_t get_heater_level(); + uint8_t get_heater_level(); float get_setup_priority() const override; diff --git a/esphome/components/i2s_audio/__init__.py b/esphome/components/i2s_audio/__init__.py index cff91a546f..8ceff26d84 100644 --- a/esphome/components/i2s_audio/__init__.py +++ b/esphome/components/i2s_audio/__init__.py @@ -262,8 +262,7 @@ async def to_code(config): cg.add_define("USE_I2S_LEGACY") # Helps avoid callbacks being skipped due to processor load - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_I2S_ISR_IRAM_SAFE", True) + add_idf_sdkconfig_option("CONFIG_I2S_ISR_IRAM_SAFE", True) cg.add(var.set_lrclk_pin(config[CONF_I2S_LRCLK_PIN])) if CONF_I2S_BCLK_PIN in config: diff --git a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp index 7ae3ec8b3b..53e378c41e 100644 --- a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp +++ b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp @@ -377,7 +377,7 @@ void I2SAudioSpeaker::speaker_task(void *params) { this_speaker->current_stream_info_.get_bits_per_sample() <= 16) { size_t len = bytes_read / sizeof(int16_t); int16_t *tmp_buf = (int16_t *) new_data; - for (int i = 0; i < len; i += 2) { + for (size_t i = 0; i < len; i += 2) { int16_t tmp = tmp_buf[i]; tmp_buf[i] = tmp_buf[i + 1]; tmp_buf[i + 1] = tmp; diff --git a/esphome/components/ili9xxx/ili9xxx_display.cpp b/esphome/components/ili9xxx/ili9xxx_display.cpp index ec0a860aa8..2a3d0edca7 100644 --- a/esphome/components/ili9xxx/ili9xxx_display.cpp +++ b/esphome/components/ili9xxx/ili9xxx_display.cpp @@ -325,7 +325,7 @@ void ILI9XXXDisplay::draw_pixels_at(int x_start, int y_start, int w, int h, cons // we could deal here with a non-zero y_offset, but if x_offset is zero, y_offset probably will be so don't bother this->write_array(ptr, w * h * 2); } else { - for (size_t y = 0; y != h; y++) { + for (size_t y = 0; y != static_cast(h); y++) { this->write_array(ptr + (y + y_offset) * stride + x_offset, w * 2); } } @@ -349,7 +349,7 @@ void ILI9XXXDisplay::draw_pixels_at(int x_start, int y_start, int w, int h, cons App.feed_wdt(); } // end of line? Skip to the next. - if (++pixel == w) { + if (++pixel == static_cast(w)) { pixel = 0; ptr += (x_pad + x_offset) * 2; } diff --git a/esphome/components/improv_serial/improv_serial_component.cpp b/esphome/components/improv_serial/improv_serial_component.cpp index ae4927828b..528a155a7f 100644 --- a/esphome/components/improv_serial/improv_serial_component.cpp +++ b/esphome/components/improv_serial/improv_serial_component.cpp @@ -15,11 +15,10 @@ static const char *const TAG = "improv_serial"; void ImprovSerialComponent::setup() { global_improv_serial_component = this; -#ifdef USE_ARDUINO - this->hw_serial_ = logger::global_logger->get_hw_serial(); -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 this->uart_num_ = logger::global_logger->get_uart_num(); +#elif defined(USE_ARDUINO) + this->hw_serial_ = logger::global_logger->get_hw_serial(); #endif if (wifi::global_wifi_component->has_sta()) { @@ -34,13 +33,7 @@ void ImprovSerialComponent::dump_config() { ESP_LOGCONFIG(TAG, "Improv Serial:") optional ImprovSerialComponent::read_byte_() { optional byte; uint8_t data = 0; -#ifdef USE_ARDUINO - if (this->hw_serial_->available()) { - this->hw_serial_->readBytes(&data, 1); - byte = data; - } -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 switch (logger::global_logger->get_uart()) { case logger::UART_SELECTION_UART0: case logger::UART_SELECTION_UART1: @@ -76,16 +69,18 @@ optional ImprovSerialComponent::read_byte_() { default: break; } +#elif defined(USE_ARDUINO) + if (this->hw_serial_->available()) { + this->hw_serial_->readBytes(&data, 1); + byte = data; + } #endif return byte; } void ImprovSerialComponent::write_data_(std::vector &data) { data.push_back('\n'); -#ifdef USE_ARDUINO - this->hw_serial_->write(data.data(), data.size()); -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 switch (logger::global_logger->get_uart()) { case logger::UART_SELECTION_UART0: case logger::UART_SELECTION_UART1: @@ -112,6 +107,8 @@ void ImprovSerialComponent::write_data_(std::vector &data) { default: break; } +#elif defined(USE_ARDUINO) + this->hw_serial_->write(data.data(), data.size()); #endif } diff --git a/esphome/components/improv_serial/improv_serial_component.h b/esphome/components/improv_serial/improv_serial_component.h index 5d2534c2fc..c3c9aee24e 100644 --- a/esphome/components/improv_serial/improv_serial_component.h +++ b/esphome/components/improv_serial/improv_serial_component.h @@ -9,10 +9,7 @@ #include #include -#ifdef USE_ARDUINO -#include -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include #if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) || defined(USE_ESP32_VARIANT_ESP32S3) || \ defined(USE_ESP32_VARIANT_ESP32H2) @@ -22,6 +19,8 @@ #if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) #include #endif +#elif defined(USE_ARDUINO) +#include #endif namespace esphome { @@ -60,11 +59,10 @@ class ImprovSerialComponent : public Component, public improv_base::ImprovBase { optional read_byte_(); void write_data_(std::vector &data); -#ifdef USE_ARDUINO - Stream *hw_serial_{nullptr}; -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 uart_port_t uart_num_; +#elif defined(USE_ARDUINO) + Stream *hw_serial_{nullptr}; #endif std::vector rx_buffer_; diff --git a/esphome/components/json/json_util.cpp b/esphome/components/json/json_util.cpp index 94c531222a..dbdf6e3486 100644 --- a/esphome/components/json/json_util.cpp +++ b/esphome/components/json/json_util.cpp @@ -8,70 +8,62 @@ namespace json { static const char *const TAG = "json"; -// Build an allocator for the JSON Library using the RAMAllocator class -struct SpiRamAllocator : ArduinoJson::Allocator { - void *allocate(size_t size) override { return this->allocator_.allocate(size); } - - void deallocate(void *pointer) override { - // ArduinoJson's Allocator interface doesn't provide the size parameter in deallocate. - // RAMAllocator::deallocate() requires the size, which we don't have access to here. - // RAMAllocator::deallocate implementation just calls free() regardless of whether - // the memory was allocated with heap_caps_malloc or malloc. - // This is safe because ESP-IDF's heap implementation internally tracks the memory region - // and routes free() to the appropriate heap. - free(pointer); // NOLINT(cppcoreguidelines-owning-memory,cppcoreguidelines-no-malloc) - } - - void *reallocate(void *ptr, size_t new_size) override { - return this->allocator_.reallocate(static_cast(ptr), new_size); - } - - protected: - RAMAllocator allocator_{RAMAllocator(RAMAllocator::NONE)}; -}; - std::string build_json(const json_build_t &f) { // NOLINTBEGIN(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson - auto doc_allocator = SpiRamAllocator(); - JsonDocument json_document(&doc_allocator); - if (json_document.overflowed()) { - ESP_LOGE(TAG, "Could not allocate memory for JSON document!"); - return "{}"; - } - JsonObject root = json_document.to(); + JsonBuilder builder; + JsonObject root = builder.root(); f(root); - if (json_document.overflowed()) { - ESP_LOGE(TAG, "Could not allocate memory for JSON document!"); - return "{}"; - } - std::string output; - serializeJson(json_document, output); - return output; + return builder.serialize(); // NOLINTEND(clang-analyzer-cplusplus.NewDeleteLeaks) } bool parse_json(const std::string &data, const json_parse_t &f) { // NOLINTBEGIN(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson + JsonDocument doc = parse_json(reinterpret_cast(data.c_str()), data.size()); + if (doc.overflowed() || doc.isNull()) + return false; + return f(doc.as()); + // NOLINTEND(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +JsonDocument parse_json(const uint8_t *data, size_t len) { + // NOLINTBEGIN(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson + if (data == nullptr || len == 0) { + ESP_LOGE(TAG, "No data to parse"); + return JsonObject(); // return unbound object + } +#ifdef USE_PSRAM auto doc_allocator = SpiRamAllocator(); JsonDocument json_document(&doc_allocator); +#else + JsonDocument json_document; +#endif if (json_document.overflowed()) { ESP_LOGE(TAG, "Could not allocate memory for JSON document!"); - return false; + return JsonObject(); // return unbound object } - DeserializationError err = deserializeJson(json_document, data); - - JsonObject root = json_document.as(); + DeserializationError err = deserializeJson(json_document, data, len); if (err == DeserializationError::Ok) { - return f(root); + return json_document; } else if (err == DeserializationError::NoMemory) { ESP_LOGE(TAG, "Can not allocate more memory for deserialization. Consider making source string smaller"); - return false; + return JsonObject(); // return unbound object } ESP_LOGE(TAG, "Parse error: %s", err.c_str()); - return false; + return JsonObject(); // return unbound object // NOLINTEND(clang-analyzer-cplusplus.NewDeleteLeaks) } +std::string JsonBuilder::serialize() { + if (doc_.overflowed()) { + ESP_LOGE(TAG, "JSON document overflow"); + return "{}"; + } + std::string output; + serializeJson(doc_, output); + return output; +} + } // namespace json } // namespace esphome diff --git a/esphome/components/json/json_util.h b/esphome/components/json/json_util.h index 72d31c8afe..91cc84dc14 100644 --- a/esphome/components/json/json_util.h +++ b/esphome/components/json/json_util.h @@ -2,6 +2,7 @@ #include +#include "esphome/core/defines.h" #include "esphome/core/helpers.h" #define ARDUINOJSON_ENABLE_STD_STRING 1 // NOLINT @@ -13,6 +14,31 @@ namespace esphome { namespace json { +#ifdef USE_PSRAM +// Build an allocator for the JSON Library using the RAMAllocator class +// This is only compiled when PSRAM is enabled +struct SpiRamAllocator : ArduinoJson::Allocator { + void *allocate(size_t size) override { return allocator_.allocate(size); } + + void deallocate(void *ptr) override { + // ArduinoJson's Allocator interface doesn't provide the size parameter in deallocate. + // RAMAllocator::deallocate() requires the size, which we don't have access to here. + // RAMAllocator::deallocate implementation just calls free() regardless of whether + // the memory was allocated with heap_caps_malloc or malloc. + // This is safe because ESP-IDF's heap implementation internally tracks the memory region + // and routes free() to the appropriate heap. + free(ptr); // NOLINT(cppcoreguidelines-owning-memory,cppcoreguidelines-no-malloc) + } + + void *reallocate(void *ptr, size_t new_size) override { + return allocator_.reallocate(static_cast(ptr), new_size); + } + + protected: + RAMAllocator allocator_{RAMAllocator::NONE}; +}; +#endif + /// Callback function typedef for parsing JsonObjects. using json_parse_t = std::function; @@ -25,5 +51,36 @@ std::string build_json(const json_build_t &f); /// Parse a JSON string and run the provided json parse function if it's valid. bool parse_json(const std::string &data, const json_parse_t &f); +/// Parse a JSON string and return the root JsonDocument (or an unbound object on error) +JsonDocument parse_json(const uint8_t *data, size_t len); +/// Parse a JSON string and return the root JsonDocument (or an unbound object on error) +inline JsonDocument parse_json(const std::string &data) { + return parse_json(reinterpret_cast(data.c_str()), data.size()); +} + +/// Builder class for creating JSON documents without lambdas +class JsonBuilder { + public: + JsonObject root() { + if (!root_created_) { + root_ = doc_.to(); + root_created_ = true; + } + return root_; + } + + std::string serialize(); + + private: +#ifdef USE_PSRAM + SpiRamAllocator allocator_; + JsonDocument doc_{&allocator_}; +#else + JsonDocument doc_; +#endif + JsonObject root_; + bool root_created_{false}; +}; + } // namespace json } // namespace esphome diff --git a/esphome/components/kamstrup_kmp/kamstrup_kmp.cpp b/esphome/components/kamstrup_kmp/kamstrup_kmp.cpp index c058c7b3aa..e5fa035682 100644 --- a/esphome/components/kamstrup_kmp/kamstrup_kmp.cpp +++ b/esphome/components/kamstrup_kmp/kamstrup_kmp.cpp @@ -22,7 +22,7 @@ void KamstrupKMPComponent::dump_config() { LOG_SENSOR(" ", "Flow", this->flow_sensor_); LOG_SENSOR(" ", "Volume", this->volume_sensor_); - for (int i = 0; i < this->custom_sensors_.size(); i++) { + for (size_t i = 0; i < this->custom_sensors_.size(); i++) { LOG_SENSOR(" ", "Custom Sensor", this->custom_sensors_[i]); ESP_LOGCONFIG(TAG, " Command: 0x%04X", this->custom_commands_[i]); } @@ -268,7 +268,7 @@ void KamstrupKMPComponent::set_sensor_value_(uint16_t command, float value, uint } // Custom sensors - for (int i = 0; i < this->custom_commands_.size(); i++) { + for (size_t i = 0; i < this->custom_commands_.size(); i++) { if (command == this->custom_commands_[i]) { this->custom_sensors_[i]->publish_state(value); } diff --git a/esphome/components/key_collector/key_collector.h b/esphome/components/key_collector/key_collector.h index 6e585ddd8e..35e8141ce5 100644 --- a/esphome/components/key_collector/key_collector.h +++ b/esphome/components/key_collector/key_collector.h @@ -13,8 +13,8 @@ class KeyCollector : public Component { void loop() override; void dump_config() override; void set_provider(key_provider::KeyProvider *provider); - void set_min_length(int min_length) { this->min_length_ = min_length; }; - void set_max_length(int max_length) { this->max_length_ = max_length; }; + void set_min_length(uint32_t min_length) { this->min_length_ = min_length; }; + void set_max_length(uint32_t max_length) { this->max_length_ = max_length; }; void set_start_keys(std::string start_keys) { this->start_keys_ = std::move(start_keys); }; void set_end_keys(std::string end_keys) { this->end_keys_ = std::move(end_keys); }; void set_end_key_required(bool end_key_required) { this->end_key_required_ = end_key_required; }; @@ -33,8 +33,8 @@ class KeyCollector : public Component { protected: void key_pressed_(uint8_t key); - int min_length_{0}; - int max_length_{0}; + uint32_t min_length_{0}; + uint32_t max_length_{0}; std::string start_keys_; std::string end_keys_; bool end_key_required_{false}; diff --git a/esphome/components/libretiny/__init__.py b/esphome/components/libretiny/__init__.py index 178660cb40..c63d6d7faa 100644 --- a/esphome/components/libretiny/__init__.py +++ b/esphome/components/libretiny/__init__.py @@ -1,6 +1,5 @@ import json import logging -from os.path import dirname, isfile, join import esphome.codegen as cg import esphome.config_validation as cv @@ -24,6 +23,7 @@ from esphome.const import ( __version__, ) from esphome.core import CORE +from esphome.storage_json import StorageJSON from . import gpio # noqa from .const import ( @@ -129,7 +129,7 @@ def only_on_family(*, supported=None, unsupported=None): return validator_ -def get_download_types(storage_json=None): +def get_download_types(storage_json: StorageJSON = None): types = [ { "title": "UF2 package (recommended)", @@ -139,11 +139,11 @@ def get_download_types(storage_json=None): }, ] - build_dir = dirname(storage_json.firmware_bin_path) - outputs = join(build_dir, "firmware.json") - if not isfile(outputs): + build_dir = storage_json.firmware_bin_path.parent + outputs = build_dir / "firmware.json" + if not outputs.is_file(): return types - with open(outputs, encoding="utf-8") as f: + with outputs.open(encoding="utf-8") as f: outputs = json.load(f) for output in outputs: if not output["public"]: diff --git a/esphome/components/libretiny/preferences.cpp b/esphome/components/libretiny/preferences.cpp index fc535c99b4..871b186d8e 100644 --- a/esphome/components/libretiny/preferences.cpp +++ b/esphome/components/libretiny/preferences.cpp @@ -15,7 +15,14 @@ static const char *const TAG = "lt.preferences"; struct NVSData { std::string key; - std::vector data; + std::unique_ptr data; + size_t len; + + void set_data(const uint8_t *src, size_t size) { + data = std::make_unique(size); + memcpy(data.get(), src, size); + len = size; + } }; static std::vector s_pending_save; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) @@ -30,15 +37,15 @@ class LibreTinyPreferenceBackend : public ESPPreferenceBackend { // try find in pending saves and update that for (auto &obj : s_pending_save) { if (obj.key == key) { - obj.data.assign(data, data + len); + obj.set_data(data, len); return true; } } NVSData save{}; save.key = key; - save.data.assign(data, data + len); - s_pending_save.emplace_back(save); - ESP_LOGVV(TAG, "s_pending_save: key: %s, len: %d", key.c_str(), len); + save.set_data(data, len); + s_pending_save.emplace_back(std::move(save)); + ESP_LOGVV(TAG, "s_pending_save: key: %s, len: %zu", key.c_str(), len); return true; } @@ -46,11 +53,11 @@ class LibreTinyPreferenceBackend : public ESPPreferenceBackend { // try find in pending saves and load from that for (auto &obj : s_pending_save) { if (obj.key == key) { - if (obj.data.size() != len) { + if (obj.len != len) { // size mismatch return false; } - memcpy(data, obj.data.data(), len); + memcpy(data, obj.data.get(), len); return true; } } @@ -58,10 +65,10 @@ class LibreTinyPreferenceBackend : public ESPPreferenceBackend { fdb_blob_make(blob, data, len); size_t actual_len = fdb_kv_get_blob(db, key.c_str(), blob); if (actual_len != len) { - ESP_LOGVV(TAG, "NVS length does not match (%u!=%u)", actual_len, len); + ESP_LOGVV(TAG, "NVS length does not match (%zu!=%zu)", actual_len, len); return false; } else { - ESP_LOGVV(TAG, "fdb_kv_get_blob: key: %s, len: %d", key.c_str(), len); + ESP_LOGVV(TAG, "fdb_kv_get_blob: key: %s, len: %zu", key.c_str(), len); } return true; } @@ -101,7 +108,7 @@ class LibreTinyPreferences : public ESPPreferences { if (s_pending_save.empty()) return true; - ESP_LOGV(TAG, "Saving %d items...", s_pending_save.size()); + ESP_LOGV(TAG, "Saving %zu items...", s_pending_save.size()); // goal try write all pending saves even if one fails int cached = 0, written = 0, failed = 0; fdb_err_t last_err = FDB_NO_ERR; @@ -112,11 +119,11 @@ class LibreTinyPreferences : public ESPPreferences { const auto &save = s_pending_save[i]; ESP_LOGVV(TAG, "Checking if FDB data %s has changed", save.key.c_str()); if (is_changed(&db, save)) { - ESP_LOGV(TAG, "sync: key: %s, len: %d", save.key.c_str(), save.data.size()); - fdb_blob_make(&blob, save.data.data(), save.data.size()); + ESP_LOGV(TAG, "sync: key: %s, len: %zu", save.key.c_str(), save.len); + fdb_blob_make(&blob, save.data.get(), save.len); fdb_err_t err = fdb_kv_set_blob(&db, save.key.c_str(), &blob); if (err != FDB_NO_ERR) { - ESP_LOGV(TAG, "fdb_kv_set_blob('%s', len=%u) failed: %d", save.key.c_str(), save.data.size(), err); + ESP_LOGV(TAG, "fdb_kv_set_blob('%s', len=%zu) failed: %d", save.key.c_str(), save.len, err); failed++; last_err = err; last_key = save.key; @@ -124,7 +131,7 @@ class LibreTinyPreferences : public ESPPreferences { } written++; } else { - ESP_LOGD(TAG, "FDB data not changed; skipping %s len=%u", save.key.c_str(), save.data.size()); + ESP_LOGD(TAG, "FDB data not changed; skipping %s len=%zu", save.key.c_str(), save.len); cached++; } s_pending_save.erase(s_pending_save.begin() + i); @@ -147,7 +154,7 @@ class LibreTinyPreferences : public ESPPreferences { } // Check size first - if different, data has changed - if (kv.value_len != to_save.data.size()) { + if (kv.value_len != to_save.len) { return true; } @@ -161,7 +168,7 @@ class LibreTinyPreferences : public ESPPreferences { } // Compare the actual data - return memcmp(to_save.data.data(), stored_data.get(), kv.value_len) != 0; + return memcmp(to_save.data.get(), stored_data.get(), kv.value_len) != 0; } bool reset() override { diff --git a/esphome/components/light/light_call.cpp b/esphome/components/light/light_call.cpp index cbe9ed0454..915b8fdf89 100644 --- a/esphome/components/light/light_call.cpp +++ b/esphome/components/light/light_call.cpp @@ -10,11 +10,15 @@ namespace light { static const char *const TAG = "light"; // Helper functions to reduce code size for logging -#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_WARN -static void log_validation_warning(const char *name, const LogString *param_name, float val, float min, float max) { - ESP_LOGW(TAG, "'%s': %s value %.2f is out of range [%.1f - %.1f]", name, LOG_STR_ARG(param_name), val, min, max); +static void clamp_and_log_if_invalid(const char *name, float &value, const LogString *param_name, float min = 0.0f, + float max = 1.0f) { + if (value < min || value > max) { + ESP_LOGW(TAG, "'%s': %s value %.2f is out of range [%.1f - %.1f]", name, LOG_STR_ARG(param_name), value, min, max); + value = clamp(value, min, max); + } } +#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_WARN static void log_feature_not_supported(const char *name, const LogString *feature) { ESP_LOGW(TAG, "'%s': %s not supported", name, LOG_STR_ARG(feature)); } @@ -27,7 +31,6 @@ static void log_invalid_parameter(const char *name, const LogString *message) { ESP_LOGW(TAG, "'%s': %s", name, LOG_STR_ARG(message)); } #else -#define log_validation_warning(name, param_name, val, min, max) #define log_feature_not_supported(name, feature) #define log_color_mode_not_supported(name, feature) #define log_invalid_parameter(name, message) @@ -44,7 +47,7 @@ static void log_invalid_parameter(const char *name, const LogString *message) { } \ LightCall &LightCall::set_##name(type name) { \ this->name##_ = name; \ - this->set_flag_(flag, true); \ + this->set_flag_(flag); \ return *this; \ } @@ -181,6 +184,16 @@ void LightCall::perform() { } } +void LightCall::log_and_clear_unsupported_(FieldFlags flag, const LogString *feature, bool use_color_mode_log) { + auto *name = this->parent_->get_name().c_str(); + if (use_color_mode_log) { + log_color_mode_not_supported(name, feature); + } else { + log_feature_not_supported(name, feature); + } + this->clear_flag_(flag); +} + LightColorValues LightCall::validate_() { auto *name = this->parent_->get_name().c_str(); auto traits = this->parent_->get_traits(); @@ -188,141 +201,108 @@ LightColorValues LightCall::validate_() { // Color mode check if (this->has_color_mode() && !traits.supports_color_mode(this->color_mode_)) { ESP_LOGW(TAG, "'%s' does not support color mode %s", name, LOG_STR_ARG(color_mode_to_human(this->color_mode_))); - this->set_flag_(FLAG_HAS_COLOR_MODE, false); + this->clear_flag_(FLAG_HAS_COLOR_MODE); } // Ensure there is always a color mode set if (!this->has_color_mode()) { this->color_mode_ = this->compute_color_mode_(); - this->set_flag_(FLAG_HAS_COLOR_MODE, true); + this->set_flag_(FLAG_HAS_COLOR_MODE); } auto color_mode = this->color_mode_; // Transform calls that use non-native parameters for the current mode. this->transform_parameters_(); - // Brightness exists check - if (this->has_brightness() && this->brightness_ > 0.0f && !(color_mode & ColorCapability::BRIGHTNESS)) { - log_feature_not_supported(name, LOG_STR("brightness")); - this->set_flag_(FLAG_HAS_BRIGHTNESS, false); - } - - // Transition length possible check - if (this->has_transition_() && this->transition_length_ != 0 && !(color_mode & ColorCapability::BRIGHTNESS)) { - log_feature_not_supported(name, LOG_STR("transitions")); - this->set_flag_(FLAG_HAS_TRANSITION, false); - } - - // Color brightness exists check - if (this->has_color_brightness() && this->color_brightness_ > 0.0f && !(color_mode & ColorCapability::RGB)) { - log_color_mode_not_supported(name, LOG_STR("RGB brightness")); - this->set_flag_(FLAG_HAS_COLOR_BRIGHTNESS, false); - } - - // RGB exists check - if ((this->has_red() && this->red_ > 0.0f) || (this->has_green() && this->green_ > 0.0f) || - (this->has_blue() && this->blue_ > 0.0f)) { - if (!(color_mode & ColorCapability::RGB)) { - log_color_mode_not_supported(name, LOG_STR("RGB color")); - this->set_flag_(FLAG_HAS_RED, false); - this->set_flag_(FLAG_HAS_GREEN, false); - this->set_flag_(FLAG_HAS_BLUE, false); - } - } - - // White value exists check - if (this->has_white() && this->white_ > 0.0f && - !(color_mode & ColorCapability::WHITE || color_mode & ColorCapability::COLD_WARM_WHITE)) { - log_color_mode_not_supported(name, LOG_STR("white value")); - this->set_flag_(FLAG_HAS_WHITE, false); - } - - // Color temperature exists check - if (this->has_color_temperature() && - !(color_mode & ColorCapability::COLOR_TEMPERATURE || color_mode & ColorCapability::COLD_WARM_WHITE)) { - log_color_mode_not_supported(name, LOG_STR("color temperature")); - this->set_flag_(FLAG_HAS_COLOR_TEMPERATURE, false); - } - - // Cold/warm white value exists check - if ((this->has_cold_white() && this->cold_white_ > 0.0f) || (this->has_warm_white() && this->warm_white_ > 0.0f)) { - if (!(color_mode & ColorCapability::COLD_WARM_WHITE)) { - log_color_mode_not_supported(name, LOG_STR("cold/warm white value")); - this->set_flag_(FLAG_HAS_COLD_WHITE, false); - this->set_flag_(FLAG_HAS_WARM_WHITE, false); - } - } - -#define VALIDATE_RANGE_(name_, upper_name, min, max) \ - if (this->has_##name_()) { \ - auto val = this->name_##_; \ - if (val < (min) || val > (max)) { \ - log_validation_warning(name, LOG_STR(upper_name), val, (min), (max)); \ - this->name_##_ = clamp(val, (min), (max)); \ - } \ - } -#define VALIDATE_RANGE(name, upper_name) VALIDATE_RANGE_(name, upper_name, 0.0f, 1.0f) - - // Range checks - VALIDATE_RANGE(brightness, "Brightness") - VALIDATE_RANGE(color_brightness, "Color brightness") - VALIDATE_RANGE(red, "Red") - VALIDATE_RANGE(green, "Green") - VALIDATE_RANGE(blue, "Blue") - VALIDATE_RANGE(white, "White") - VALIDATE_RANGE(cold_white, "Cold white") - VALIDATE_RANGE(warm_white, "Warm white") - VALIDATE_RANGE_(color_temperature, "Color temperature", traits.get_min_mireds(), traits.get_max_mireds()) - + // Business logic adjustments before validation // Flag whether an explicit turn off was requested, in which case we'll also stop the effect. bool explicit_turn_off_request = this->has_state() && !this->state_; // Turn off when brightness is set to zero, and reset brightness (so that it has nonzero brightness when turned on). if (this->has_brightness() && this->brightness_ == 0.0f) { this->state_ = false; - this->set_flag_(FLAG_HAS_STATE, true); + this->set_flag_(FLAG_HAS_STATE); this->brightness_ = 1.0f; } // Set color brightness to 100% if currently zero and a color is set. - if (this->has_red() || this->has_green() || this->has_blue()) { - if (!this->has_color_brightness() && this->parent_->remote_values.get_color_brightness() == 0.0f) { - this->color_brightness_ = 1.0f; - this->set_flag_(FLAG_HAS_COLOR_BRIGHTNESS, true); - } + if ((this->has_red() || this->has_green() || this->has_blue()) && !this->has_color_brightness() && + this->parent_->remote_values.get_color_brightness() == 0.0f) { + this->color_brightness_ = 1.0f; + this->set_flag_(FLAG_HAS_COLOR_BRIGHTNESS); } - // Create color values for the light with this call applied. + // Capability validation + if (this->has_brightness() && this->brightness_ > 0.0f && !(color_mode & ColorCapability::BRIGHTNESS)) + this->log_and_clear_unsupported_(FLAG_HAS_BRIGHTNESS, LOG_STR("brightness"), false); + + // Transition length possible check + if (this->has_transition_() && this->transition_length_ != 0 && !(color_mode & ColorCapability::BRIGHTNESS)) + this->log_and_clear_unsupported_(FLAG_HAS_TRANSITION, LOG_STR("transitions"), false); + + if (this->has_color_brightness() && this->color_brightness_ > 0.0f && !(color_mode & ColorCapability::RGB)) + this->log_and_clear_unsupported_(FLAG_HAS_COLOR_BRIGHTNESS, LOG_STR("RGB brightness"), true); + + // RGB exists check + if (((this->has_red() && this->red_ > 0.0f) || (this->has_green() && this->green_ > 0.0f) || + (this->has_blue() && this->blue_ > 0.0f)) && + !(color_mode & ColorCapability::RGB)) { + log_color_mode_not_supported(name, LOG_STR("RGB color")); + this->clear_flag_(FLAG_HAS_RED); + this->clear_flag_(FLAG_HAS_GREEN); + this->clear_flag_(FLAG_HAS_BLUE); + } + + // White value exists check + if (this->has_white() && this->white_ > 0.0f && + !(color_mode & ColorCapability::WHITE || color_mode & ColorCapability::COLD_WARM_WHITE)) + this->log_and_clear_unsupported_(FLAG_HAS_WHITE, LOG_STR("white value"), true); + + // Color temperature exists check + if (this->has_color_temperature() && + !(color_mode & ColorCapability::COLOR_TEMPERATURE || color_mode & ColorCapability::COLD_WARM_WHITE)) + this->log_and_clear_unsupported_(FLAG_HAS_COLOR_TEMPERATURE, LOG_STR("color temperature"), true); + + // Cold/warm white value exists check + if (((this->has_cold_white() && this->cold_white_ > 0.0f) || (this->has_warm_white() && this->warm_white_ > 0.0f)) && + !(color_mode & ColorCapability::COLD_WARM_WHITE)) { + log_color_mode_not_supported(name, LOG_STR("cold/warm white value")); + this->clear_flag_(FLAG_HAS_COLD_WHITE); + this->clear_flag_(FLAG_HAS_WARM_WHITE); + } + + // Create color values and validate+apply ranges in one step to eliminate duplicate checks auto v = this->parent_->remote_values; if (this->has_color_mode()) v.set_color_mode(this->color_mode_); if (this->has_state()) v.set_state(this->state_); - if (this->has_brightness()) - v.set_brightness(this->brightness_); - if (this->has_color_brightness()) - v.set_color_brightness(this->color_brightness_); - if (this->has_red()) - v.set_red(this->red_); - if (this->has_green()) - v.set_green(this->green_); - if (this->has_blue()) - v.set_blue(this->blue_); - if (this->has_white()) - v.set_white(this->white_); - if (this->has_color_temperature()) - v.set_color_temperature(this->color_temperature_); - if (this->has_cold_white()) - v.set_cold_white(this->cold_white_); - if (this->has_warm_white()) - v.set_warm_white(this->warm_white_); + +#define VALIDATE_AND_APPLY(field, setter, name_str, ...) \ + if (this->has_##field()) { \ + clamp_and_log_if_invalid(name, this->field##_, LOG_STR(name_str), ##__VA_ARGS__); \ + v.setter(this->field##_); \ + } + + VALIDATE_AND_APPLY(brightness, set_brightness, "Brightness") + VALIDATE_AND_APPLY(color_brightness, set_color_brightness, "Color brightness") + VALIDATE_AND_APPLY(red, set_red, "Red") + VALIDATE_AND_APPLY(green, set_green, "Green") + VALIDATE_AND_APPLY(blue, set_blue, "Blue") + VALIDATE_AND_APPLY(white, set_white, "White") + VALIDATE_AND_APPLY(cold_white, set_cold_white, "Cold white") + VALIDATE_AND_APPLY(warm_white, set_warm_white, "Warm white") + VALIDATE_AND_APPLY(color_temperature, set_color_temperature, "Color temperature", traits.get_min_mireds(), + traits.get_max_mireds()) + +#undef VALIDATE_AND_APPLY v.normalize_color(); // Flash length check if (this->has_flash_() && this->flash_length_ == 0) { - log_invalid_parameter(name, LOG_STR("flash length must be greater than zero")); - this->set_flag_(FLAG_HAS_FLASH, false); + log_invalid_parameter(name, LOG_STR("flash length must be >0")); + this->clear_flag_(FLAG_HAS_FLASH); } // validate transition length/flash length/effect not used at the same time @@ -330,42 +310,40 @@ LightColorValues LightCall::validate_() { // If effect is already active, remove effect start if (this->has_effect_() && this->effect_ == this->parent_->active_effect_index_) { - this->set_flag_(FLAG_HAS_EFFECT, false); + this->clear_flag_(FLAG_HAS_EFFECT); } // validate effect index if (this->has_effect_() && this->effect_ > this->parent_->effects_.size()) { ESP_LOGW(TAG, "'%s': invalid effect index %" PRIu32, name, this->effect_); - this->set_flag_(FLAG_HAS_EFFECT, false); + this->clear_flag_(FLAG_HAS_EFFECT); } if (this->has_effect_() && (this->has_transition_() || this->has_flash_())) { log_invalid_parameter(name, LOG_STR("effect cannot be used with transition/flash")); - this->set_flag_(FLAG_HAS_TRANSITION, false); - this->set_flag_(FLAG_HAS_FLASH, false); + this->clear_flag_(FLAG_HAS_TRANSITION); + this->clear_flag_(FLAG_HAS_FLASH); } if (this->has_flash_() && this->has_transition_()) { log_invalid_parameter(name, LOG_STR("flash cannot be used with transition")); - this->set_flag_(FLAG_HAS_TRANSITION, false); + this->clear_flag_(FLAG_HAS_TRANSITION); } if (!this->has_transition_() && !this->has_flash_() && (!this->has_effect_() || this->effect_ == 0) && supports_transition) { // nothing specified and light supports transitions, set default transition length this->transition_length_ = this->parent_->default_transition_length_; - this->set_flag_(FLAG_HAS_TRANSITION, true); + this->set_flag_(FLAG_HAS_TRANSITION); } if (this->has_transition_() && this->transition_length_ == 0) { // 0 transition is interpreted as no transition (instant change) - this->set_flag_(FLAG_HAS_TRANSITION, false); + this->clear_flag_(FLAG_HAS_TRANSITION); } - if (this->has_transition_() && !supports_transition) { - log_feature_not_supported(name, LOG_STR("transitions")); - this->set_flag_(FLAG_HAS_TRANSITION, false); - } + if (this->has_transition_() && !supports_transition) + this->log_and_clear_unsupported_(FLAG_HAS_TRANSITION, LOG_STR("transitions"), false); // If not a flash and turning the light off, then disable the light // Do not use light color values directly, so that effects can set 0% brightness @@ -374,17 +352,17 @@ LightColorValues LightCall::validate_() { if (!this->has_flash_() && !target_state) { if (this->has_effect_()) { log_invalid_parameter(name, LOG_STR("cannot start effect when turning off")); - this->set_flag_(FLAG_HAS_EFFECT, false); + this->clear_flag_(FLAG_HAS_EFFECT); } else if (this->parent_->active_effect_index_ != 0 && explicit_turn_off_request) { // Auto turn off effect this->effect_ = 0; - this->set_flag_(FLAG_HAS_EFFECT, true); + this->set_flag_(FLAG_HAS_EFFECT); } } // Disable saving for flashes if (this->has_flash_()) - this->set_flag_(FLAG_SAVE, false); + this->clear_flag_(FLAG_SAVE); return v; } @@ -418,12 +396,12 @@ void LightCall::transform_parameters_() { const float gamma = this->parent_->get_gamma_correct(); this->cold_white_ = gamma_uncorrect(cw_fraction / max_cw_ww, gamma); this->warm_white_ = gamma_uncorrect(ww_fraction / max_cw_ww, gamma); - this->set_flag_(FLAG_HAS_COLD_WHITE, true); - this->set_flag_(FLAG_HAS_WARM_WHITE, true); + this->set_flag_(FLAG_HAS_COLD_WHITE); + this->set_flag_(FLAG_HAS_WARM_WHITE); } if (this->has_white()) { this->brightness_ = this->white_; - this->set_flag_(FLAG_HAS_BRIGHTNESS, true); + this->set_flag_(FLAG_HAS_BRIGHTNESS); } } } @@ -630,7 +608,7 @@ LightCall &LightCall::set_effect(optional effect) { } LightCall &LightCall::set_effect(uint32_t effect_number) { this->effect_ = effect_number; - this->set_flag_(FLAG_HAS_EFFECT, true); + this->set_flag_(FLAG_HAS_EFFECT); return *this; } LightCall &LightCall::set_effect(optional effect_number) { diff --git a/esphome/components/light/light_call.h b/esphome/components/light/light_call.h index 7e04e1a767..d3a526b136 100644 --- a/esphome/components/light/light_call.h +++ b/esphome/components/light/light_call.h @@ -4,6 +4,10 @@ #include namespace esphome { + +// Forward declaration +struct LogString; + namespace light { class LightState; @@ -207,14 +211,14 @@ class LightCall { FLAG_SAVE = 1 << 15, }; - bool has_transition_() { return (this->flags_ & FLAG_HAS_TRANSITION) != 0; } - bool has_flash_() { return (this->flags_ & FLAG_HAS_FLASH) != 0; } - bool has_effect_() { return (this->flags_ & FLAG_HAS_EFFECT) != 0; } - bool get_publish_() { return (this->flags_ & FLAG_PUBLISH) != 0; } - bool get_save_() { return (this->flags_ & FLAG_SAVE) != 0; } + inline bool has_transition_() { return (this->flags_ & FLAG_HAS_TRANSITION) != 0; } + inline bool has_flash_() { return (this->flags_ & FLAG_HAS_FLASH) != 0; } + inline bool has_effect_() { return (this->flags_ & FLAG_HAS_EFFECT) != 0; } + inline bool get_publish_() { return (this->flags_ & FLAG_PUBLISH) != 0; } + inline bool get_save_() { return (this->flags_ & FLAG_SAVE) != 0; } - // Helper to set flag - void set_flag_(FieldFlags flag, bool value) { + // Helper to set flag - defaults to true for common case + void set_flag_(FieldFlags flag, bool value = true) { if (value) { this->flags_ |= flag; } else { @@ -222,6 +226,12 @@ class LightCall { } } + // Helper to clear flag - reduces code size for common case + void clear_flag_(FieldFlags flag) { this->flags_ &= ~flag; } + + // Helper to log unsupported feature and clear flag - reduces code duplication + void log_and_clear_unsupported_(FieldFlags flag, const LogString *feature, bool use_color_mode_log); + LightState *parent_; // Light state values - use flags_ to check if a value has been set. diff --git a/esphome/components/lm75b/__init__.py b/esphome/components/lm75b/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/esphome/components/lm75b/lm75b.cpp b/esphome/components/lm75b/lm75b.cpp new file mode 100644 index 0000000000..19398eda85 --- /dev/null +++ b/esphome/components/lm75b/lm75b.cpp @@ -0,0 +1,39 @@ +#include "lm75b.h" +#include "esphome/core/log.h" +#include "esphome/core/hal.h" + +namespace esphome { +namespace lm75b { + +static const char *const TAG = "lm75b"; + +void LM75BComponent::dump_config() { + ESP_LOGCONFIG(TAG, "LM75B:"); + LOG_I2C_DEVICE(this); + if (this->is_failed()) { + ESP_LOGE(TAG, "Setting up LM75B failed!"); + } + LOG_UPDATE_INTERVAL(this); + LOG_SENSOR(" ", "Temperature", this); +} + +void LM75BComponent::update() { + // Create a temporary buffer + uint8_t buff[2]; + if (this->read_register(LM75B_REG_TEMPERATURE, buff, 2) != i2c::ERROR_OK) { + this->status_set_warning(); + return; + } + // Obtain combined 16-bit value + int16_t raw_temperature = (buff[0] << 8) | buff[1]; + // Read the 11-bit raw temperature value + raw_temperature >>= 5; + // Publish the temperature in °C + this->publish_state(raw_temperature * 0.125); + if (this->status_has_warning()) { + this->status_clear_warning(); + } +} + +} // namespace lm75b +} // namespace esphome diff --git a/esphome/components/lm75b/lm75b.h b/esphome/components/lm75b/lm75b.h new file mode 100644 index 0000000000..79d9fa3f32 --- /dev/null +++ b/esphome/components/lm75b/lm75b.h @@ -0,0 +1,19 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/components/sensor/sensor.h" +#include "esphome/components/i2c/i2c.h" + +namespace esphome { +namespace lm75b { + +static const uint8_t LM75B_REG_TEMPERATURE = 0x00; + +class LM75BComponent : public PollingComponent, public i2c::I2CDevice, public sensor::Sensor { + public: + void dump_config() override; + void update() override; +}; + +} // namespace lm75b +} // namespace esphome diff --git a/esphome/components/lm75b/sensor.py b/esphome/components/lm75b/sensor.py new file mode 100644 index 0000000000..335446b62f --- /dev/null +++ b/esphome/components/lm75b/sensor.py @@ -0,0 +1,34 @@ +import esphome.codegen as cg +from esphome.components import i2c, sensor +import esphome.config_validation as cv +from esphome.const import ( + DEVICE_CLASS_TEMPERATURE, + STATE_CLASS_MEASUREMENT, + UNIT_CELSIUS, +) + +CODEOWNERS = ["@beormund"] +DEPENDENCIES = ["i2c"] + +lm75b_ns = cg.esphome_ns.namespace("lm75b") +LM75BComponent = lm75b_ns.class_( + "LM75BComponent", cg.PollingComponent, i2c.I2CDevice, sensor.Sensor +) + +CONFIG_SCHEMA = ( + sensor.sensor_schema( + LM75BComponent, + unit_of_measurement=UNIT_CELSIUS, + accuracy_decimals=3, + device_class=DEVICE_CLASS_TEMPERATURE, + state_class=STATE_CLASS_MEASUREMENT, + ) + .extend(cv.polling_component_schema("60s")) + .extend(i2c.i2c_device_schema(0x48)) +) + + +async def to_code(config): + var = await sensor.new_sensor(config) + await cg.register_component(var, config) + await i2c.register_i2c_device(var, config) diff --git a/esphome/components/lock/lock.h b/esphome/components/lock/lock.h index 04c4cd71cd..9737569921 100644 --- a/esphome/components/lock/lock.h +++ b/esphome/components/lock/lock.h @@ -5,7 +5,7 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "esphome/core/preferences.h" -#include +#include namespace esphome { namespace lock { @@ -44,16 +44,22 @@ class LockTraits { bool get_assumed_state() const { return this->assumed_state_; } void set_assumed_state(bool assumed_state) { this->assumed_state_ = assumed_state; } - bool supports_state(LockState state) const { return supported_states_.count(state); } - std::set get_supported_states() const { return supported_states_; } - void set_supported_states(std::set states) { supported_states_ = std::move(states); } - void add_supported_state(LockState state) { supported_states_.insert(state); } + bool supports_state(LockState state) const { return supported_states_mask_ & (1 << state); } + void set_supported_states(std::initializer_list states) { + supported_states_mask_ = 0; + for (auto state : states) { + supported_states_mask_ |= (1 << state); + } + } + uint8_t get_supported_states_mask() const { return supported_states_mask_; } + void set_supported_states_mask(uint8_t mask) { supported_states_mask_ = mask; } + void add_supported_state(LockState state) { supported_states_mask_ |= (1 << state); } protected: bool supports_open_{false}; bool requires_code_{false}; bool assumed_state_{false}; - std::set supported_states_ = {LOCK_STATE_NONE, LOCK_STATE_LOCKED, LOCK_STATE_UNLOCKED}; + uint8_t supported_states_mask_{(1 << LOCK_STATE_NONE) | (1 << LOCK_STATE_LOCKED) | (1 << LOCK_STATE_UNLOCKED)}; }; /** This class is used to encode all control actions on a lock device. diff --git a/esphome/components/logger/__init__.py b/esphome/components/logger/__init__.py index 2865355278..1d02073d27 100644 --- a/esphome/components/logger/__init__.py +++ b/esphome/components/logger/__init__.py @@ -95,6 +95,7 @@ DEFAULT = "DEFAULT" CONF_INITIAL_LEVEL = "initial_level" CONF_LOGGER_ID = "logger_id" +CONF_RUNTIME_TAG_LEVELS = "runtime_tag_levels" CONF_TASK_LOG_BUFFER_SIZE = "task_log_buffer_size" UART_SELECTION_ESP32 = { @@ -117,8 +118,6 @@ UART_SELECTION_LIBRETINY = { COMPONENT_RTL87XX: [DEFAULT, UART0, UART1, UART2], } -ESP_ARDUINO_UNSUPPORTED_USB_UARTS = [USB_SERIAL_JTAG] - UART_SELECTION_RP2040 = [USB_CDC, UART0, UART1] UART_SELECTION_NRF52 = [USB_CDC, UART0] @@ -153,13 +152,7 @@ is_log_level = cv.one_of(*LOG_LEVELS, upper=True) def uart_selection(value): if CORE.is_esp32: - if CORE.using_arduino and value.upper() in ESP_ARDUINO_UNSUPPORTED_USB_UARTS: - raise cv.Invalid(f"Arduino framework does not support {value}.") variant = get_esp32_variant() - if CORE.using_esp_idf and variant == VARIANT_ESP32C3 and value == USB_CDC: - raise cv.Invalid( - f"{value} is not supported for variant {variant} when using ESP-IDF." - ) if variant in UART_SELECTION_ESP32: return cv.one_of(*UART_SELECTION_ESP32[variant], upper=True)(value) if CORE.is_esp8266: @@ -226,14 +219,11 @@ CONFIG_SCHEMA = cv.All( esp8266=UART0, esp32=UART0, esp32_s2=USB_CDC, - esp32_s3_arduino=USB_CDC, - esp32_s3_idf=USB_SERIAL_JTAG, - esp32_c3_arduino=USB_CDC, - esp32_c3_idf=USB_SERIAL_JTAG, - esp32_c5_idf=USB_SERIAL_JTAG, - esp32_c6_arduino=USB_CDC, - esp32_c6_idf=USB_SERIAL_JTAG, - esp32_p4_idf=USB_SERIAL_JTAG, + esp32_s3=USB_SERIAL_JTAG, + esp32_c3=USB_SERIAL_JTAG, + esp32_c5=USB_SERIAL_JTAG, + esp32_c6=USB_SERIAL_JTAG, + esp32_p4=USB_SERIAL_JTAG, rp2040=USB_CDC, bk72xx=DEFAULT, ln882x=DEFAULT, @@ -260,6 +250,7 @@ CONFIG_SCHEMA = cv.All( } ), cv.Optional(CONF_INITIAL_LEVEL): is_log_level, + cv.Optional(CONF_RUNTIME_TAG_LEVELS, default=False): cv.boolean, cv.Optional(CONF_ON_MESSAGE): automation.validate_automation( { cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(LoggerMessageTrigger), @@ -302,8 +293,12 @@ async def to_code(config): ) cg.add(log.pre_setup()) - for tag, log_level in config[CONF_LOGS].items(): - cg.add(log.set_log_level(tag, LOG_LEVELS[log_level])) + # Enable runtime tag levels if logs are configured or explicitly enabled + logs_config = config[CONF_LOGS] + if logs_config or config[CONF_RUNTIME_TAG_LEVELS]: + cg.add_define("USE_LOGGER_RUNTIME_TAG_LEVELS") + for tag, log_level in logs_config.items(): + cg.add(log.set_log_level(tag, LOG_LEVELS[log_level])) cg.add_define("USE_LOGGER") this_severity = LOG_LEVEL_SEVERITY.index(level) @@ -346,15 +341,7 @@ async def to_code(config): if config.get(CONF_ESP8266_STORE_LOG_STRINGS_IN_FLASH): cg.add_build_flag("-DUSE_STORE_LOG_STR_IN_FLASH") - if CORE.using_arduino and config[CONF_HARDWARE_UART] == USB_CDC: - cg.add_build_flag("-DARDUINO_USB_CDC_ON_BOOT=1") - if CORE.is_esp32 and get_esp32_variant() in ( - VARIANT_ESP32C3, - VARIANT_ESP32C6, - ): - cg.add_build_flag("-DARDUINO_USB_MODE=1") - - if CORE.using_esp_idf: + if CORE.is_esp32: if config[CONF_HARDWARE_UART] == USB_CDC: add_idf_sdkconfig_option("CONFIG_ESP_CONSOLE_USB_CDC", True) elif config[CONF_HARDWARE_UART] == USB_SERIAL_JTAG: @@ -462,6 +449,7 @@ async def logger_set_level_to_code(config, action_id, template_arg, args): level = LOG_LEVELS[config[CONF_LEVEL]] logger = await cg.get_variable(config[CONF_LOGGER_ID]) if tag := config.get(CONF_TAG): + cg.add_define("USE_LOGGER_RUNTIME_TAG_LEVELS") text = str(cg.statement(logger.set_log_level(tag, level))) else: text = str(cg.statement(logger.set_log_level(level))) diff --git a/esphome/components/logger/logger.cpp b/esphome/components/logger/logger.cpp index 5f0e78fc0d..9a9bf89fe3 100644 --- a/esphome/components/logger/logger.cpp +++ b/esphome/components/logger/logger.cpp @@ -148,9 +148,11 @@ void Logger::log_vprintf_(uint8_t level, const char *tag, int line, const __Flas #endif // USE_STORE_LOG_STR_IN_FLASH inline uint8_t Logger::level_for(const char *tag) { +#ifdef USE_LOGGER_RUNTIME_TAG_LEVELS auto it = this->log_levels_.find(tag); if (it != this->log_levels_.end()) return it->second; +#endif return this->current_level_; } @@ -173,24 +175,8 @@ void Logger::init_log_buffer(size_t total_buffer_size) { } #endif -#ifndef USE_ZEPHYR -#if defined(USE_LOGGER_USB_CDC) || defined(USE_ESP32) -void Logger::loop() { -#if defined(USE_LOGGER_USB_CDC) && defined(USE_ARDUINO) - if (this->uart_ == UART_SELECTION_USB_CDC) { - static bool opened = false; - if (opened == Serial) { - return; - } - if (false == opened) { - App.schedule_dump_config(); - } - opened = !opened; - } -#endif - this->process_messages_(); -} -#endif +#ifdef USE_ESPHOME_TASK_LOG_BUFFER +void Logger::loop() { this->process_messages_(); } #endif void Logger::process_messages_() { @@ -236,7 +222,9 @@ void Logger::process_messages_() { } void Logger::set_baud_rate(uint32_t baud_rate) { this->baud_rate_ = baud_rate; } -void Logger::set_log_level(const std::string &tag, uint8_t log_level) { this->log_levels_[tag] = log_level; } +#ifdef USE_LOGGER_RUNTIME_TAG_LEVELS +void Logger::set_log_level(const char *tag, uint8_t log_level) { this->log_levels_[tag] = log_level; } +#endif #if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_ZEPHYR) UARTSelection Logger::get_uart() const { return this->uart_; } @@ -287,9 +275,11 @@ void Logger::dump_config() { } #endif +#ifdef USE_LOGGER_RUNTIME_TAG_LEVELS for (auto &it : this->log_levels_) { - ESP_LOGCONFIG(TAG, " Level for '%s': %s", it.first.c_str(), LOG_STR_ARG(LOG_LEVELS[it.second])); + ESP_LOGCONFIG(TAG, " Level for '%s': %s", it.first, LOG_STR_ARG(LOG_LEVELS[it.second])); } +#endif } void Logger::set_log_level(uint8_t level) { diff --git a/esphome/components/logger/logger.h b/esphome/components/logger/logger.h index a4cf5e3004..2099520049 100644 --- a/esphome/components/logger/logger.h +++ b/esphome/components/logger/logger.h @@ -16,18 +16,18 @@ #endif #ifdef USE_ARDUINO -#if defined(USE_ESP8266) || defined(USE_ESP32) +#if defined(USE_ESP8266) #include -#endif // USE_ESP8266 || USE_ESP32 +#endif // USE_ESP8266 #ifdef USE_RP2040 #include #include #endif // USE_RP2040 #endif // USE_ARDUINO -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include -#endif // USE_ESP_IDF +#endif // USE_ESP32 #ifdef USE_ZEPHYR #include @@ -36,29 +36,38 @@ struct device; namespace esphome::logger { -// Color and letter constants for log levels -static const char *const LOG_LEVEL_COLORS[] = { - "", // NONE - ESPHOME_LOG_BOLD(ESPHOME_LOG_COLOR_RED), // ERROR - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_YELLOW), // WARNING - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_GREEN), // INFO - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_MAGENTA), // CONFIG - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_CYAN), // DEBUG - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_GRAY), // VERBOSE - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_WHITE), // VERY_VERBOSE +#ifdef USE_LOGGER_RUNTIME_TAG_LEVELS +// Comparison function for const char* keys in log_levels_ map +struct CStrCompare { + bool operator()(const char *a, const char *b) const { return strcmp(a, b) < 0; } +}; +#endif + +// ANSI color code last digit (30-38 range, store only last digit to save RAM) +static constexpr char LOG_LEVEL_COLOR_DIGIT[] = { + '\0', // NONE + '1', // ERROR (31 = red) + '3', // WARNING (33 = yellow) + '2', // INFO (32 = green) + '5', // CONFIG (35 = magenta) + '6', // DEBUG (36 = cyan) + '7', // VERBOSE (37 = gray) + '8', // VERY_VERBOSE (38 = white) }; -static const char *const LOG_LEVEL_LETTERS[] = { - "", // NONE - "E", // ERROR - "W", // WARNING - "I", // INFO - "C", // CONFIG - "D", // DEBUG - "V", // VERBOSE - "VV", // VERY_VERBOSE +static constexpr char LOG_LEVEL_LETTER_CHARS[] = { + '\0', // NONE + 'E', // ERROR + 'W', // WARNING + 'I', // INFO + 'C', // CONFIG + 'D', // DEBUG + 'V', // VERBOSE (VERY_VERBOSE uses two 'V's) }; +// Maximum header size: 35 bytes fixed + 32 bytes tag + 16 bytes thread name = 83 bytes (45 byte safety margin) +static constexpr uint16_t MAX_HEADER_SIZE = 128; + #if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_ZEPHYR) /** Enum for logging UART selection * @@ -110,19 +119,17 @@ class Logger : public Component { #ifdef USE_ESPHOME_TASK_LOG_BUFFER void init_log_buffer(size_t total_buffer_size); #endif -#if defined(USE_LOGGER_USB_CDC) || defined(USE_ESP32) || defined(USE_ZEPHYR) +#if defined(USE_ESPHOME_TASK_LOG_BUFFER) || (defined(USE_ZEPHYR) && defined(USE_LOGGER_USB_CDC)) void loop() override; #endif /// Manually set the baud rate for serial, set to 0 to disable. void set_baud_rate(uint32_t baud_rate); uint32_t get_baud_rate() const { return baud_rate_; } -#ifdef USE_ARDUINO +#if defined(USE_ARDUINO) && !defined(USE_ESP32) Stream *get_hw_serial() const { return hw_serial_; } #endif -#ifdef USE_ESP_IDF - uart_port_t get_uart_num() const { return uart_num_; } -#endif #ifdef USE_ESP32 + uart_port_t get_uart_num() const { return uart_num_; } void create_pthread_key() { pthread_key_create(&log_recursion_key_, nullptr); } #endif #if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_ZEPHYR) @@ -133,8 +140,10 @@ class Logger : public Component { /// Set the default log level for this logger. void set_log_level(uint8_t level); +#ifdef USE_LOGGER_RUNTIME_TAG_LEVELS /// Set the log level of the specified tag. - void set_log_level(const std::string &tag, uint8_t log_level); + void set_log_level(const char *tag, uint8_t log_level); +#endif uint8_t get_log_level() { return this->current_level_; } // ========== INTERNAL METHODS ========== @@ -217,14 +226,6 @@ class Logger : public Component { } } - // Format string to explicit buffer with varargs - inline void printf_to_buffer_(char *buffer, uint16_t *buffer_at, uint16_t buffer_size, const char *format, ...) { - va_list arg; - va_start(arg, format); - this->format_body_to_buffer_(buffer, buffer_at, buffer_size, format, arg); - va_end(arg); - } - #ifndef USE_HOST const LogString *get_uart_selection_(); #endif @@ -232,7 +233,7 @@ class Logger : public Component { // Group 4-byte aligned members first uint32_t baud_rate_; char *tx_buffer_{nullptr}; -#ifdef USE_ARDUINO +#if defined(USE_ARDUINO) && !defined(USE_ESP32) Stream *hw_serial_{nullptr}; #endif #if defined(USE_ZEPHYR) @@ -246,13 +247,13 @@ class Logger : public Component { // - Main task uses a dedicated member variable for efficiency // - Other tasks use pthread TLS with a dynamically created key via pthread_key_create pthread_key_t log_recursion_key_; // 4 bytes -#endif -#ifdef USE_ESP_IDF - uart_port_t uart_num_; // 4 bytes (enum defaults to int size) + uart_port_t uart_num_; // 4 bytes (enum defaults to int size) #endif // Large objects (internally aligned) - std::map log_levels_{}; +#ifdef USE_LOGGER_RUNTIME_TAG_LEVELS + std::map log_levels_{}; +#endif CallbackManager log_callback_{}; CallbackManager level_callback_{}; #ifdef USE_ESPHOME_TASK_LOG_BUFFER @@ -322,26 +323,76 @@ class Logger : public Component { } #endif + static inline void copy_string(char *buffer, uint16_t &pos, const char *str) { + const size_t len = strlen(str); + // Intentionally no null terminator, building larger string + memcpy(buffer + pos, str, len); // NOLINT(bugprone-not-null-terminated-result) + pos += len; + } + + static inline void write_ansi_color_for_level(char *buffer, uint16_t &pos, uint8_t level) { + if (level == 0) + return; + // Construct ANSI escape sequence: "\033[{bold};3{color}m" + // Example: "\033[1;31m" for ERROR (bold red) + buffer[pos++] = '\033'; + buffer[pos++] = '['; + buffer[pos++] = (level == 1) ? '1' : '0'; // Only ERROR is bold + buffer[pos++] = ';'; + buffer[pos++] = '3'; + buffer[pos++] = LOG_LEVEL_COLOR_DIGIT[level]; + buffer[pos++] = 'm'; + } + inline void HOT write_header_to_buffer_(uint8_t level, const char *tag, int line, const char *thread_name, char *buffer, uint16_t *buffer_at, uint16_t buffer_size) { - // Format header - // uint8_t level is already bounded 0-255, just ensure it's <= 7 - if (level > 7) - level = 7; + uint16_t pos = *buffer_at; + // Early return if insufficient space - intentionally don't update buffer_at to prevent partial writes + if (pos + MAX_HEADER_SIZE > buffer_size) + return; - const char *color = esphome::logger::LOG_LEVEL_COLORS[level]; - const char *letter = esphome::logger::LOG_LEVEL_LETTERS[level]; + // Construct: [LEVEL][tag:line]: + write_ansi_color_for_level(buffer, pos, level); + buffer[pos++] = '['; + if (level != 0) { + if (level >= 7) { + buffer[pos++] = 'V'; // VERY_VERBOSE = "VV" + buffer[pos++] = 'V'; + } else { + buffer[pos++] = LOG_LEVEL_LETTER_CHARS[level]; + } + } + buffer[pos++] = ']'; + buffer[pos++] = '['; + copy_string(buffer, pos, tag); + buffer[pos++] = ':'; + // Format line number without modulo operations (passed by value, safe to mutate) + if (line > 999) [[unlikely]] { + int thousands = line / 1000; + buffer[pos++] = '0' + thousands; + line -= thousands * 1000; + } + int hundreds = line / 100; + int remainder = line - hundreds * 100; + int tens = remainder / 10; + buffer[pos++] = '0' + hundreds; + buffer[pos++] = '0' + tens; + buffer[pos++] = '0' + (remainder - tens * 10); + buffer[pos++] = ']'; #if defined(USE_ESP32) || defined(USE_LIBRETINY) || defined(USE_ZEPHYR) if (thread_name != nullptr) { - // Non-main task with thread name - this->printf_to_buffer_(buffer, buffer_at, buffer_size, "%s[%s][%s:%03u]%s[%s]%s: ", color, letter, tag, line, - ESPHOME_LOG_BOLD(ESPHOME_LOG_COLOR_RED), thread_name, color); - return; + write_ansi_color_for_level(buffer, pos, 1); // Always use bold red for thread name + buffer[pos++] = '['; + copy_string(buffer, pos, thread_name); + buffer[pos++] = ']'; + write_ansi_color_for_level(buffer, pos, level); // Restore original color } #endif - // Main task or non ESP32/LibreTiny platform - this->printf_to_buffer_(buffer, buffer_at, buffer_size, "%s[%s][%s:%03u]: ", color, letter, tag, line); + + buffer[pos++] = ':'; + buffer[pos++] = ' '; + *buffer_at = pos; } inline void HOT format_body_to_buffer_(char *buffer, uint16_t *buffer_at, uint16_t buffer_size, const char *format, @@ -380,15 +431,7 @@ class Logger : public Component { // will be processed on the next main loop iteration since: // - disable_loop() takes effect immediately // - enable_loop_soon_any_context() sets a pending flag that's checked at loop start -#if defined(USE_LOGGER_USB_CDC) && defined(USE_ARDUINO) - // Only disable if not using USB CDC (which needs loop for connection detection) - if (this->uart_ != UART_SELECTION_USB_CDC) { - this->disable_loop(); - } -#else - // No USB CDC support, always safe to disable this->disable_loop(); -#endif } #endif }; diff --git a/esphome/components/logger/logger_esp32.cpp b/esphome/components/logger/logger_esp32.cpp index 6cb57c1540..7fc79e6f54 100644 --- a/esphome/components/logger/logger_esp32.cpp +++ b/esphome/components/logger/logger_esp32.cpp @@ -1,11 +1,8 @@ #ifdef USE_ESP32 #include "logger.h" -#if defined(USE_ESP32_FRAMEWORK_ARDUINO) || defined(USE_ESP_IDF) #include -#endif // USE_ESP32_FRAMEWORK_ARDUINO || USE_ESP_IDF -#ifdef USE_ESP_IDF #include #ifdef USE_LOGGER_USB_SERIAL_JTAG @@ -25,16 +22,12 @@ #include #include -#endif // USE_ESP_IDF - #include "esphome/core/log.h" namespace esphome::logger { static const char *const TAG = "logger"; -#ifdef USE_ESP_IDF - #ifdef USE_LOGGER_USB_SERIAL_JTAG static void init_usb_serial_jtag_() { setvbuf(stdin, NULL, _IONBF, 0); // Disable buffering on stdin @@ -89,42 +82,8 @@ void init_uart(uart_port_t uart_num, uint32_t baud_rate, int tx_buffer_size) { uart_driver_install(uart_num, uart_buffer_size, uart_buffer_size, 10, nullptr, 0); } -#endif // USE_ESP_IDF - void Logger::pre_setup() { if (this->baud_rate_ > 0) { -#ifdef USE_ARDUINO - switch (this->uart_) { - case UART_SELECTION_UART0: -#if ARDUINO_USB_CDC_ON_BOOT - this->hw_serial_ = &Serial0; - Serial0.begin(this->baud_rate_); -#else - this->hw_serial_ = &Serial; - Serial.begin(this->baud_rate_); -#endif - break; - case UART_SELECTION_UART1: - this->hw_serial_ = &Serial1; - Serial1.begin(this->baud_rate_); - break; -#ifdef USE_ESP32_VARIANT_ESP32 - case UART_SELECTION_UART2: - this->hw_serial_ = &Serial2; - Serial2.begin(this->baud_rate_); - break; -#endif - -#ifdef USE_LOGGER_USB_CDC - case UART_SELECTION_USB_CDC: - this->hw_serial_ = &Serial; - Serial.begin(this->baud_rate_); - break; -#endif - } -#endif // USE_ARDUINO - -#ifdef USE_ESP_IDF this->uart_num_ = UART_NUM_0; switch (this->uart_) { case UART_SELECTION_UART0: @@ -151,21 +110,17 @@ void Logger::pre_setup() { break; #endif } -#endif // USE_ESP_IDF } global_logger = this; -#if defined(USE_ESP_IDF) || defined(USE_ESP32_FRAMEWORK_ARDUINO) esp_log_set_vprintf(esp_idf_log_vprintf_); if (ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE) { esp_log_level_set("*", ESP_LOG_VERBOSE); } -#endif // USE_ESP_IDF || USE_ESP32_FRAMEWORK_ARDUINO ESP_LOGI(TAG, "Log initialized"); } -#ifdef USE_ESP_IDF void HOT Logger::write_msg_(const char *msg) { if ( #if defined(USE_LOGGER_USB_CDC) && !defined(USE_LOGGER_USB_SERIAL_JTAG) @@ -186,9 +141,6 @@ void HOT Logger::write_msg_(const char *msg) { uart_write_bytes(this->uart_num_, "\n", 1); } } -#else -void HOT Logger::write_msg_(const char *msg) { this->hw_serial_->println(msg); } -#endif const LogString *Logger::get_uart_selection_() { switch (this->uart_) { diff --git a/esphome/components/logger/logger_zephyr.cpp b/esphome/components/logger/logger_zephyr.cpp index 817ca168f8..fb0c7dcca3 100644 --- a/esphome/components/logger/logger_zephyr.cpp +++ b/esphome/components/logger/logger_zephyr.cpp @@ -12,8 +12,8 @@ namespace esphome::logger { static const char *const TAG = "logger"; -void Logger::loop() { #ifdef USE_LOGGER_USB_CDC +void Logger::loop() { if (this->uart_ != UART_SELECTION_USB_CDC || nullptr == this->uart_dev_) { return; } @@ -30,9 +30,8 @@ void Logger::loop() { App.schedule_dump_config(); } opened = !opened; -#endif - this->process_messages_(); } +#endif void Logger::pre_setup() { if (this->baud_rate_ > 0) { diff --git a/esphome/components/logger/select/logger_level_select.cpp b/esphome/components/logger/select/logger_level_select.cpp index d9c950ce3c..6d60a3ae47 100644 --- a/esphome/components/logger/select/logger_level_select.cpp +++ b/esphome/components/logger/select/logger_level_select.cpp @@ -3,11 +3,10 @@ namespace esphome::logger { void LoggerLevelSelect::publish_state(int level) { - auto value = this->at(level); - if (!value) { + const auto &option = this->at(level_to_index(level)); + if (!option) return; - } - Select::publish_state(value.value()); + Select::publish_state(option.value()); } void LoggerLevelSelect::setup() { @@ -16,10 +15,10 @@ void LoggerLevelSelect::setup() { } void LoggerLevelSelect::control(const std::string &value) { - auto level = this->index_of(value); - if (!level) + const auto index = this->index_of(value); + if (!index) return; - this->parent_->set_log_level(level.value()); + this->parent_->set_log_level(index_to_level(index.value())); } } // namespace esphome::logger diff --git a/esphome/components/logger/select/logger_level_select.h b/esphome/components/logger/select/logger_level_select.h index f31a6f6cdb..0631eca45d 100644 --- a/esphome/components/logger/select/logger_level_select.h +++ b/esphome/components/logger/select/logger_level_select.h @@ -3,11 +3,18 @@ #include "esphome/components/select/select.h" #include "esphome/core/component.h" #include "esphome/components/logger/logger.h" + namespace esphome::logger { class LoggerLevelSelect : public Component, public select::Select, public Parented { public: void publish_state(int level); void setup() override; void control(const std::string &value) override; + + protected: + // Convert log level to option index (skip CONFIG at level 4) + static uint8_t level_to_index(uint8_t level) { return (level > ESPHOME_LOG_LEVEL_CONFIG) ? level - 1 : level; } + // Convert option index to log level (skip CONFIG at level 4) + static uint8_t index_to_level(uint8_t index) { return (index >= ESPHOME_LOG_LEVEL_CONFIG) ? index + 1 : index; } }; } // namespace esphome::logger diff --git a/esphome/components/ltr501/ltr501.cpp b/esphome/components/ltr501/ltr501.cpp index b249d23666..be5a4ddccf 100644 --- a/esphome/components/ltr501/ltr501.cpp +++ b/esphome/components/ltr501/ltr501.cpp @@ -2,6 +2,7 @@ #include "esphome/core/application.h" #include "esphome/core/helpers.h" #include "esphome/core/log.h" +#include using esphome::i2c::ErrorCode; @@ -28,30 +29,30 @@ bool operator!=(const GainTimePair &lhs, const GainTimePair &rhs) { template T get_next(const T (&array)[size], const T val) { size_t i = 0; - size_t idx = -1; - while (idx == -1 && i < size) { + size_t idx = std::numeric_limits::max(); + while (idx == std::numeric_limits::max() && i < size) { if (array[i] == val) { idx = i; break; } i++; } - if (idx == -1 || i + 1 >= size) + if (idx == std::numeric_limits::max() || i + 1 >= size) return val; return array[i + 1]; } template T get_prev(const T (&array)[size], const T val) { size_t i = size - 1; - size_t idx = -1; - while (idx == -1 && i > 0) { + size_t idx = std::numeric_limits::max(); + while (idx == std::numeric_limits::max() && i > 0) { if (array[i] == val) { idx = i; break; } i--; } - if (idx == -1 || i == 0) + if (idx == std::numeric_limits::max() || i == 0) return val; return array[i - 1]; } diff --git a/esphome/components/ltr_als_ps/ltr_als_ps.cpp b/esphome/components/ltr_als_ps/ltr_als_ps.cpp index bf27c01e26..c3ea5848c8 100644 --- a/esphome/components/ltr_als_ps/ltr_als_ps.cpp +++ b/esphome/components/ltr_als_ps/ltr_als_ps.cpp @@ -2,6 +2,7 @@ #include "esphome/core/application.h" #include "esphome/core/helpers.h" #include "esphome/core/log.h" +#include using esphome::i2c::ErrorCode; @@ -14,30 +15,30 @@ static const uint8_t MAX_TRIES = 5; template T get_next(const T (&array)[size], const T val) { size_t i = 0; - size_t idx = -1; - while (idx == -1 && i < size) { + size_t idx = std::numeric_limits::max(); + while (idx == std::numeric_limits::max() && i < size) { if (array[i] == val) { idx = i; break; } i++; } - if (idx == -1 || i + 1 >= size) + if (idx == std::numeric_limits::max() || i + 1 >= size) return val; return array[i + 1]; } template T get_prev(const T (&array)[size], const T val) { size_t i = size - 1; - size_t idx = -1; - while (idx == -1 && i > 0) { + size_t idx = std::numeric_limits::max(); + while (idx == std::numeric_limits::max() && i > 0) { if (array[i] == val) { idx = i; break; } i--; } - if (idx == -1 || i == 0) + if (idx == std::numeric_limits::max() || i == 0) return val; return array[i - 1]; } diff --git a/esphome/components/matrix_keypad/matrix_keypad.h b/esphome/components/matrix_keypad/matrix_keypad.h index 8b309b42c2..258ab4fadc 100644 --- a/esphome/components/matrix_keypad/matrix_keypad.h +++ b/esphome/components/matrix_keypad/matrix_keypad.h @@ -29,9 +29,9 @@ class MatrixKeypad : public key_provider::KeyProvider, public Component { void set_columns(std::vector pins) { columns_ = std::move(pins); }; void set_rows(std::vector pins) { rows_ = std::move(pins); }; void set_keys(std::string keys) { keys_ = std::move(keys); }; - void set_debounce_time(int debounce_time) { debounce_time_ = debounce_time; }; - void set_has_diodes(int has_diodes) { has_diodes_ = has_diodes; }; - void set_has_pulldowns(int has_pulldowns) { has_pulldowns_ = has_pulldowns; }; + void set_debounce_time(uint32_t debounce_time) { debounce_time_ = debounce_time; }; + void set_has_diodes(bool has_diodes) { has_diodes_ = has_diodes; }; + void set_has_pulldowns(bool has_pulldowns) { has_pulldowns_ = has_pulldowns; }; void register_listener(MatrixKeypadListener *listener); void register_key_trigger(MatrixKeyTrigger *trig); @@ -40,7 +40,7 @@ class MatrixKeypad : public key_provider::KeyProvider, public Component { std::vector rows_; std::vector columns_; std::string keys_; - int debounce_time_ = 0; + uint32_t debounce_time_ = 0; bool has_diodes_{false}; bool has_pulldowns_{false}; int pressed_key_ = -1; diff --git a/esphome/components/max7219digit/max7219digit.cpp b/esphome/components/max7219digit/max7219digit.cpp index 9b9921d2f0..6df3c4d7c8 100644 --- a/esphome/components/max7219digit/max7219digit.cpp +++ b/esphome/components/max7219digit/max7219digit.cpp @@ -90,7 +90,7 @@ void MAX7219Component::loop() { } if (this->scroll_mode_ == ScrollMode::STOP) { - if (this->stepsleft_ + get_width_internal() == first_line_size + 1) { + if (static_cast(this->stepsleft_ + get_width_internal()) == first_line_size + 1) { if (millis_since_last_scroll < this->scroll_dwell_) { ESP_LOGVV(TAG, "Dwell time at end of string in case of stop at end. Step %d, since last scroll %d, dwell %d.", this->stepsleft_, millis_since_last_scroll, this->scroll_dwell_); diff --git a/esphome/components/mcp2515/mcp2515.cpp b/esphome/components/mcp2515/mcp2515.cpp index 23104f5aeb..1a17715315 100644 --- a/esphome/components/mcp2515/mcp2515.cpp +++ b/esphome/components/mcp2515/mcp2515.cpp @@ -20,6 +20,23 @@ bool MCP2515::setup_internal() { return false; if (this->set_bitrate_(this->bit_rate_, this->mcp_clock_) != canbus::ERROR_OK) return false; + + // setup hardware filter RXF0 accepting all standard CAN IDs + if (this->set_filter_(RXF::RXF0, false, 0) != canbus::ERROR_OK) { + return false; + } + if (this->set_filter_mask_(MASK::MASK0, false, 0) != canbus::ERROR_OK) { + return false; + } + + // setup hardware filter RXF1 accepting all extended CAN IDs + if (this->set_filter_(RXF::RXF1, true, 0) != canbus::ERROR_OK) { + return false; + } + if (this->set_filter_mask_(MASK::MASK1, true, 0) != canbus::ERROR_OK) { + return false; + } + if (this->set_mode_(this->mcp_mode_) != canbus::ERROR_OK) return false; uint8_t err_flags = this->get_error_flags_(); @@ -155,7 +172,7 @@ void MCP2515::prepare_id_(uint8_t *buffer, const bool extended, const uint32_t i canid = (uint16_t) (id >> 16); buffer[MCP_SIDL] = (uint8_t) (canid & 0x03); buffer[MCP_SIDL] += (uint8_t) ((canid & 0x1C) << 3); - buffer[MCP_SIDL] |= TXB_EXIDE_MASK; + buffer[MCP_SIDL] |= SIDL_EXIDE_MASK; buffer[MCP_SIDH] = (uint8_t) (canid >> 5); } else { buffer[MCP_SIDH] = (uint8_t) (canid >> 3); @@ -258,7 +275,7 @@ canbus::Error MCP2515::send_message(struct canbus::CanFrame *frame) { } } - return canbus::ERROR_FAILTX; + return canbus::ERROR_ALLTXBUSY; } canbus::Error MCP2515::read_message_(RXBn rxbn, struct canbus::CanFrame *frame) { @@ -272,7 +289,7 @@ canbus::Error MCP2515::read_message_(RXBn rxbn, struct canbus::CanFrame *frame) bool use_extended_id = false; bool remote_transmission_request = false; - if ((tbufdata[MCP_SIDL] & TXB_EXIDE_MASK) == TXB_EXIDE_MASK) { + if ((tbufdata[MCP_SIDL] & SIDL_EXIDE_MASK) == SIDL_EXIDE_MASK) { id = (id << 2) + (tbufdata[MCP_SIDL] & 0x03); id = (id << 8) + tbufdata[MCP_EID8]; id = (id << 8) + tbufdata[MCP_EID0]; @@ -315,6 +332,17 @@ canbus::Error MCP2515::read_message(struct canbus::CanFrame *frame) { rc = canbus::ERROR_NOMSG; } +#ifdef ESPHOME_LOG_HAS_DEBUG + uint8_t err = get_error_flags_(); + // The receive flowchart in the datasheet says that if rollover is set (BUKT), RX1OVR flag will be set + // once both buffers are full. However, the RX0OVR flag is actually set instead. + // We can just check for both though because it doesn't break anything. + if (err & (EFLG_RX0OVR | EFLG_RX1OVR)) { + ESP_LOGD(TAG, "receive buffer overrun"); + clear_rx_n_ovr_flags_(); + } +#endif + return rc; } diff --git a/esphome/components/mcp2515/mcp2515_defs.h b/esphome/components/mcp2515/mcp2515_defs.h index 2f5cf2a238..b33adcbba6 100644 --- a/esphome/components/mcp2515/mcp2515_defs.h +++ b/esphome/components/mcp2515/mcp2515_defs.h @@ -130,7 +130,9 @@ static const uint8_t CANSTAT_ICOD = 0x0E; static const uint8_t CNF3_SOF = 0x80; -static const uint8_t TXB_EXIDE_MASK = 0x08; +// applies to RXBn_SIDL, TXBn_SIDL and RXFn_SIDL +static const uint8_t SIDL_EXIDE_MASK = 0x08; + static const uint8_t DLC_MASK = 0x0F; static const uint8_t RTR_MASK = 0x40; diff --git a/esphome/components/md5/md5.cpp b/esphome/components/md5/md5.cpp index 21bd2e1cab..866f00eda4 100644 --- a/esphome/components/md5/md5.cpp +++ b/esphome/components/md5/md5.cpp @@ -39,32 +39,6 @@ void MD5Digest::add(const uint8_t *data, size_t len) { br_md5_update(&this->ctx_ void MD5Digest::calculate() { br_md5_out(&this->ctx_, this->digest_); } #endif // USE_RP2040 -void MD5Digest::get_bytes(uint8_t *output) { memcpy(output, this->digest_, 16); } - -void MD5Digest::get_hex(char *output) { - for (size_t i = 0; i < 16; i++) { - uint8_t byte = this->digest_[i]; - output[i * 2] = format_hex_char(byte >> 4); - output[i * 2 + 1] = format_hex_char(byte & 0x0F); - } -} - -bool MD5Digest::equals_bytes(const uint8_t *expected) { - for (size_t i = 0; i < 16; i++) { - if (expected[i] != this->digest_[i]) { - return false; - } - } - return true; -} - -bool MD5Digest::equals_hex(const char *expected) { - uint8_t parsed[16]; - if (!parse_hex(expected, parsed, 16)) - return false; - return equals_bytes(parsed); -} - } // namespace md5 } // namespace esphome #endif diff --git a/esphome/components/md5/md5.h b/esphome/components/md5/md5.h index be1df40423..b0da2c0a3b 100644 --- a/esphome/components/md5/md5.h +++ b/esphome/components/md5/md5.h @@ -3,6 +3,8 @@ #include "esphome/core/defines.h" #ifdef USE_MD5 +#include "esphome/core/hash_base.h" + #ifdef USE_ESP32 #include "esp_rom_md5.h" #define MD5_CTX_TYPE md5_context_t @@ -26,38 +28,26 @@ namespace esphome { namespace md5 { -class MD5Digest { +class MD5Digest : public HashBase { public: MD5Digest() = default; - ~MD5Digest() = default; + ~MD5Digest() override = default; /// Initialize a new MD5 digest computation. - void init(); + void init() override; /// Add bytes of data for the digest. - void add(const uint8_t *data, size_t len); - void add(const char *data, size_t len) { this->add((const uint8_t *) data, len); } + void add(const uint8_t *data, size_t len) override; + using HashBase::add; // Bring base class overload into scope /// Compute the digest, based on the provided data. - void calculate(); + void calculate() override; - /// Retrieve the MD5 digest as bytes. - /// The output must be able to hold 16 bytes or more. - void get_bytes(uint8_t *output); - - /// Retrieve the MD5 digest as hex characters. - /// The output must be able to hold 32 bytes or more. - void get_hex(char *output); - - /// Compare the digest against a provided byte-encoded digest (16 bytes). - bool equals_bytes(const uint8_t *expected); - - /// Compare the digest against a provided hex-encoded digest (32 bytes). - bool equals_hex(const char *expected); + /// Get the size of the hash in bytes (16 for MD5) + size_t get_size() const override { return 16; } protected: MD5_CTX_TYPE ctx_{}; - uint8_t digest_[16]; }; } // namespace md5 diff --git a/esphome/components/mdns/__init__.py b/esphome/components/mdns/__init__.py index a84fe5a249..3fa4d2ebef 100644 --- a/esphome/components/mdns/__init__.py +++ b/esphome/components/mdns/__init__.py @@ -17,6 +17,11 @@ from esphome.coroutine import CoroPriority CODEOWNERS = ["@esphome/core"] DEPENDENCIES = ["network"] +# Components that create mDNS services at runtime +# IMPORTANT: If you add a new component here, you must also update the corresponding +# #ifdef blocks in mdns_component.cpp compile_records_() method +COMPONENTS_WITH_MDNS_SERVICES = ("api", "prometheus", "web_server") + mdns_ns = cg.esphome_ns.namespace("mdns") MDNSComponent = mdns_ns.class_("MDNSComponent", cg.Component) MDNSTXTRecord = mdns_ns.struct("MDNSTXTRecord") @@ -56,7 +61,7 @@ CONFIG_SCHEMA = cv.All( def mdns_txt_record(key: str, value: str): return cg.StructInitializer( MDNSTXTRecord, - ("key", key), + ("key", cg.RawExpression(f"MDNS_STR({cg.safe_exp(key)})")), ("value", value), ) @@ -66,8 +71,8 @@ def mdns_service( ): return cg.StructInitializer( MDNSService, - ("service_type", service), - ("proto", proto), + ("service_type", cg.RawExpression(f"MDNS_STR({cg.safe_exp(service)})")), + ("proto", cg.RawExpression(f"MDNS_STR({cg.safe_exp(proto)})")), ("port", port), ("txt_records", txt_records), ) @@ -91,17 +96,25 @@ async def to_code(config): cg.add_define("USE_MDNS") - var = cg.new_Pvariable(config[CONF_ID]) - await cg.register_component(var, config) + # Calculate compile-time service count + service_count = sum( + 1 for key in COMPONENTS_WITH_MDNS_SERVICES if key in CORE.config + ) + len(config[CONF_SERVICES]) if config[CONF_SERVICES]: cg.add_define("USE_MDNS_EXTRA_SERVICES") + # Ensure at least 1 service (fallback service) + cg.add_define("MDNS_SERVICE_COUNT", max(1, service_count)) + + var = cg.new_Pvariable(config[CONF_ID]) + await cg.register_component(var, config) + for service in config[CONF_SERVICES]: txt = [ cg.StructInitializer( MDNSTXTRecord, - ("key", txt_key), + ("key", cg.RawExpression(f"MDNS_STR({cg.safe_exp(txt_key)})")), ("value", await cg.templatable(txt_value, [], cg.std_string)), ) for txt_key, txt_value in service[CONF_TXT].items() diff --git a/esphome/components/mdns/mdns_component.cpp b/esphome/components/mdns/mdns_component.cpp index 5d9788198f..8945053b7d 100644 --- a/esphome/components/mdns/mdns_component.cpp +++ b/esphome/components/mdns/mdns_component.cpp @@ -9,24 +9,21 @@ #include // Macro to define strings in PROGMEM on ESP8266, regular memory on other platforms #define MDNS_STATIC_CONST_CHAR(name, value) static const char name[] PROGMEM = value -// Helper to get string from PROGMEM - returns a temporary std::string +// Helper to convert PROGMEM string to std::string for TemplatableValue // Only define this function if we have services that will use it #if defined(USE_API) || defined(USE_PROMETHEUS) || defined(USE_WEBSERVER) || defined(USE_MDNS_EXTRA_SERVICES) -static std::string mdns_string_p(const char *src) { +static std::string mdns_str_value(PGM_P str) { char buf[64]; - strncpy_P(buf, src, sizeof(buf) - 1); + strncpy_P(buf, str, sizeof(buf) - 1); buf[sizeof(buf) - 1] = '\0'; return std::string(buf); } -#define MDNS_STR(name) mdns_string_p(name) -#else -// If no services are configured, we still need the fallback service but it uses string literals -#define MDNS_STR(name) std::string(name) +#define MDNS_STR_VALUE(name) mdns_str_value(name) #endif #else // On non-ESP8266 platforms, use regular const char* -#define MDNS_STATIC_CONST_CHAR(name, value) static constexpr const char *name = value -#define MDNS_STR(name) name +#define MDNS_STATIC_CONST_CHAR(name, value) static constexpr const char name[] = value +#define MDNS_STR_VALUE(name) std::string(name) #endif #ifdef USE_API @@ -74,32 +71,12 @@ MDNS_STATIC_CONST_CHAR(NETWORK_THREAD, "thread"); void MDNSComponent::compile_records_() { this->hostname_ = App.get_name(); - // Calculate exact capacity needed for services vector - size_t services_count = 0; -#ifdef USE_API - if (api::global_api_server != nullptr) { - services_count++; - } -#endif -#ifdef USE_PROMETHEUS - services_count++; -#endif -#ifdef USE_WEBSERVER - services_count++; -#endif -#ifdef USE_MDNS_EXTRA_SERVICES - services_count += this->services_extra_.size(); -#endif - // Reserve for fallback service if needed - if (services_count == 0) { - services_count = 1; - } - this->services_.reserve(services_count); + // IMPORTANT: The #ifdef blocks below must match COMPONENTS_WITH_MDNS_SERVICES + // in mdns/__init__.py. If you add a new service here, update both locations. #ifdef USE_API if (api::global_api_server != nullptr) { - this->services_.emplace_back(); - auto &service = this->services_.back(); + auto &service = this->services_.emplace_next(); service.service_type = MDNS_STR(SERVICE_ESPHOMELIB); service.proto = MDNS_STR(SERVICE_TCP); service.port = api::global_api_server->get_port(); @@ -138,31 +115,31 @@ void MDNSComponent::compile_records_() { txt_records.push_back({MDNS_STR(TXT_MAC), get_mac_address()}); #ifdef USE_ESP8266 - txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR(PLATFORM_ESP8266)}); + txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR_VALUE(PLATFORM_ESP8266)}); #elif defined(USE_ESP32) - txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR(PLATFORM_ESP32)}); + txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR_VALUE(PLATFORM_ESP32)}); #elif defined(USE_RP2040) - txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR(PLATFORM_RP2040)}); + txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR_VALUE(PLATFORM_RP2040)}); #elif defined(USE_LIBRETINY) - txt_records.emplace_back(MDNSTXTRecord{"platform", lt_cpu_get_model_name()}); + txt_records.push_back({MDNS_STR(TXT_PLATFORM), lt_cpu_get_model_name()}); #endif txt_records.push_back({MDNS_STR(TXT_BOARD), ESPHOME_BOARD}); #if defined(USE_WIFI) - txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR(NETWORK_WIFI)}); + txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR_VALUE(NETWORK_WIFI)}); #elif defined(USE_ETHERNET) - txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR(NETWORK_ETHERNET)}); + txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR_VALUE(NETWORK_ETHERNET)}); #elif defined(USE_OPENTHREAD) - txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR(NETWORK_THREAD)}); + txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR_VALUE(NETWORK_THREAD)}); #endif #ifdef USE_API_NOISE MDNS_STATIC_CONST_CHAR(NOISE_ENCRYPTION, "Noise_NNpsk0_25519_ChaChaPoly_SHA256"); if (api::global_api_server->get_noise_ctx()->has_psk()) { - txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION), MDNS_STR(NOISE_ENCRYPTION)}); + txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION), MDNS_STR_VALUE(NOISE_ENCRYPTION)}); } else { - txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION_SUPPORTED), MDNS_STR(NOISE_ENCRYPTION)}); + txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION_SUPPORTED), MDNS_STR_VALUE(NOISE_ENCRYPTION)}); } #endif @@ -178,34 +155,27 @@ void MDNSComponent::compile_records_() { #endif // USE_API #ifdef USE_PROMETHEUS - this->services_.emplace_back(); - auto &prom_service = this->services_.back(); + auto &prom_service = this->services_.emplace_next(); prom_service.service_type = MDNS_STR(SERVICE_PROMETHEUS); prom_service.proto = MDNS_STR(SERVICE_TCP); prom_service.port = USE_WEBSERVER_PORT; #endif #ifdef USE_WEBSERVER - this->services_.emplace_back(); - auto &web_service = this->services_.back(); + auto &web_service = this->services_.emplace_next(); web_service.service_type = MDNS_STR(SERVICE_HTTP); web_service.proto = MDNS_STR(SERVICE_TCP); web_service.port = USE_WEBSERVER_PORT; #endif -#ifdef USE_MDNS_EXTRA_SERVICES - this->services_.insert(this->services_.end(), this->services_extra_.begin(), this->services_extra_.end()); -#endif - #if !defined(USE_API) && !defined(USE_PROMETHEUS) && !defined(USE_WEBSERVER) && !defined(USE_MDNS_EXTRA_SERVICES) // Publish "http" service if not using native API or any other services // This is just to have *some* mDNS service so that .local resolution works - this->services_.emplace_back(); - auto &fallback_service = this->services_.back(); - fallback_service.service_type = "_http"; - fallback_service.proto = "_tcp"; + auto &fallback_service = this->services_.emplace_next(); + fallback_service.service_type = MDNS_STR(SERVICE_HTTP); + fallback_service.proto = MDNS_STR(SERVICE_TCP); fallback_service.port = USE_WEBSERVER_PORT; - fallback_service.txt_records.emplace_back(MDNSTXTRecord{"version", ESPHOME_VERSION}); + fallback_service.txt_records.push_back({MDNS_STR(TXT_VERSION), ESPHOME_VERSION}); #endif } @@ -214,21 +184,19 @@ void MDNSComponent::dump_config() { "mDNS:\n" " Hostname: %s", this->hostname_.c_str()); -#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERY_VERBOSE +#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE ESP_LOGV(TAG, " Services:"); for (const auto &service : this->services_) { - ESP_LOGV(TAG, " - %s, %s, %d", service.service_type.c_str(), service.proto.c_str(), + ESP_LOGV(TAG, " - %s, %s, %d", MDNS_STR_ARG(service.service_type), MDNS_STR_ARG(service.proto), const_cast &>(service.port).value()); for (const auto &record : service.txt_records) { - ESP_LOGV(TAG, " TXT: %s = %s", record.key.c_str(), + ESP_LOGV(TAG, " TXT: %s = %s", MDNS_STR_ARG(record.key), const_cast &>(record.value).value().c_str()); } } #endif } -std::vector MDNSComponent::get_services() { return this->services_; } - } // namespace mdns } // namespace esphome #endif diff --git a/esphome/components/mdns/mdns_component.h b/esphome/components/mdns/mdns_component.h index f87ef08bcd..b1f73fbb32 100644 --- a/esphome/components/mdns/mdns_component.h +++ b/esphome/components/mdns/mdns_component.h @@ -2,25 +2,41 @@ #include "esphome/core/defines.h" #ifdef USE_MDNS #include -#include #include "esphome/core/automation.h" #include "esphome/core/component.h" +#include "esphome/core/helpers.h" namespace esphome { namespace mdns { +// Helper struct that identifies strings that may be stored in flash storage (similar to LogString) +struct MDNSString; + +// Macro to cast string literals to MDNSString* (works on all platforms) +#define MDNS_STR(name) (reinterpret_cast(name)) + +#ifdef USE_ESP8266 +#include +#define MDNS_STR_ARG(s) ((PGM_P) (s)) +#else +#define MDNS_STR_ARG(s) (reinterpret_cast(s)) +#endif + +// Service count is calculated at compile time by Python codegen +// MDNS_SERVICE_COUNT will always be defined + struct MDNSTXTRecord { - std::string key; + const MDNSString *key; TemplatableValue value; }; struct MDNSService { // service name _including_ underscore character prefix // as defined in RFC6763 Section 7 - std::string service_type; + const MDNSString *service_type; // second label indicating protocol _including_ underscore character prefix // as defined in RFC6763 Section 7, like "_tcp" or "_udp" - std::string proto; + const MDNSString *proto; TemplatableValue port; std::vector txt_records; }; @@ -36,18 +52,15 @@ class MDNSComponent : public Component { float get_setup_priority() const override { return setup_priority::AFTER_CONNECTION; } #ifdef USE_MDNS_EXTRA_SERVICES - void add_extra_service(MDNSService service) { services_extra_.push_back(std::move(service)); } + void add_extra_service(MDNSService service) { this->services_.emplace_next() = std::move(service); } #endif - std::vector get_services(); + const StaticVector &get_services() const { return this->services_; } void on_shutdown() override; protected: -#ifdef USE_MDNS_EXTRA_SERVICES - std::vector services_extra_{}; -#endif - std::vector services_{}; + StaticVector services_{}; std::string hostname_; void compile_records_(); }; diff --git a/esphome/components/mdns/mdns_esp32.cpp b/esphome/components/mdns/mdns_esp32.cpp index ffd86afec1..40d305a1e6 100644 --- a/esphome/components/mdns/mdns_esp32.cpp +++ b/esphome/components/mdns/mdns_esp32.cpp @@ -29,23 +29,23 @@ void MDNSComponent::setup() { std::vector txt_records; for (const auto &record : service.txt_records) { mdns_txt_item_t it{}; - // dup strings to ensure the pointer is valid even after the record loop - it.key = strdup(record.key.c_str()); + // key is a compile-time string literal in flash, no need to strdup + it.key = MDNS_STR_ARG(record.key); + // value is a temporary from TemplatableValue, must strdup to keep it alive it.value = strdup(const_cast &>(record.value).value().c_str()); txt_records.push_back(it); } uint16_t port = const_cast &>(service.port).value(); - err = mdns_service_add(nullptr, service.service_type.c_str(), service.proto.c_str(), port, txt_records.data(), - txt_records.size()); + err = mdns_service_add(nullptr, MDNS_STR_ARG(service.service_type), MDNS_STR_ARG(service.proto), port, + txt_records.data(), txt_records.size()); // free records for (const auto &it : txt_records) { - delete it.key; // NOLINT(cppcoreguidelines-owning-memory) - delete it.value; // NOLINT(cppcoreguidelines-owning-memory) + free((void *) it.value); // NOLINT(cppcoreguidelines-no-malloc) } if (err != ESP_OK) { - ESP_LOGW(TAG, "Failed to register service %s: %s", service.service_type.c_str(), esp_err_to_name(err)); + ESP_LOGW(TAG, "Failed to register service %s: %s", MDNS_STR_ARG(service.service_type), esp_err_to_name(err)); } } } diff --git a/esphome/components/mdns/mdns_esp8266.cpp b/esphome/components/mdns/mdns_esp8266.cpp index 2c90d57021..f1c8909807 100644 --- a/esphome/components/mdns/mdns_esp8266.cpp +++ b/esphome/components/mdns/mdns_esp8266.cpp @@ -21,18 +21,18 @@ void MDNSComponent::setup() { // part of the wire protocol to have an underscore, and for example ESP-IDF // expects the underscore to be there, the ESP8266 implementation always adds // the underscore itself. - auto *proto = service.proto.c_str(); - while (*proto == '_') { + auto *proto = MDNS_STR_ARG(service.proto); + while (progmem_read_byte((const uint8_t *) proto) == '_') { proto++; } - auto *service_type = service.service_type.c_str(); - while (*service_type == '_') { + auto *service_type = MDNS_STR_ARG(service.service_type); + while (progmem_read_byte((const uint8_t *) service_type) == '_') { service_type++; } uint16_t port = const_cast &>(service.port).value(); - MDNS.addService(service_type, proto, port); + MDNS.addService(FPSTR(service_type), FPSTR(proto), port); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + MDNS.addServiceTxt(FPSTR(service_type), FPSTR(proto), FPSTR(MDNS_STR_ARG(record.key)), const_cast &>(record.value).value().c_str()); } } diff --git a/esphome/components/mdns/mdns_libretiny.cpp b/esphome/components/mdns/mdns_libretiny.cpp index 7a41ec9dce..9010ca2bc6 100644 --- a/esphome/components/mdns/mdns_libretiny.cpp +++ b/esphome/components/mdns/mdns_libretiny.cpp @@ -21,18 +21,18 @@ void MDNSComponent::setup() { // part of the wire protocol to have an underscore, and for example ESP-IDF // expects the underscore to be there, the ESP8266 implementation always adds // the underscore itself. - auto *proto = service.proto.c_str(); + auto *proto = MDNS_STR_ARG(service.proto); while (*proto == '_') { proto++; } - auto *service_type = service.service_type.c_str(); + auto *service_type = MDNS_STR_ARG(service.service_type); while (*service_type == '_') { service_type++; } uint16_t port_ = const_cast &>(service.port).value(); MDNS.addService(service_type, proto, port_); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + MDNS.addServiceTxt(service_type, proto, MDNS_STR_ARG(record.key), const_cast &>(record.value).value().c_str()); } } diff --git a/esphome/components/mdns/mdns_rp2040.cpp b/esphome/components/mdns/mdns_rp2040.cpp index 95894323f4..039453f501 100644 --- a/esphome/components/mdns/mdns_rp2040.cpp +++ b/esphome/components/mdns/mdns_rp2040.cpp @@ -21,18 +21,18 @@ void MDNSComponent::setup() { // part of the wire protocol to have an underscore, and for example ESP-IDF // expects the underscore to be there, the ESP8266 implementation always adds // the underscore itself. - auto *proto = service.proto.c_str(); + auto *proto = MDNS_STR_ARG(service.proto); while (*proto == '_') { proto++; } - auto *service_type = service.service_type.c_str(); + auto *service_type = MDNS_STR_ARG(service.service_type); while (*service_type == '_') { service_type++; } uint16_t port = const_cast &>(service.port).value(); MDNS.addService(service_type, proto, port); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + MDNS.addServiceTxt(service_type, proto, MDNS_STR_ARG(record.key), const_cast &>(record.value).value().c_str()); } } diff --git a/esphome/components/mipi/__init__.py b/esphome/components/mipi/__init__.py index f670a5913d..7e687cabaa 100644 --- a/esphome/components/mipi/__init__.py +++ b/esphome/components/mipi/__init__.py @@ -343,11 +343,7 @@ class DriverChip: ) offset_height = native_height - height - offset_height # Swap default dimensions if swap_xy is set, or if rotation is 90/270 and we are not using a buffer - rotated = not requires_buffer(config) and config.get(CONF_ROTATION, 0) in ( - 90, - 270, - ) - if transform.get(CONF_SWAP_XY) is True or rotated: + if transform.get(CONF_SWAP_XY) is True: width, height = height, width offset_height, offset_width = offset_width, offset_height return width, height, offset_width, offset_height diff --git a/esphome/components/mipi_spi/display.py b/esphome/components/mipi_spi/display.py index e891e2daad..52b5b86fba 100644 --- a/esphome/components/mipi_spi/display.py +++ b/esphome/components/mipi_spi/display.py @@ -380,25 +380,41 @@ def get_instance(config): bus_type = BusTypes[bus_type] buffer_type = cg.uint8 if color_depth == 8 else cg.uint16 frac = denominator(config) - rotation = DISPLAY_ROTATIONS[ + rotation = ( 0 if model.rotation_as_transform(config) else config.get(CONF_ROTATION, 0) - ] + ) templateargs = [ buffer_type, bufferpixels, config[CONF_BYTE_ORDER] == "big_endian", display_pixel_mode, bus_type, - width, - height, - offset_width, - offset_height, ] # If a buffer is required, use MipiSpiBuffer, otherwise use MipiSpi if requires_buffer(config): - templateargs.append(rotation) - templateargs.append(frac) + templateargs.extend( + [ + width, + height, + offset_width, + offset_height, + DISPLAY_ROTATIONS[rotation], + frac, + ] + ) return MipiSpiBuffer, templateargs + # Swap height and width if the display is rotated 90 or 270 degrees in software + if rotation in (90, 270): + width, height = height, width + offset_width, offset_height = offset_height, offset_width + templateargs.extend( + [ + width, + height, + offset_width, + offset_height, + ] + ) return MipiSpi, templateargs diff --git a/esphome/components/mipi_spi/mipi_spi.h b/esphome/components/mipi_spi/mipi_spi.h index 00b861f71b..248d5b7104 100644 --- a/esphome/components/mipi_spi/mipi_spi.h +++ b/esphome/components/mipi_spi/mipi_spi.h @@ -340,7 +340,7 @@ class MipiSpi : public display::Display, this->write_cmd_addr_data(0, 0, 0, 0, ptr, w * h, 8); } } else { - for (size_t y = 0; y != h; y++) { + for (size_t y = 0; y != static_cast(h); y++) { if constexpr (BUS_TYPE == BUS_TYPE_SINGLE || BUS_TYPE == BUS_TYPE_SINGLE_16) { this->write_array(ptr, w); } else if constexpr (BUS_TYPE == BUS_TYPE_QUAD) { @@ -372,8 +372,8 @@ class MipiSpi : public display::Display, uint8_t dbuffer[DISPLAYPIXEL * 48]; uint8_t *dptr = dbuffer; auto stride = x_offset + w + x_pad; // stride in pixels - for (size_t y = 0; y != h; y++) { - for (size_t x = 0; x != w; x++) { + for (size_t y = 0; y != static_cast(h); y++) { + for (size_t x = 0; x != static_cast(w); x++) { auto color_val = ptr[y * stride + x]; if constexpr (DISPLAYPIXEL == PIXEL_MODE_18 && BUFFERPIXEL == PIXEL_MODE_16) { // 16 to 18 bit conversion diff --git a/esphome/components/mixer/speaker/mixer_speaker.cpp b/esphome/components/mixer/speaker/mixer_speaker.cpp index fc0517c7be..b0b64f5709 100644 --- a/esphome/components/mixer/speaker/mixer_speaker.cpp +++ b/esphome/components/mixer/speaker/mixer_speaker.cpp @@ -572,7 +572,7 @@ void MixerSpeaker::audio_mixer_task(void *params) { } } else { // Determine how many frames to mix - for (int i = 0; i < transfer_buffers_with_data.size(); ++i) { + for (size_t i = 0; i < transfer_buffers_with_data.size(); ++i) { const uint32_t frames_available_in_buffer = speakers_with_data[i]->get_audio_stream_info().bytes_to_frames(transfer_buffers_with_data[i]->available()); frames_to_mix = std::min(frames_to_mix, frames_available_in_buffer); @@ -581,7 +581,7 @@ void MixerSpeaker::audio_mixer_task(void *params) { audio::AudioStreamInfo primary_stream_info = speakers_with_data[0]->get_audio_stream_info(); // Mix two streams together - for (int i = 1; i < transfer_buffers_with_data.size(); ++i) { + for (size_t i = 1; i < transfer_buffers_with_data.size(); ++i) { mix_audio_samples(primary_buffer, primary_stream_info, reinterpret_cast(transfer_buffers_with_data[i]->get_buffer_start()), speakers_with_data[i]->get_audio_stream_info(), @@ -596,7 +596,7 @@ void MixerSpeaker::audio_mixer_task(void *params) { } // Update source transfer buffer lengths and add new audio durations to the source speaker pending playbacks - for (int i = 0; i < transfer_buffers_with_data.size(); ++i) { + for (size_t i = 0; i < transfer_buffers_with_data.size(); ++i) { transfer_buffers_with_data[i]->decrease_buffer_length( speakers_with_data[i]->get_audio_stream_info().frames_to_bytes(frames_to_mix)); speakers_with_data[i]->pending_playback_frames_ += frames_to_mix; diff --git a/esphome/components/mmc5603/mmc5603.cpp b/esphome/components/mmc5603/mmc5603.cpp index d712e2401d..f0d1044f3f 100644 --- a/esphome/components/mmc5603/mmc5603.cpp +++ b/esphome/components/mmc5603/mmc5603.cpp @@ -128,21 +128,21 @@ void MMC5603Component::update() { raw_x |= buffer[1] << 4; raw_x |= buffer[2] << 0; - const float x = 0.0625 * (raw_x - 524288); + const float x = 0.00625 * (raw_x - 524288); int32_t raw_y = 0; raw_y |= buffer[3] << 12; raw_y |= buffer[4] << 4; raw_y |= buffer[5] << 0; - const float y = 0.0625 * (raw_y - 524288); + const float y = 0.00625 * (raw_y - 524288); int32_t raw_z = 0; raw_z |= buffer[6] << 12; raw_z |= buffer[7] << 4; raw_z |= buffer[8] << 0; - const float z = 0.0625 * (raw_z - 524288); + const float z = 0.00625 * (raw_z - 524288); const float heading = atan2f(0.0f - x, y) * 180.0f / M_PI; ESP_LOGD(TAG, "Got x=%0.02fµT y=%0.02fµT z=%0.02fµT heading=%0.01f°", x, y, z, heading); diff --git a/esphome/components/modbus/modbus.cpp b/esphome/components/modbus/modbus.cpp index 6350f43ef6..20271b4bdb 100644 --- a/esphome/components/modbus/modbus.cpp +++ b/esphome/components/modbus/modbus.cpp @@ -66,7 +66,10 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { uint8_t data_offset = 3; // Per https://modbus.org/docs/Modbus_Application_Protocol_V1_1b3.pdf Ch 5 User-Defined function codes - if (((function_code >= 65) && (function_code <= 72)) || ((function_code >= 100) && (function_code <= 110))) { + if (((function_code >= FUNCTION_CODE_USER_DEFINED_SPACE_1_INIT) && + (function_code <= FUNCTION_CODE_USER_DEFINED_SPACE_1_END)) || + ((function_code >= FUNCTION_CODE_USER_DEFINED_SPACE_2_INIT) && + (function_code <= FUNCTION_CODE_USER_DEFINED_SPACE_2_END))) { // Handle user-defined function, since we don't know how big this ought to be, // ideally we should delegate the entire length detection to whatever handler is // installed, but wait, there is the CRC, and if we get a hit there is a good @@ -91,10 +94,14 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { } else { // data starts at 2 and length is 4 for read registers commands if (this->role == ModbusRole::SERVER) { - if (function_code == 0x1 || function_code == 0x3 || function_code == 0x4 || function_code == 0x6) { + if (function_code == ModbusFunctionCode::READ_COILS || + function_code == ModbusFunctionCode::READ_DISCRETE_INPUTS || + function_code == ModbusFunctionCode::READ_HOLDING_REGISTERS || + function_code == ModbusFunctionCode::READ_INPUT_REGISTERS || + function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER) { data_offset = 2; data_len = 4; - } else if (function_code == 0x10) { + } else if (function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) { if (at < 6) { return true; } @@ -104,7 +111,10 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { } } else { // the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands - if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) { + if (function_code == ModbusFunctionCode::WRITE_SINGLE_COIL || + function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER || + function_code == ModbusFunctionCode::WRITE_MULTIPLE_COILS || + function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) { data_offset = 2; data_len = 4; } @@ -112,7 +122,7 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { // Error ( msb indicates error ) // response format: Byte[0] = device address, Byte[1] function code | 0x80 , Byte[2] exception code, Byte[3-4] crc - if ((function_code & 0x80) == 0x80) { + if ((function_code & FUNCTION_CODE_EXCEPTION_MASK) == FUNCTION_CODE_EXCEPTION_MASK) { data_offset = 2; data_len = 1; } @@ -143,10 +153,10 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { if (device->address_ == address) { found = true; // Is it an error response? - if ((function_code & 0x80) == 0x80) { + if ((function_code & FUNCTION_CODE_EXCEPTION_MASK) == FUNCTION_CODE_EXCEPTION_MASK) { ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]); if (waiting_for_response != 0) { - device->on_modbus_error(function_code & 0x7F, raw[2]); + device->on_modbus_error(function_code & FUNCTION_CODE_MASK, raw[2]); } else { // Ignore modbus exception not related to a pending command ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response"); @@ -154,12 +164,14 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { continue; } if (this->role == ModbusRole::SERVER) { - if (function_code == 0x3 || function_code == 0x4) { + if (function_code == ModbusFunctionCode::READ_HOLDING_REGISTERS || + function_code == ModbusFunctionCode::READ_INPUT_REGISTERS) { device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8), uint16_t(data[3]) | (uint16_t(data[2]) << 8)); continue; } - if (function_code == 0x6 || function_code == 0x10) { + if (function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER || + function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) { device->on_modbus_write_registers(function_code, data); continue; } @@ -199,7 +211,7 @@ void Modbus::send(uint8_t address, uint8_t function_code, uint16_t start_address // Only check max number of registers for standard function codes // Some devices use non standard codes like 0x43 - if (number_of_entities > MAX_VALUES && function_code <= 0x10) { + if (number_of_entities > MAX_VALUES && function_code <= ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) { ESP_LOGE(TAG, "send too many values %d max=%zu", number_of_entities, MAX_VALUES); return; } @@ -210,15 +222,17 @@ void Modbus::send(uint8_t address, uint8_t function_code, uint16_t start_address if (this->role == ModbusRole::CLIENT) { data.push_back(start_address >> 8); data.push_back(start_address >> 0); - if (function_code != 0x5 && function_code != 0x6) { + if (function_code != ModbusFunctionCode::WRITE_SINGLE_COIL && + function_code != ModbusFunctionCode::WRITE_SINGLE_REGISTER) { data.push_back(number_of_entities >> 8); data.push_back(number_of_entities >> 0); } } if (payload != nullptr) { - if (this->role == ModbusRole::SERVER || function_code == 0xF || function_code == 0x10) { // Write multiple - data.push_back(payload_len); // Byte count is required for write + if (this->role == ModbusRole::SERVER || function_code == ModbusFunctionCode::WRITE_MULTIPLE_COILS || + function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) { // Write multiple + data.push_back(payload_len); // Byte count is required for write } else { payload_len = 2; // Write single register or coil } diff --git a/esphome/components/modbus/modbus.h b/esphome/components/modbus/modbus.h index ec35612690..fac74aaadf 100644 --- a/esphome/components/modbus/modbus.h +++ b/esphome/components/modbus/modbus.h @@ -3,6 +3,8 @@ #include "esphome/core/component.h" #include "esphome/components/uart/uart.h" +#include "esphome/components/modbus/modbus_definitions.h" + #include namespace esphome { @@ -65,12 +67,12 @@ class ModbusDevice { this->parent_->send(this->address_, function, start_address, number_of_entities, payload_len, payload); } void send_raw(const std::vector &payload) { this->parent_->send_raw(payload); } - void send_error(uint8_t function_code, uint8_t exception_code) { + void send_error(uint8_t function_code, ModbusExceptionCode exception_code) { std::vector error_response; error_response.reserve(3); error_response.push_back(this->address_); - error_response.push_back(function_code | 0x80); - error_response.push_back(exception_code); + error_response.push_back(function_code | FUNCTION_CODE_EXCEPTION_MASK); + error_response.push_back(static_cast(exception_code)); this->send_raw(error_response); } // If more than one device is connected block sending a new command before a response is received diff --git a/esphome/components/modbus/modbus_definitions.h b/esphome/components/modbus/modbus_definitions.h new file mode 100644 index 0000000000..07f101ae4c --- /dev/null +++ b/esphome/components/modbus/modbus_definitions.h @@ -0,0 +1,86 @@ +#pragma once + +#include "esphome/core/component.h" + +namespace esphome { +namespace modbus { + +/// Modbus definitions from specs: +/// https://modbus.org/docs/Modbus_Application_Protocol_V1_1b3.pdf +// 5 Function Code Categories +const uint8_t FUNCTION_CODE_USER_DEFINED_SPACE_1_INIT = 65; // 0x41 +const uint8_t FUNCTION_CODE_USER_DEFINED_SPACE_1_END = 72; // 0x48 + +const uint8_t FUNCTION_CODE_USER_DEFINED_SPACE_2_INIT = 100; // 0x64 +const uint8_t FUNCTION_CODE_USER_DEFINED_SPACE_2_END = 110; // 0x6E + +enum class ModbusFunctionCode : uint8_t { + CUSTOM = 0x00, + READ_COILS = 0x01, + READ_DISCRETE_INPUTS = 0x02, + READ_HOLDING_REGISTERS = 0x03, + READ_INPUT_REGISTERS = 0x04, + WRITE_SINGLE_COIL = 0x05, + WRITE_SINGLE_REGISTER = 0x06, + READ_EXCEPTION_STATUS = 0x07, // not implemented + DIAGNOSTICS = 0x08, // not implemented + GET_COMM_EVENT_COUNTER = 0x0B, // not implemented + GET_COMM_EVENT_LOG = 0x0C, // not implemented + WRITE_MULTIPLE_COILS = 0x0F, + WRITE_MULTIPLE_REGISTERS = 0x10, + REPORT_SERVER_ID = 0x11, // not implemented + READ_FILE_RECORD = 0x14, // not implemented + WRITE_FILE_RECORD = 0x15, // not implemented + MASK_WRITE_REGISTER = 0x16, // not implemented + READ_WRITE_MULTIPLE_REGISTERS = 0x17, // not implemented + READ_FIFO_QUEUE = 0x18, // not implemented +}; + +/*Allow comparison operators between ModbusFunctionCode and uint8_t*/ +inline bool operator==(ModbusFunctionCode lhs, uint8_t rhs) { return static_cast(lhs) == rhs; } +inline bool operator==(uint8_t lhs, ModbusFunctionCode rhs) { return lhs == static_cast(rhs); } +inline bool operator!=(ModbusFunctionCode lhs, uint8_t rhs) { return !(static_cast(lhs) == rhs); } +inline bool operator!=(uint8_t lhs, ModbusFunctionCode rhs) { return !(lhs == static_cast(rhs)); } +inline bool operator<(ModbusFunctionCode lhs, uint8_t rhs) { return static_cast(lhs) < rhs; } +inline bool operator<(uint8_t lhs, ModbusFunctionCode rhs) { return lhs < static_cast(rhs); } +inline bool operator<=(ModbusFunctionCode lhs, uint8_t rhs) { return static_cast(lhs) <= rhs; } +inline bool operator<=(uint8_t lhs, ModbusFunctionCode rhs) { return lhs <= static_cast(rhs); } +inline bool operator>(ModbusFunctionCode lhs, uint8_t rhs) { return static_cast(lhs) > rhs; } +inline bool operator>(uint8_t lhs, ModbusFunctionCode rhs) { return lhs > static_cast(rhs); } +inline bool operator>=(ModbusFunctionCode lhs, uint8_t rhs) { return static_cast(lhs) >= rhs; } +inline bool operator>=(uint8_t lhs, ModbusFunctionCode rhs) { return lhs >= static_cast(rhs); } + +// 4.3 MODBUS Data model +enum class ModbusRegisterType : uint8_t { + CUSTOM = 0x00, + COIL = 0x01, + DISCRETE_INPUT = 0x02, + HOLDING = 0x03, + READ = 0x04, +}; + +// 7 MODBUS Exception Responses: +const uint8_t FUNCTION_CODE_MASK = 0x7F; +const uint8_t FUNCTION_CODE_EXCEPTION_MASK = 0x80; + +enum class ModbusExceptionCode : uint8_t { + ILLEGAL_FUNCTION = 0x01, + ILLEGAL_DATA_ADDRESS = 0x02, + ILLEGAL_DATA_VALUE = 0x03, + SERVICE_DEVICE_FAILURE = 0x04, + ACKNOWLEDGE = 0x05, + SERVER_DEVICE_BUSY = 0x06, + MEMORY_PARITY_ERROR = 0x08, + GATEWAY_PATH_UNAVAILABLE = 0x0A, + GATEWAY_TARGET_DEVICE_FAILED_TO_RESPOND = 0x0B, +}; + +// 6.12 16 (0x10) Write Multiple registers: +const uint8_t MAX_NUM_OF_REGISTERS_TO_WRITE = 123; // 0x7B + +// 6.3 03 (0x03) Read Holding Registers +// 6.4 04 (0x04) Read Input Registers +const uint8_t MAX_NUM_OF_REGISTERS_TO_READ = 125; // 0x7D +/// End of Modbus definitions +} // namespace modbus +} // namespace esphome diff --git a/esphome/components/modbus_controller/__init__.py b/esphome/components/modbus_controller/__init__.py index 5ab82f5e17..28f3326c47 100644 --- a/esphome/components/modbus_controller/__init__.py +++ b/esphome/components/modbus_controller/__init__.py @@ -20,6 +20,7 @@ from .const import ( CONF_BYTE_OFFSET, CONF_COMMAND_THROTTLE, CONF_CUSTOM_COMMAND, + CONF_ENABLED, CONF_FORCE_NEW_RANGE, CONF_MAX_CMD_RETRIES, CONF_MODBUS_CONTROLLER_ID, @@ -28,8 +29,11 @@ from .const import ( CONF_ON_OFFLINE, CONF_ON_ONLINE, CONF_REGISTER_COUNT, + CONF_REGISTER_LAST_ADDRESS, CONF_REGISTER_TYPE, + CONF_REGISTER_VALUE, CONF_RESPONSE_SIZE, + CONF_SERVER_COURTESY_RESPONSE, CONF_SKIP_UPDATES, CONF_VALUE_TYPE, ) @@ -49,6 +53,7 @@ ModbusController = modbus_controller_ns.class_( ) SensorItem = modbus_controller_ns.struct("SensorItem") +ServerCourtesyResponse = modbus_controller_ns.struct("ServerCourtesyResponse") ServerRegister = modbus_controller_ns.struct("ServerRegister") ModbusFunctionCode_ns = modbus_controller_ns.namespace("ModbusFunctionCode") @@ -143,6 +148,14 @@ ModbusOfflineTrigger = modbus_controller_ns.class_( _LOGGER = logging.getLogger(__name__) +SERVER_COURTESY_RESPONSE_SCHEMA = cv.Schema( + { + cv.Optional(CONF_ENABLED, default=False): cv.boolean, + cv.Optional(CONF_REGISTER_LAST_ADDRESS, default=0xFFFF): cv.hex_uint16_t, + cv.Optional(CONF_REGISTER_VALUE, default=0): cv.hex_uint16_t, + } +) + ModbusServerRegisterSchema = cv.Schema( { cv.GenerateID(): cv.declare_id(ServerRegister), @@ -162,6 +175,7 @@ CONFIG_SCHEMA = cv.All( cv.Optional( CONF_COMMAND_THROTTLE, default="0ms" ): cv.positive_time_period_milliseconds, + cv.Optional(CONF_SERVER_COURTESY_RESPONSE): SERVER_COURTESY_RESPONSE_SCHEMA, cv.Optional(CONF_MAX_CMD_RETRIES, default=4): cv.positive_int, cv.Optional(CONF_OFFLINE_SKIP_UPDATES, default=0): cv.positive_int, cv.Optional( @@ -232,7 +246,7 @@ def validate_modbus_register(config): def _final_validate(config): - if CONF_SERVER_REGISTERS in config: + if CONF_SERVER_COURTESY_RESPONSE in config or CONF_SERVER_REGISTERS in config: return modbus.final_validate_modbus_device("modbus_controller", role="server")( config ) @@ -299,6 +313,20 @@ async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) cg.add(var.set_allow_duplicate_commands(config[CONF_ALLOW_DUPLICATE_COMMANDS])) cg.add(var.set_command_throttle(config[CONF_COMMAND_THROTTLE])) + if server_courtesy_response := config.get(CONF_SERVER_COURTESY_RESPONSE): + cg.add( + var.set_server_courtesy_response( + cg.StructInitializer( + ServerCourtesyResponse, + ("enabled", server_courtesy_response[CONF_ENABLED]), + ( + "register_last_address", + server_courtesy_response[CONF_REGISTER_LAST_ADDRESS], + ), + ("register_value", server_courtesy_response[CONF_REGISTER_VALUE]), + ) + ) + ) cg.add(var.set_max_cmd_retries(config[CONF_MAX_CMD_RETRIES])) cg.add(var.set_offline_skip_updates(config[CONF_OFFLINE_SKIP_UPDATES])) if CONF_SERVER_REGISTERS in config: diff --git a/esphome/components/modbus_controller/const.py b/esphome/components/modbus_controller/const.py index 4d39e48dcd..ee0b5fc633 100644 --- a/esphome/components/modbus_controller/const.py +++ b/esphome/components/modbus_controller/const.py @@ -2,6 +2,7 @@ CONF_ALLOW_DUPLICATE_COMMANDS = "allow_duplicate_commands" CONF_BITMASK = "bitmask" CONF_BYTE_OFFSET = "byte_offset" CONF_COMMAND_THROTTLE = "command_throttle" +CONF_ENABLED = "enabled" CONF_OFFLINE_SKIP_UPDATES = "offline_skip_updates" CONF_CUSTOM_COMMAND = "custom_command" CONF_FORCE_NEW_RANGE = "force_new_range" @@ -13,8 +14,11 @@ CONF_ON_ONLINE = "on_online" CONF_ON_OFFLINE = "on_offline" CONF_RAW_ENCODE = "raw_encode" CONF_REGISTER_COUNT = "register_count" +CONF_REGISTER_LAST_ADDRESS = "register_last_address" CONF_REGISTER_TYPE = "register_type" +CONF_REGISTER_VALUE = "register_value" CONF_RESPONSE_SIZE = "response_size" +CONF_SERVER_COURTESY_RESPONSE = "server_courtesy_response" CONF_SKIP_UPDATES = "skip_updates" CONF_USE_WRITE_MULTIPLE = "use_write_multiple" CONF_VALUE_TYPE = "value_type" diff --git a/esphome/components/modbus_controller/modbus_controller.cpp b/esphome/components/modbus_controller/modbus_controller.cpp index 0f3ddf920d..50bd9f45cb 100644 --- a/esphome/components/modbus_controller/modbus_controller.cpp +++ b/esphome/components/modbus_controller/modbus_controller.cpp @@ -112,6 +112,12 @@ void ModbusController::on_modbus_read_registers(uint8_t function_code, uint16_t "0x%X.", this->address_, function_code, start_address, number_of_registers); + if (number_of_registers == 0 || number_of_registers > modbus::MAX_NUM_OF_REGISTERS_TO_READ) { + ESP_LOGW(TAG, "Invalid number of registers %d. Sending exception response.", number_of_registers); + this->send_error(function_code, ModbusExceptionCode::ILLEGAL_DATA_ADDRESS); + return; + } + std::vector sixteen_bit_response; for (uint16_t current_address = start_address; current_address < start_address + number_of_registers;) { bool found = false; @@ -136,9 +142,21 @@ void ModbusController::on_modbus_read_registers(uint8_t function_code, uint16_t } if (!found) { - ESP_LOGW(TAG, "Could not match any register to address %02X. Sending exception response.", current_address); - send_error(function_code, 0x02); - return; + if (this->server_courtesy_response_.enabled && + (current_address <= this->server_courtesy_response_.register_last_address)) { + ESP_LOGD(TAG, + "Could not match any register to address 0x%02X, but default allowed. " + "Returning default value: %d.", + current_address, this->server_courtesy_response_.register_value); + sixteen_bit_response.push_back(this->server_courtesy_response_.register_value); + current_address += 1; // Just increment by 1, as the default response is a single register + } else { + ESP_LOGW(TAG, + "Could not match any register to address 0x%02X and default not allowed. Sending exception response.", + current_address); + this->send_error(function_code, ModbusExceptionCode::ILLEGAL_DATA_ADDRESS); + return; + } } } @@ -156,27 +174,27 @@ void ModbusController::on_modbus_write_registers(uint8_t function_code, const st uint16_t number_of_registers; uint16_t payload_offset; - if (function_code == 0x10) { + if (function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) { number_of_registers = uint16_t(data[3]) | (uint16_t(data[2]) << 8); - if (number_of_registers == 0 || number_of_registers > 0x7B) { + if (number_of_registers == 0 || number_of_registers > modbus::MAX_NUM_OF_REGISTERS_TO_WRITE) { ESP_LOGW(TAG, "Invalid number of registers %d. Sending exception response.", number_of_registers); - send_error(function_code, 3); + this->send_error(function_code, ModbusExceptionCode::ILLEGAL_DATA_VALUE); return; } uint16_t payload_size = data[4]; if (payload_size != number_of_registers * 2) { ESP_LOGW(TAG, "Payload size of %d bytes is not 2 times the number of registers (%d). Sending exception response.", payload_size, number_of_registers); - send_error(function_code, 3); + this->send_error(function_code, ModbusExceptionCode::ILLEGAL_DATA_VALUE); return; } payload_offset = 5; - } else if (function_code == 0x06) { + } else if (function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER) { number_of_registers = 1; payload_offset = 2; } else { ESP_LOGW(TAG, "Invalid function code 0x%X. Sending exception response.", function_code); - send_error(function_code, 1); + this->send_error(function_code, ModbusExceptionCode::ILLEGAL_FUNCTION); return; } @@ -211,7 +229,7 @@ void ModbusController::on_modbus_write_registers(uint8_t function_code, const st if (!for_each_register([](ServerRegister *server_register, uint16_t offset) -> bool { return server_register->write_lambda != nullptr; })) { - send_error(function_code, 1); + this->send_error(function_code, ModbusExceptionCode::ILLEGAL_FUNCTION); return; } @@ -220,7 +238,7 @@ void ModbusController::on_modbus_write_registers(uint8_t function_code, const st int64_t number = payload_to_number(data, server_register->value_type, offset, 0xFFFFFFFF); return server_register->write_lambda(number); })) { - send_error(function_code, 4); + this->send_error(function_code, ModbusExceptionCode::SERVICE_DEVICE_FAILURE); return; } @@ -431,8 +449,15 @@ void ModbusController::dump_config() { "ModbusController:\n" " Address: 0x%02X\n" " Max Command Retries: %d\n" - " Offline Skip Updates: %d", - this->address_, this->max_cmd_retries_, this->offline_skip_updates_); + " Offline Skip Updates: %d\n" + " Server Courtesy Response:\n" + " Enabled: %s\n" + " Register Last Address: 0x%02X\n" + " Register Value: %d", + this->address_, this->max_cmd_retries_, this->offline_skip_updates_, + this->server_courtesy_response_.enabled ? "true" : "false", + this->server_courtesy_response_.register_last_address, this->server_courtesy_response_.register_value); + #if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE ESP_LOGCONFIG(TAG, "sensormap"); for (auto &it : this->sensorset_) { diff --git a/esphome/components/modbus_controller/modbus_controller.h b/esphome/components/modbus_controller/modbus_controller.h index a86ad1ccb5..6ed05715cb 100644 --- a/esphome/components/modbus_controller/modbus_controller.h +++ b/esphome/components/modbus_controller/modbus_controller.h @@ -16,35 +16,9 @@ namespace modbus_controller { class ModbusController; -enum class ModbusFunctionCode { - CUSTOM = 0x00, - READ_COILS = 0x01, - READ_DISCRETE_INPUTS = 0x02, - READ_HOLDING_REGISTERS = 0x03, - READ_INPUT_REGISTERS = 0x04, - WRITE_SINGLE_COIL = 0x05, - WRITE_SINGLE_REGISTER = 0x06, - READ_EXCEPTION_STATUS = 0x07, // not implemented - DIAGNOSTICS = 0x08, // not implemented - GET_COMM_EVENT_COUNTER = 0x0B, // not implemented - GET_COMM_EVENT_LOG = 0x0C, // not implemented - WRITE_MULTIPLE_COILS = 0x0F, - WRITE_MULTIPLE_REGISTERS = 0x10, - REPORT_SERVER_ID = 0x11, // not implemented - READ_FILE_RECORD = 0x14, // not implemented - WRITE_FILE_RECORD = 0x15, // not implemented - MASK_WRITE_REGISTER = 0x16, // not implemented - READ_WRITE_MULTIPLE_REGISTERS = 0x17, // not implemented - READ_FIFO_QUEUE = 0x18, // not implemented -}; - -enum class ModbusRegisterType : uint8_t { - CUSTOM = 0x0, - COIL = 0x01, - DISCRETE_INPUT = 0x02, - HOLDING = 0x03, - READ = 0x04, -}; +using modbus::ModbusFunctionCode; +using modbus::ModbusRegisterType; +using modbus::ModbusExceptionCode; enum class SensorValueType : uint8_t { RAW = 0x00, // variable length @@ -256,6 +230,12 @@ class SensorItem { bool force_new_range{false}; }; +struct ServerCourtesyResponse { + bool enabled{false}; + uint16_t register_last_address{0xFFFF}; + uint16_t register_value{0}; +}; + class ServerRegister { using ReadLambda = std::function; using WriteLambda = std::function; @@ -530,6 +510,12 @@ class ModbusController : public PollingComponent, public modbus::ModbusDevice { void set_max_cmd_retries(uint8_t max_cmd_retries) { this->max_cmd_retries_ = max_cmd_retries; } /// get how many times a command will be (re)sent if no response is received uint8_t get_max_cmd_retries() { return this->max_cmd_retries_; } + /// Called by esphome generated code to set the server courtesy response object + void set_server_courtesy_response(const ServerCourtesyResponse &server_courtesy_response) { + this->server_courtesy_response_ = server_courtesy_response; + } + /// Get the server courtesy response object + ServerCourtesyResponse get_server_courtesy_response() const { return this->server_courtesy_response_; } protected: /// parse sensormap_ and create range of sequential addresses @@ -572,6 +558,9 @@ class ModbusController : public PollingComponent, public modbus::ModbusDevice { CallbackManager online_callback_{}; /// Server offline callback CallbackManager offline_callback_{}; + /// Server courtesy response + ServerCourtesyResponse server_courtesy_response_{ + .enabled = false, .register_last_address = 0xFFFF, .register_value = 0}; }; /** Convert vector response payload to float. diff --git a/esphome/components/mpr121/mpr121.cpp b/esphome/components/mpr121/mpr121.cpp index 074bc79ea2..5a8a8e7205 100644 --- a/esphome/components/mpr121/mpr121.cpp +++ b/esphome/components/mpr121/mpr121.cpp @@ -11,47 +11,49 @@ namespace mpr121 { static const char *const TAG = "mpr121"; void MPR121Component::setup() { + this->disable_loop(); // soft reset device this->write_byte(MPR121_SOFTRESET, 0x63); - delay(100); // NOLINT - if (!this->write_byte(MPR121_ECR, 0x0)) { - this->error_code_ = COMMUNICATION_FAILED; - this->mark_failed(); - return; - } + this->set_timeout(100, [this]() { + if (!this->write_byte(MPR121_ECR, 0x0)) { + this->error_code_ = COMMUNICATION_FAILED; + this->mark_failed(); + return; + } + // set touch sensitivity for all 12 channels + for (auto *channel : this->channels_) { + channel->setup(); + } + this->write_byte(MPR121_MHDR, 0x01); + this->write_byte(MPR121_NHDR, 0x01); + this->write_byte(MPR121_NCLR, 0x0E); + this->write_byte(MPR121_FDLR, 0x00); - // set touch sensitivity for all 12 channels - for (auto *channel : this->channels_) { - channel->setup(); - } - this->write_byte(MPR121_MHDR, 0x01); - this->write_byte(MPR121_NHDR, 0x01); - this->write_byte(MPR121_NCLR, 0x0E); - this->write_byte(MPR121_FDLR, 0x00); + this->write_byte(MPR121_MHDF, 0x01); + this->write_byte(MPR121_NHDF, 0x05); + this->write_byte(MPR121_NCLF, 0x01); + this->write_byte(MPR121_FDLF, 0x00); - this->write_byte(MPR121_MHDF, 0x01); - this->write_byte(MPR121_NHDF, 0x05); - this->write_byte(MPR121_NCLF, 0x01); - this->write_byte(MPR121_FDLF, 0x00); + this->write_byte(MPR121_NHDT, 0x00); + this->write_byte(MPR121_NCLT, 0x00); + this->write_byte(MPR121_FDLT, 0x00); - this->write_byte(MPR121_NHDT, 0x00); - this->write_byte(MPR121_NCLT, 0x00); - this->write_byte(MPR121_FDLT, 0x00); + this->write_byte(MPR121_DEBOUNCE, 0); + // default, 16uA charge current + this->write_byte(MPR121_CONFIG1, 0x10); + // 0.5uS encoding, 1ms period + this->write_byte(MPR121_CONFIG2, 0x20); - this->write_byte(MPR121_DEBOUNCE, 0); - // default, 16uA charge current - this->write_byte(MPR121_CONFIG1, 0x10); - // 0.5uS encoding, 1ms period - this->write_byte(MPR121_CONFIG2, 0x20); + // Write the Electrode Configuration Register + // * Highest 2 bits is "Calibration Lock", which we set to a value corresponding to 5 bits. + // * The 2 bits below is "Proximity Enable" and are left at 0. + // * The 4 least significant bits control how many electrodes are enabled. Electrodes are enabled + // as a range, starting at 0 up to the highest channel index used. + this->write_byte(MPR121_ECR, 0x80 | (this->max_touch_channel_ + 1)); - // Write the Electrode Configuration Register - // * Highest 2 bits is "Calibration Lock", which we set to a value corresponding to 5 bits. - // * The 2 bits below is "Proximity Enable" and are left at 0. - // * The 4 least significant bits control how many electrodes are enabled. Electrodes are enabled - // as a range, starting at 0 up to the highest channel index used. - this->write_byte(MPR121_ECR, 0x80 | (this->max_touch_channel_ + 1)); - - this->flush_gpio_(); + this->flush_gpio_(); + this->enable_loop(); + }); } void MPR121Component::set_touch_debounce(uint8_t debounce) { @@ -73,9 +75,6 @@ void MPR121Component::dump_config() { case COMMUNICATION_FAILED: ESP_LOGE(TAG, ESP_LOG_MSG_COMM_FAIL); break; - case WRONG_CHIP_STATE: - ESP_LOGE(TAG, "MPR121 has wrong default value for CONFIG2?"); - break; case NONE: default: break; diff --git a/esphome/components/mpr121/mpr121.h b/esphome/components/mpr121/mpr121.h index eb2e2edc57..6dd2c38309 100644 --- a/esphome/components/mpr121/mpr121.h +++ b/esphome/components/mpr121/mpr121.h @@ -88,7 +88,6 @@ class MPR121Component : public Component, public i2c::I2CDevice { enum ErrorCode { NONE = 0, COMMUNICATION_FAILED, - WRONG_CHIP_STATE, } error_code_{NONE}; bool flush_gpio_(); diff --git a/esphome/components/nau7802/nau7802.cpp b/esphome/components/nau7802/nau7802.cpp index acdca03fdb..6a31b754f7 100644 --- a/esphome/components/nau7802/nau7802.cpp +++ b/esphome/components/nau7802/nau7802.cpp @@ -218,7 +218,7 @@ void NAU7802Sensor::dump_config() { void NAU7802Sensor::write_value_(uint8_t start_reg, size_t size, int32_t value) { uint8_t data[4]; - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { data[i] = 0xFF & (value >> (size - 1 - i) * 8); } this->write_register(start_reg, data, size); @@ -228,7 +228,7 @@ int32_t NAU7802Sensor::read_value_(uint8_t start_reg, size_t size) { uint8_t data[4]; this->read_register(start_reg, data, size); int32_t result = 0; - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { result |= data[i] << (size - 1 - i) * 8; } // extend sign bit diff --git a/esphome/components/network/__init__.py b/esphome/components/network/__init__.py index 9679961b15..1a74350c4c 100644 --- a/esphome/components/network/__init__.py +++ b/esphome/components/network/__init__.py @@ -47,9 +47,13 @@ async def to_code(config): cg.add_define( "USE_NETWORK_MIN_IPV6_ADDR_COUNT", config[CONF_MIN_IPV6_ADDR_COUNT] ) - if CORE.using_esp_idf: - add_idf_sdkconfig_option("CONFIG_LWIP_IPV6", enable_ipv6) - add_idf_sdkconfig_option("CONFIG_LWIP_IPV6_AUTOCONFIG", enable_ipv6) + if CORE.is_esp32: + if CORE.using_esp_idf: + add_idf_sdkconfig_option("CONFIG_LWIP_IPV6", enable_ipv6) + add_idf_sdkconfig_option("CONFIG_LWIP_IPV6_AUTOCONFIG", enable_ipv6) + else: + add_idf_sdkconfig_option("CONFIG_LWIP_IPV6", True) + add_idf_sdkconfig_option("CONFIG_LWIP_IPV6_AUTOCONFIG", True) elif enable_ipv6: cg.add_build_flag("-DCONFIG_LWIP_IPV6") cg.add_build_flag("-DCONFIG_LWIP_IPV6_AUTOCONFIG") diff --git a/esphome/components/nextion/display.py b/esphome/components/nextion/display.py index 4254ae45fe..ed6cd93027 100644 --- a/esphome/components/nextion/display.py +++ b/esphome/components/nextion/display.py @@ -153,10 +153,10 @@ async def to_code(config): if CONF_TFT_URL in config: cg.add_define("USE_NEXTION_TFT_UPLOAD") cg.add(var.set_tft_url(config[CONF_TFT_URL])) - if CORE.is_esp32 and CORE.using_arduino: - cg.add_library("NetworkClientSecure", None) - cg.add_library("HTTPClient", None) - elif CORE.is_esp32 and CORE.using_esp_idf: + if CORE.is_esp32: + if CORE.using_arduino: + cg.add_library("NetworkClientSecure", None) + cg.add_library("HTTPClient", None) esp32.add_idf_sdkconfig_option("CONFIG_ESP_TLS_INSECURE", True) esp32.add_idf_sdkconfig_option( "CONFIG_ESP_TLS_SKIP_SERVER_CERT_VERIFY", True diff --git a/esphome/components/nextion/nextion.cpp b/esphome/components/nextion/nextion.cpp index b348bc9920..0ce9d02e97 100644 --- a/esphome/components/nextion/nextion.cpp +++ b/esphome/components/nextion/nextion.cpp @@ -77,7 +77,7 @@ bool Nextion::check_connect_() { this->recv_ret_string_(response, 0, false); if (!response.empty() && response[0] == 0x1A) { // Swallow invalid variable name responses that may be caused by the above commands - ESP_LOGD(TAG, "0x1A error ignored (setup)"); + ESP_LOGV(TAG, "0x1A error ignored (setup)"); return false; } if (response.empty() || response.find("comok") == std::string::npos) { @@ -334,7 +334,7 @@ void Nextion::loop() { this->started_ms_ = App.get_loop_component_start_time(); if (this->started_ms_ + this->startup_override_ms_ < App.get_loop_component_start_time()) { - ESP_LOGD(TAG, "Manual ready set"); + ESP_LOGV(TAG, "Manual ready set"); this->connection_state_.nextion_reports_is_setup_ = true; } } @@ -544,7 +544,7 @@ void Nextion::process_nextion_commands_() { uint8_t page_id = to_process[0]; uint8_t component_id = to_process[1]; uint8_t touch_event = to_process[2]; // 0 -> release, 1 -> press - ESP_LOGD(TAG, "Touch %s: page %u comp %u", touch_event ? "PRESS" : "RELEASE", page_id, component_id); + ESP_LOGV(TAG, "Touch %s: page %u comp %u", touch_event ? "PRESS" : "RELEASE", page_id, component_id); for (auto *touch : this->touch_) { touch->process_touch(page_id, component_id, touch_event != 0); } @@ -559,7 +559,7 @@ void Nextion::process_nextion_commands_() { } uint8_t page_id = to_process[0]; - ESP_LOGD(TAG, "New page: %u", page_id); + ESP_LOGV(TAG, "New page: %u", page_id); this->page_callback_.call(page_id); break; } @@ -577,7 +577,7 @@ void Nextion::process_nextion_commands_() { const uint16_t x = (uint16_t(to_process[0]) << 8) | to_process[1]; const uint16_t y = (uint16_t(to_process[2]) << 8) | to_process[3]; const uint8_t touch_event = to_process[4]; // 0 -> release, 1 -> press - ESP_LOGD(TAG, "Touch %s at %u,%u", touch_event ? "PRESS" : "RELEASE", x, y); + ESP_LOGV(TAG, "Touch %s at %u,%u", touch_event ? "PRESS" : "RELEASE", x, y); break; } @@ -676,7 +676,7 @@ void Nextion::process_nextion_commands_() { } case 0x88: // system successful start up { - ESP_LOGD(TAG, "System start: %zu", to_process_length); + ESP_LOGV(TAG, "System start: %zu", to_process_length); this->connection_state_.nextion_reports_is_setup_ = true; break; } @@ -922,7 +922,7 @@ void Nextion::set_nextion_sensor_state(NextionQueueType queue_type, const std::s } void Nextion::set_nextion_text_state(const std::string &name, const std::string &state) { - ESP_LOGD(TAG, "State: %s='%s'", name.c_str(), state.c_str()); + ESP_LOGV(TAG, "State: %s='%s'", name.c_str(), state.c_str()); for (auto *sensor : this->textsensortype_) { if (name == sensor->get_variable_name()) { @@ -933,7 +933,7 @@ void Nextion::set_nextion_text_state(const std::string &name, const std::string } void Nextion::all_components_send_state_(bool force_update) { - ESP_LOGD(TAG, "Send states"); + ESP_LOGV(TAG, "Send states"); for (auto *binarysensortype : this->binarysensortype_) { if (force_update || binarysensortype->get_needs_to_send_update()) binarysensortype->send_state_to_nextion(); diff --git a/esphome/components/number/__init__.py b/esphome/components/number/__init__.py index c2cad2f7f1..76a7b05ea1 100644 --- a/esphome/components/number/__init__.py +++ b/esphome/components/number/__init__.py @@ -51,6 +51,7 @@ from esphome.const import ( DEVICE_CLASS_OZONE, DEVICE_CLASS_PH, DEVICE_CLASS_PM1, + DEVICE_CLASS_PM4, DEVICE_CLASS_PM10, DEVICE_CLASS_PM25, DEVICE_CLASS_POWER, @@ -116,6 +117,7 @@ DEVICE_CLASSES = [ DEVICE_CLASS_PM1, DEVICE_CLASS_PM10, DEVICE_CLASS_PM25, + DEVICE_CLASS_PM4, DEVICE_CLASS_POWER, DEVICE_CLASS_POWER_FACTOR, DEVICE_CLASS_PRECIPITATION, diff --git a/esphome/components/number/number_call.cpp b/esphome/components/number/number_call.cpp index 4219f85328..669dd65184 100644 --- a/esphome/components/number/number_call.cpp +++ b/esphome/components/number/number_call.cpp @@ -7,6 +7,17 @@ namespace number { static const char *const TAG = "number"; +// Helper functions to reduce code size for logging +void NumberCall::log_perform_warning_(const LogString *message) { + ESP_LOGW(TAG, "'%s': %s", this->parent_->get_name().c_str(), LOG_STR_ARG(message)); +} + +void NumberCall::log_perform_warning_value_range_(const LogString *comparison, const LogString *limit_type, float val, + float limit) { + ESP_LOGW(TAG, "'%s': %f %s %s %f", this->parent_->get_name().c_str(), val, LOG_STR_ARG(comparison), + LOG_STR_ARG(limit_type), limit); +} + NumberCall &NumberCall::set_value(float value) { return this->with_operation(NUMBER_OP_SET).with_value(value); } NumberCall &NumberCall::number_increment(bool cycle) { @@ -42,7 +53,7 @@ void NumberCall::perform() { const auto &traits = parent->traits; if (this->operation_ == NUMBER_OP_NONE) { - ESP_LOGW(TAG, "'%s' - NumberCall performed without selecting an operation", name); + this->log_perform_warning_(LOG_STR("No operation")); return; } @@ -51,28 +62,28 @@ void NumberCall::perform() { float max_value = traits.get_max_value(); if (this->operation_ == NUMBER_OP_SET) { - ESP_LOGD(TAG, "'%s' - Setting number value", name); + ESP_LOGD(TAG, "'%s': Setting value", name); if (!this->value_.has_value() || std::isnan(*this->value_)) { - ESP_LOGW(TAG, "'%s' - No value set for NumberCall", name); + this->log_perform_warning_(LOG_STR("No value")); return; } target_value = this->value_.value(); } else if (this->operation_ == NUMBER_OP_TO_MIN) { if (std::isnan(min_value)) { - ESP_LOGW(TAG, "'%s' - Can't set to min value through NumberCall: no min_value defined", name); + this->log_perform_warning_(LOG_STR("min undefined")); } else { target_value = min_value; } } else if (this->operation_ == NUMBER_OP_TO_MAX) { if (std::isnan(max_value)) { - ESP_LOGW(TAG, "'%s' - Can't set to max value through NumberCall: no max_value defined", name); + this->log_perform_warning_(LOG_STR("max undefined")); } else { target_value = max_value; } } else if (this->operation_ == NUMBER_OP_INCREMENT) { - ESP_LOGD(TAG, "'%s' - Increment number, with%s cycling", name, this->cycle_ ? "" : "out"); + ESP_LOGD(TAG, "'%s': Increment with%s cycling", name, this->cycle_ ? "" : "out"); if (!parent->has_state()) { - ESP_LOGW(TAG, "'%s' - Can't increment number through NumberCall: no active state to modify", name); + this->log_perform_warning_(LOG_STR("Can't increment, no state")); return; } auto step = traits.get_step(); @@ -85,9 +96,9 @@ void NumberCall::perform() { } } } else if (this->operation_ == NUMBER_OP_DECREMENT) { - ESP_LOGD(TAG, "'%s' - Decrement number, with%s cycling", name, this->cycle_ ? "" : "out"); + ESP_LOGD(TAG, "'%s': Decrement with%s cycling", name, this->cycle_ ? "" : "out"); if (!parent->has_state()) { - ESP_LOGW(TAG, "'%s' - Can't decrement number through NumberCall: no active state to modify", name); + this->log_perform_warning_(LOG_STR("Can't decrement, no state")); return; } auto step = traits.get_step(); @@ -102,15 +113,15 @@ void NumberCall::perform() { } if (target_value < min_value) { - ESP_LOGW(TAG, "'%s' - Value %f must not be less than minimum %f", name, target_value, min_value); + this->log_perform_warning_value_range_(LOG_STR("<"), LOG_STR("min"), target_value, min_value); return; } if (target_value > max_value) { - ESP_LOGW(TAG, "'%s' - Value %f must not be greater than maximum %f", name, target_value, max_value); + this->log_perform_warning_value_range_(LOG_STR(">"), LOG_STR("max"), target_value, max_value); return; } - ESP_LOGD(TAG, " New number value: %f", target_value); + ESP_LOGD(TAG, " New value: %f", target_value); this->parent_->control(target_value); } diff --git a/esphome/components/number/number_call.h b/esphome/components/number/number_call.h index bd50170be5..807207f0ec 100644 --- a/esphome/components/number/number_call.h +++ b/esphome/components/number/number_call.h @@ -1,6 +1,7 @@ #pragma once #include "esphome/core/helpers.h" +#include "esphome/core/log.h" #include "number_traits.h" namespace esphome { @@ -33,6 +34,10 @@ class NumberCall { NumberCall &with_cycle(bool cycle); protected: + void log_perform_warning_(const LogString *message); + void log_perform_warning_value_range_(const LogString *comparison, const LogString *limit_type, float val, + float limit); + Number *const parent_; NumberOperation operation_{NUMBER_OP_NONE}; optional value_; diff --git a/esphome/components/online_image/bmp_image.cpp b/esphome/components/online_image/bmp_image.cpp index f55c9f1813..676a2efca9 100644 --- a/esphome/components/online_image/bmp_image.cpp +++ b/esphome/components/online_image/bmp_image.cpp @@ -117,7 +117,8 @@ int HOT BmpDecoder::decode(uint8_t *buffer, size_t size) { this->paint_index_++; this->current_index_ += 3; index += 3; - if (x == this->width_ - 1 && this->padding_bytes_ > 0) { + size_t last_col = static_cast(this->width_) - 1; + if (x == last_col && this->padding_bytes_ > 0) { index += this->padding_bytes_; this->current_index_ += this->padding_bytes_; } diff --git a/esphome/components/online_image/jpeg_image.cpp b/esphome/components/online_image/jpeg_image.cpp index e5ee3dd8bf..10586091d5 100644 --- a/esphome/components/online_image/jpeg_image.cpp +++ b/esphome/components/online_image/jpeg_image.cpp @@ -25,8 +25,10 @@ static int draw_callback(JPEGDRAW *jpeg) { // to avoid crashing. App.feed_wdt(); size_t position = 0; - for (size_t y = 0; y < jpeg->iHeight; y++) { - for (size_t x = 0; x < jpeg->iWidth; x++) { + size_t height = static_cast(jpeg->iHeight); + size_t width = static_cast(jpeg->iWidth); + for (size_t y = 0; y < height; y++) { + for (size_t x = 0; x < width; x++) { auto rg = decode_value(jpeg->pPixels[position++]); auto ba = decode_value(jpeg->pPixels[position++]); Color color(rg[1], rg[0], ba[1], ba[0]); diff --git a/esphome/components/openthread/openthread.cpp b/esphome/components/openthread/openthread.cpp index 5b5c113f83..bc5dcadef6 100644 --- a/esphome/components/openthread/openthread.cpp +++ b/esphome/components/openthread/openthread.cpp @@ -143,11 +143,10 @@ void OpenThreadSrpComponent::setup() { return; } - // Copy the mdns services to our local instance so that the c_str pointers remain valid for the lifetime of this - // component - this->mdns_services_ = this->mdns_->get_services(); - ESP_LOGD(TAG, "Setting up SRP services. count = %d\n", this->mdns_services_.size()); - for (const auto &service : this->mdns_services_) { + // Get mdns services and copy their data (strings are copied with strdup below) + const auto &mdns_services = this->mdns_->get_services(); + ESP_LOGD(TAG, "Setting up SRP services. count = %d\n", mdns_services.size()); + for (const auto &service : mdns_services) { otSrpClientBuffersServiceEntry *entry = otSrpClientBuffersAllocateService(instance); if (!entry) { ESP_LOGW(TAG, "Failed to allocate service entry"); @@ -156,7 +155,7 @@ void OpenThreadSrpComponent::setup() { // Set service name char *string = otSrpClientBuffersGetServiceEntryServiceNameString(entry, &size); - std::string full_service = service.service_type + "." + service.proto; + std::string full_service = std::string(MDNS_STR_ARG(service.service_type)) + "." + MDNS_STR_ARG(service.proto); if (full_service.size() > size) { ESP_LOGW(TAG, "Service name too long: %s", full_service.c_str()); continue; @@ -182,7 +181,7 @@ void OpenThreadSrpComponent::setup() { for (size_t i = 0; i < service.txt_records.size(); i++) { const auto &txt = service.txt_records[i]; auto value = const_cast &>(txt.value).value(); - txt_entries[i].mKey = strdup(txt.key.c_str()); + txt_entries[i].mKey = MDNS_STR_ARG(txt.key); txt_entries[i].mValue = reinterpret_cast(strdup(value.c_str())); txt_entries[i].mValueLength = value.size(); } diff --git a/esphome/components/openthread/openthread.h b/esphome/components/openthread/openthread.h index a9aff78e56..5d139c633d 100644 --- a/esphome/components/openthread/openthread.h +++ b/esphome/components/openthread/openthread.h @@ -57,7 +57,6 @@ class OpenThreadSrpComponent : public Component { protected: esphome::mdns::MDNSComponent *mdns_{nullptr}; - std::vector mdns_services_; std::vector> memory_pool_; void *pool_alloc_(size_t size); }; diff --git a/esphome/components/ota/ota_backend.h b/esphome/components/ota/ota_backend.h index 372f24df5e..64ee0b9f7c 100644 --- a/esphome/components/ota/ota_backend.h +++ b/esphome/components/ota/ota_backend.h @@ -14,6 +14,7 @@ namespace ota { enum OTAResponseTypes { OTA_RESPONSE_OK = 0x00, OTA_RESPONSE_REQUEST_AUTH = 0x01, + OTA_RESPONSE_REQUEST_SHA256_AUTH = 0x02, OTA_RESPONSE_HEADER_OK = 0x40, OTA_RESPONSE_AUTH_OK = 0x41, diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index 2e7dc0e197..fdc75d995a 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -106,11 +106,13 @@ CONFIG_SCHEMA = cv.Any( ) -def _process_base_package(config: dict) -> dict: +def _process_base_package(config: dict, skip_update: bool = False) -> dict: + # When skip_update is True, use NEVER_REFRESH to prevent updates + actual_refresh = git.NEVER_REFRESH if skip_update else config[CONF_REFRESH] repo_dir, revert = git.clone_or_update( url=config[CONF_URL], ref=config.get(CONF_REF), - refresh=config[CONF_REFRESH], + refresh=actual_refresh, domain=DOMAIN, username=config.get(CONF_USERNAME), password=config.get(CONF_PASSWORD), @@ -180,16 +182,16 @@ def _process_base_package(config: dict) -> dict: return {"packages": packages} -def _process_package(package_config, config): +def _process_package(package_config, config, skip_update: bool = False): recursive_package = package_config if CONF_URL in package_config: - package_config = _process_base_package(package_config) + package_config = _process_base_package(package_config, skip_update) if isinstance(package_config, dict): - recursive_package = do_packages_pass(package_config) + recursive_package = do_packages_pass(package_config, skip_update) return merge_config(recursive_package, config) -def do_packages_pass(config: dict): +def do_packages_pass(config: dict, skip_update: bool = False): if CONF_PACKAGES not in config: return config packages = config[CONF_PACKAGES] @@ -198,10 +200,10 @@ def do_packages_pass(config: dict): if isinstance(packages, dict): for package_name, package_config in reversed(packages.items()): with cv.prepend_path(package_name): - config = _process_package(package_config, config) + config = _process_package(package_config, config, skip_update) elif isinstance(packages, list): for package_config in reversed(packages): - config = _process_package(package_config, config) + config = _process_package(package_config, config, skip_update) else: raise cv.Invalid( f"Packages must be a key to value mapping or list, got {type(packages)} instead" diff --git a/esphome/components/packet_transport/packet_transport.cpp b/esphome/components/packet_transport/packet_transport.cpp index b6ce24bc1b..8bde4ee505 100644 --- a/esphome/components/packet_transport/packet_transport.cpp +++ b/esphome/components/packet_transport/packet_transport.cpp @@ -270,6 +270,7 @@ void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) { auto len = 1 + 1 + 1 + strlen(id); if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { this->flush_(); + this->init_data_(); } add(this->data_, key); add(this->data_, (uint8_t) data); @@ -284,6 +285,7 @@ void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) { auto len = 4 + 1 + 1 + strlen(id); if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { this->flush_(); + this->init_data_(); } add(this->data_, key); add(this->data_, data); diff --git a/esphome/components/pid/pid_controller.cpp b/esphome/components/pid/pid_controller.cpp index 1a16f14542..5d7aecdb05 100644 --- a/esphome/components/pid/pid_controller.cpp +++ b/esphome/components/pid/pid_controller.cpp @@ -104,7 +104,7 @@ float PIDController::weighted_average_(std::deque &list, float new_value, list.push_front(new_value); // keep only 'samples' readings, by popping off the back of the list - while (list.size() > samples) + while (samples > 0 && list.size() > static_cast(samples)) list.pop_back(); // calculate and return the average of all values in the list diff --git a/esphome/components/prometheus/prometheus_handler.cpp b/esphome/components/prometheus/prometheus_handler.cpp index 2677860c7c..68ef18e5ce 100644 --- a/esphome/components/prometheus/prometheus_handler.cpp +++ b/esphome/components/prometheus/prometheus_handler.cpp @@ -110,21 +110,21 @@ std::string PrometheusHandler::relabel_name_(EntityBase *obj) { void PrometheusHandler::add_area_label_(AsyncResponseStream *stream, std::string &area) { if (!area.empty()) { - stream->print(F("\",area=\"")); + stream->print(ESPHOME_F("\",area=\"")); stream->print(area.c_str()); } } void PrometheusHandler::add_node_label_(AsyncResponseStream *stream, std::string &node) { if (!node.empty()) { - stream->print(F("\",node=\"")); + stream->print(ESPHOME_F("\",node=\"")); stream->print(node.c_str()); } } void PrometheusHandler::add_friendly_name_label_(AsyncResponseStream *stream, std::string &friendly_name) { if (!friendly_name.empty()) { - stream->print(F("\",friendly_name=\"")); + stream->print(ESPHOME_F("\",friendly_name=\"")); stream->print(friendly_name.c_str()); } } @@ -132,8 +132,8 @@ void PrometheusHandler::add_friendly_name_label_(AsyncResponseStream *stream, st // Type-specific implementation #ifdef USE_SENSOR void PrometheusHandler::sensor_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_sensor_value gauge\n")); - stream->print(F("#TYPE esphome_sensor_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_sensor_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_sensor_failed gauge\n")); } void PrometheusHandler::sensor_row_(AsyncResponseStream *stream, sensor::Sensor *obj, std::string &area, std::string &node, std::string &friendly_name) { @@ -141,37 +141,37 @@ void PrometheusHandler::sensor_row_(AsyncResponseStream *stream, sensor::Sensor return; if (!std::isnan(obj->state)) { // We have a valid value, output this value - stream->print(F("esphome_sensor_failed{id=\"")); + stream->print(ESPHOME_F("esphome_sensor_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_sensor_value{id=\"")); + stream->print(ESPHOME_F("esphome_sensor_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",unit=\"")); + stream->print(ESPHOME_F("\",unit=\"")); stream->print(obj->get_unit_of_measurement().c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(value_accuracy_to_string(obj->state, obj->get_accuracy_decimals()).c_str()); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } else { // Invalid state - stream->print(F("esphome_sensor_failed{id=\"")); + stream->print(ESPHOME_F("esphome_sensor_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif @@ -179,8 +179,8 @@ void PrometheusHandler::sensor_row_(AsyncResponseStream *stream, sensor::Sensor // Type-specific implementation #ifdef USE_BINARY_SENSOR void PrometheusHandler::binary_sensor_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_binary_sensor_value gauge\n")); - stream->print(F("#TYPE esphome_binary_sensor_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_binary_sensor_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_binary_sensor_failed gauge\n")); } void PrometheusHandler::binary_sensor_row_(AsyncResponseStream *stream, binary_sensor::BinarySensor *obj, std::string &area, std::string &node, std::string &friendly_name) { @@ -188,204 +188,204 @@ void PrometheusHandler::binary_sensor_row_(AsyncResponseStream *stream, binary_s return; if (obj->has_state()) { // We have a valid value, output this value - stream->print(F("esphome_binary_sensor_failed{id=\"")); + stream->print(ESPHOME_F("esphome_binary_sensor_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_binary_sensor_value{id=\"")); + stream->print(ESPHOME_F("esphome_binary_sensor_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->state); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } else { // Invalid state - stream->print(F("esphome_binary_sensor_failed{id=\"")); + stream->print(ESPHOME_F("esphome_binary_sensor_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif #ifdef USE_FAN void PrometheusHandler::fan_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_fan_value gauge\n")); - stream->print(F("#TYPE esphome_fan_failed gauge\n")); - stream->print(F("#TYPE esphome_fan_speed gauge\n")); - stream->print(F("#TYPE esphome_fan_oscillation gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_fan_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_fan_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_fan_speed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_fan_oscillation gauge\n")); } void PrometheusHandler::fan_row_(AsyncResponseStream *stream, fan::Fan *obj, std::string &area, std::string &node, std::string &friendly_name) { if (obj->is_internal() && !this->include_internal_) return; - stream->print(F("esphome_fan_failed{id=\"")); + stream->print(ESPHOME_F("esphome_fan_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_fan_value{id=\"")); + stream->print(ESPHOME_F("esphome_fan_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->state); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); // Speed if available if (obj->get_traits().supports_speed()) { - stream->print(F("esphome_fan_speed{id=\"")); + stream->print(ESPHOME_F("esphome_fan_speed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->speed); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } // Oscillation if available if (obj->get_traits().supports_oscillation()) { - stream->print(F("esphome_fan_oscillation{id=\"")); + stream->print(ESPHOME_F("esphome_fan_oscillation{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->oscillating); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } } #endif #ifdef USE_LIGHT void PrometheusHandler::light_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_light_state gauge\n")); - stream->print(F("#TYPE esphome_light_color gauge\n")); - stream->print(F("#TYPE esphome_light_effect_active gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_light_state gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_light_color gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_light_effect_active gauge\n")); } void PrometheusHandler::light_row_(AsyncResponseStream *stream, light::LightState *obj, std::string &area, std::string &node, std::string &friendly_name) { if (obj->is_internal() && !this->include_internal_) return; // State - stream->print(F("esphome_light_state{id=\"")); + stream->print(ESPHOME_F("esphome_light_state{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->remote_values.is_on()); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); // Brightness and RGBW light::LightColorValues color = obj->current_values; float brightness, r, g, b, w; color.as_brightness(&brightness); color.as_rgbw(&r, &g, &b, &w); - stream->print(F("esphome_light_color{id=\"")); + stream->print(ESPHOME_F("esphome_light_color{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",channel=\"brightness\"} ")); + stream->print(ESPHOME_F("\",channel=\"brightness\"} ")); stream->print(brightness); - stream->print(F("\n")); - stream->print(F("esphome_light_color{id=\"")); + stream->print(ESPHOME_F("\n")); + stream->print(ESPHOME_F("esphome_light_color{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",channel=\"r\"} ")); + stream->print(ESPHOME_F("\",channel=\"r\"} ")); stream->print(r); - stream->print(F("\n")); - stream->print(F("esphome_light_color{id=\"")); + stream->print(ESPHOME_F("\n")); + stream->print(ESPHOME_F("esphome_light_color{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",channel=\"g\"} ")); + stream->print(ESPHOME_F("\",channel=\"g\"} ")); stream->print(g); - stream->print(F("\n")); - stream->print(F("esphome_light_color{id=\"")); + stream->print(ESPHOME_F("\n")); + stream->print(ESPHOME_F("esphome_light_color{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",channel=\"b\"} ")); + stream->print(ESPHOME_F("\",channel=\"b\"} ")); stream->print(b); - stream->print(F("\n")); - stream->print(F("esphome_light_color{id=\"")); + stream->print(ESPHOME_F("\n")); + stream->print(ESPHOME_F("esphome_light_color{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",channel=\"w\"} ")); + stream->print(ESPHOME_F("\",channel=\"w\"} ")); stream->print(w); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); // Effect std::string effect = obj->get_effect_name(); if (effect == "None") { - stream->print(F("esphome_light_effect_active{id=\"")); + stream->print(ESPHOME_F("esphome_light_effect_active{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",effect=\"None\"} 0\n")); + stream->print(ESPHOME_F("\",effect=\"None\"} 0\n")); } else { - stream->print(F("esphome_light_effect_active{id=\"")); + stream->print(ESPHOME_F("esphome_light_effect_active{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",effect=\"")); + stream->print(ESPHOME_F("\",effect=\"")); stream->print(effect.c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif #ifdef USE_COVER void PrometheusHandler::cover_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_cover_value gauge\n")); - stream->print(F("#TYPE esphome_cover_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_cover_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_cover_failed gauge\n")); } void PrometheusHandler::cover_row_(AsyncResponseStream *stream, cover::Cover *obj, std::string &area, std::string &node, std::string &friendly_name) { @@ -393,118 +393,118 @@ void PrometheusHandler::cover_row_(AsyncResponseStream *stream, cover::Cover *ob return; if (!std::isnan(obj->position)) { // We have a valid value, output this value - stream->print(F("esphome_cover_failed{id=\"")); + stream->print(ESPHOME_F("esphome_cover_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_cover_value{id=\"")); + stream->print(ESPHOME_F("esphome_cover_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->position); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); if (obj->get_traits().get_supports_tilt()) { - stream->print(F("esphome_cover_tilt{id=\"")); + stream->print(ESPHOME_F("esphome_cover_tilt{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->tilt); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } } else { // Invalid state - stream->print(F("esphome_cover_failed{id=\"")); + stream->print(ESPHOME_F("esphome_cover_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif #ifdef USE_SWITCH void PrometheusHandler::switch_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_switch_value gauge\n")); - stream->print(F("#TYPE esphome_switch_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_switch_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_switch_failed gauge\n")); } void PrometheusHandler::switch_row_(AsyncResponseStream *stream, switch_::Switch *obj, std::string &area, std::string &node, std::string &friendly_name) { if (obj->is_internal() && !this->include_internal_) return; - stream->print(F("esphome_switch_failed{id=\"")); + stream->print(ESPHOME_F("esphome_switch_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_switch_value{id=\"")); + stream->print(ESPHOME_F("esphome_switch_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->state); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } #endif #ifdef USE_LOCK void PrometheusHandler::lock_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_lock_value gauge\n")); - stream->print(F("#TYPE esphome_lock_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_lock_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_lock_failed gauge\n")); } void PrometheusHandler::lock_row_(AsyncResponseStream *stream, lock::Lock *obj, std::string &area, std::string &node, std::string &friendly_name) { if (obj->is_internal() && !this->include_internal_) return; - stream->print(F("esphome_lock_failed{id=\"")); + stream->print(ESPHOME_F("esphome_lock_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_lock_value{id=\"")); + stream->print(ESPHOME_F("esphome_lock_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->state); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } #endif // Type-specific implementation #ifdef USE_TEXT_SENSOR void PrometheusHandler::text_sensor_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_text_sensor_value gauge\n")); - stream->print(F("#TYPE esphome_text_sensor_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_text_sensor_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_text_sensor_failed gauge\n")); } void PrometheusHandler::text_sensor_row_(AsyncResponseStream *stream, text_sensor::TextSensor *obj, std::string &area, std::string &node, std::string &friendly_name) { @@ -512,37 +512,37 @@ void PrometheusHandler::text_sensor_row_(AsyncResponseStream *stream, text_senso return; if (obj->has_state()) { // We have a valid value, output this value - stream->print(F("esphome_text_sensor_failed{id=\"")); + stream->print(ESPHOME_F("esphome_text_sensor_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_text_sensor_value{id=\"")); + stream->print(ESPHOME_F("esphome_text_sensor_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",value=\"")); + stream->print(ESPHOME_F("\",value=\"")); stream->print(obj->state.c_str()); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); } else { // Invalid state - stream->print(F("esphome_text_sensor_failed{id=\"")); + stream->print(ESPHOME_F("esphome_text_sensor_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif @@ -550,8 +550,8 @@ void PrometheusHandler::text_sensor_row_(AsyncResponseStream *stream, text_senso // Type-specific implementation #ifdef USE_NUMBER void PrometheusHandler::number_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_number_value gauge\n")); - stream->print(F("#TYPE esphome_number_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_number_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_number_failed gauge\n")); } void PrometheusHandler::number_row_(AsyncResponseStream *stream, number::Number *obj, std::string &area, std::string &node, std::string &friendly_name) { @@ -559,43 +559,43 @@ void PrometheusHandler::number_row_(AsyncResponseStream *stream, number::Number return; if (!std::isnan(obj->state)) { // We have a valid value, output this value - stream->print(F("esphome_number_failed{id=\"")); + stream->print(ESPHOME_F("esphome_number_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_number_value{id=\"")); + stream->print(ESPHOME_F("esphome_number_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->state); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } else { // Invalid state - stream->print(F("esphome_number_failed{id=\"")); + stream->print(ESPHOME_F("esphome_number_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif #ifdef USE_SELECT void PrometheusHandler::select_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_select_value gauge\n")); - stream->print(F("#TYPE esphome_select_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_select_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_select_failed gauge\n")); } void PrometheusHandler::select_row_(AsyncResponseStream *stream, select::Select *obj, std::string &area, std::string &node, std::string &friendly_name) { @@ -603,105 +603,105 @@ void PrometheusHandler::select_row_(AsyncResponseStream *stream, select::Select return; if (obj->has_state()) { // We have a valid value, output this value - stream->print(F("esphome_select_failed{id=\"")); + stream->print(ESPHOME_F("esphome_select_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_select_value{id=\"")); + stream->print(ESPHOME_F("esphome_select_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",value=\"")); + stream->print(ESPHOME_F("\",value=\"")); stream->print(obj->state.c_str()); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); } else { // Invalid state - stream->print(F("esphome_select_failed{id=\"")); + stream->print(ESPHOME_F("esphome_select_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif #ifdef USE_MEDIA_PLAYER void PrometheusHandler::media_player_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_media_player_state_value gauge\n")); - stream->print(F("#TYPE esphome_media_player_volume gauge\n")); - stream->print(F("#TYPE esphome_media_player_is_muted gauge\n")); - stream->print(F("#TYPE esphome_media_player_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_media_player_state_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_media_player_volume gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_media_player_is_muted gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_media_player_failed gauge\n")); } void PrometheusHandler::media_player_row_(AsyncResponseStream *stream, media_player::MediaPlayer *obj, std::string &area, std::string &node, std::string &friendly_name) { if (obj->is_internal() && !this->include_internal_) return; - stream->print(F("esphome_media_player_failed{id=\"")); + stream->print(ESPHOME_F("esphome_media_player_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_media_player_state_value{id=\"")); + stream->print(ESPHOME_F("esphome_media_player_state_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",value=\"")); + stream->print(ESPHOME_F("\",value=\"")); stream->print(media_player::media_player_state_to_string(obj->state)); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); - stream->print(F("esphome_media_player_volume{id=\"")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); + stream->print(ESPHOME_F("esphome_media_player_volume{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->volume); - stream->print(F("\n")); - stream->print(F("esphome_media_player_is_muted{id=\"")); + stream->print(ESPHOME_F("\n")); + stream->print(ESPHOME_F("esphome_media_player_is_muted{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); if (obj->is_muted()) { - stream->print(F("1.0")); + stream->print(ESPHOME_F("1.0")); } else { - stream->print(F("0.0")); + stream->print(ESPHOME_F("0.0")); } - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } #endif #ifdef USE_UPDATE void PrometheusHandler::update_entity_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_update_entity_state gauge\n")); - stream->print(F("#TYPE esphome_update_entity_info gauge\n")); - stream->print(F("#TYPE esphome_update_entity_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_update_entity_state gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_update_entity_info gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_update_entity_failed gauge\n")); } void PrometheusHandler::handle_update_state_(AsyncResponseStream *stream, update::UpdateState state) { @@ -730,168 +730,168 @@ void PrometheusHandler::update_entity_row_(AsyncResponseStream *stream, update:: return; if (obj->has_state()) { // We have a valid value, output this value - stream->print(F("esphome_update_entity_failed{id=\"")); + stream->print(ESPHOME_F("esphome_update_entity_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // First update state - stream->print(F("esphome_update_entity_state{id=\"")); + stream->print(ESPHOME_F("esphome_update_entity_state{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",value=\"")); + stream->print(ESPHOME_F("\",value=\"")); handle_update_state_(stream, obj->state); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); // Next update info - stream->print(F("esphome_update_entity_info{id=\"")); + stream->print(ESPHOME_F("esphome_update_entity_info{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",current_version=\"")); + stream->print(ESPHOME_F("\",current_version=\"")); stream->print(obj->update_info.current_version.c_str()); - stream->print(F("\",latest_version=\"")); + stream->print(ESPHOME_F("\",latest_version=\"")); stream->print(obj->update_info.latest_version.c_str()); - stream->print(F("\",title=\"")); + stream->print(ESPHOME_F("\",title=\"")); stream->print(obj->update_info.title.c_str()); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); } else { // Invalid state - stream->print(F("esphome_update_entity_failed{id=\"")); + stream->print(ESPHOME_F("esphome_update_entity_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 1\n")); + stream->print(ESPHOME_F("\"} 1\n")); } } #endif #ifdef USE_VALVE void PrometheusHandler::valve_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_valve_operation gauge\n")); - stream->print(F("#TYPE esphome_valve_failed gauge\n")); - stream->print(F("#TYPE esphome_valve_position gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_valve_operation gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_valve_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_valve_position gauge\n")); } void PrometheusHandler::valve_row_(AsyncResponseStream *stream, valve::Valve *obj, std::string &area, std::string &node, std::string &friendly_name) { if (obj->is_internal() && !this->include_internal_) return; - stream->print(F("esphome_valve_failed{id=\"")); + stream->print(ESPHOME_F("esphome_valve_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} 0\n")); + stream->print(ESPHOME_F("\"} 0\n")); // Data itself - stream->print(F("esphome_valve_operation{id=\"")); + stream->print(ESPHOME_F("esphome_valve_operation{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",operation=\"")); + stream->print(ESPHOME_F("\",operation=\"")); stream->print(valve::valve_operation_to_str(obj->current_operation)); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); // Now see if position is supported if (obj->get_traits().get_supports_position()) { - stream->print(F("esphome_valve_position{id=\"")); + stream->print(ESPHOME_F("esphome_valve_position{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(obj->position); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } } #endif #ifdef USE_CLIMATE void PrometheusHandler::climate_type_(AsyncResponseStream *stream) { - stream->print(F("#TYPE esphome_climate_setting gauge\n")); - stream->print(F("#TYPE esphome_climate_value gauge\n")); - stream->print(F("#TYPE esphome_climate_failed gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_climate_setting gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_climate_value gauge\n")); + stream->print(ESPHOME_F("#TYPE esphome_climate_failed gauge\n")); } void PrometheusHandler::climate_setting_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, std::string &friendly_name, std::string &setting, const LogString *setting_value) { - stream->print(F("esphome_climate_setting{id=\"")); + stream->print(ESPHOME_F("esphome_climate_setting{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",category=\"")); + stream->print(ESPHOME_F("\",category=\"")); stream->print(setting.c_str()); - stream->print(F("\",setting_value=\"")); + stream->print(ESPHOME_F("\",setting_value=\"")); stream->print(LOG_STR_ARG(setting_value)); - stream->print(F("\"} ")); - stream->print(F("1.0")); - stream->print(F("\n")); + stream->print(ESPHOME_F("\"} ")); + stream->print(ESPHOME_F("1.0")); + stream->print(ESPHOME_F("\n")); } void PrometheusHandler::climate_value_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, std::string &friendly_name, std::string &category, std::string &climate_value) { - stream->print(F("esphome_climate_value{id=\"")); + stream->print(ESPHOME_F("esphome_climate_value{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",category=\"")); + stream->print(ESPHOME_F("\",category=\"")); stream->print(category.c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); stream->print(climate_value.c_str()); - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } void PrometheusHandler::climate_failed_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, std::string &friendly_name, std::string &category, bool is_failed_value) { - stream->print(F("esphome_climate_failed{id=\"")); + stream->print(ESPHOME_F("esphome_climate_failed{id=\"")); stream->print(relabel_id_(obj).c_str()); add_area_label_(stream, area); add_node_label_(stream, node); add_friendly_name_label_(stream, friendly_name); - stream->print(F("\",name=\"")); + stream->print(ESPHOME_F("\",name=\"")); stream->print(relabel_name_(obj).c_str()); - stream->print(F("\",category=\"")); + stream->print(ESPHOME_F("\",category=\"")); stream->print(category.c_str()); - stream->print(F("\"} ")); + stream->print(ESPHOME_F("\"} ")); if (is_failed_value) { - stream->print(F("1.0")); + stream->print(ESPHOME_F("1.0")); } else { - stream->print(F("0.0")); + stream->print(ESPHOME_F("0.0")); } - stream->print(F("\n")); + stream->print(ESPHOME_F("\n")); } void PrometheusHandler::climate_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, diff --git a/esphome/components/psram/__init__.py b/esphome/components/psram/__init__.py index b5c87ae5a8..6b85e7f720 100644 --- a/esphome/components/psram/__init__.py +++ b/esphome/components/psram/__init__.py @@ -62,6 +62,11 @@ SPIRAM_SPEEDS = { } +def supported() -> bool: + variant = get_esp32_variant() + return variant in SPIRAM_MODES + + def validate_psram_mode(config): esp32_config = fv.full_config.get()[PLATFORM_ESP32] if config[CONF_SPEED] == "120MHZ": @@ -95,7 +100,7 @@ def get_config_schema(config): variant = get_esp32_variant() speeds = [f"{s}MHZ" for s in SPIRAM_SPEEDS.get(variant, [])] if not speeds: - return cv.Invalid("PSRAM is not supported on this chip") + raise cv.Invalid("PSRAM is not supported on this chip") modes = SPIRAM_MODES[variant] return cv.Schema( { @@ -121,33 +126,30 @@ async def to_code(config): if config[CONF_MODE] == TYPE_OCTAL: cg.add_platformio_option("board_build.arduino.memory_type", "qio_opi") - if CORE.using_esp_idf: - add_idf_sdkconfig_option( - f"CONFIG_{get_esp32_variant().upper()}_SPIRAM_SUPPORT", True - ) - add_idf_sdkconfig_option("CONFIG_SOC_SPIRAM_SUPPORTED", True) - add_idf_sdkconfig_option("CONFIG_SPIRAM", True) - add_idf_sdkconfig_option("CONFIG_SPIRAM_USE", True) - add_idf_sdkconfig_option("CONFIG_SPIRAM_USE_CAPS_ALLOC", True) - add_idf_sdkconfig_option("CONFIG_SPIRAM_IGNORE_NOTFOUND", True) + add_idf_sdkconfig_option( + f"CONFIG_{get_esp32_variant().upper()}_SPIRAM_SUPPORT", True + ) + add_idf_sdkconfig_option("CONFIG_SOC_SPIRAM_SUPPORTED", True) + add_idf_sdkconfig_option("CONFIG_SPIRAM", True) + add_idf_sdkconfig_option("CONFIG_SPIRAM_USE", True) + add_idf_sdkconfig_option("CONFIG_SPIRAM_USE_CAPS_ALLOC", True) + add_idf_sdkconfig_option("CONFIG_SPIRAM_IGNORE_NOTFOUND", True) - add_idf_sdkconfig_option( - f"CONFIG_SPIRAM_MODE_{SDK_MODES[config[CONF_MODE]]}", True - ) + add_idf_sdkconfig_option(f"CONFIG_SPIRAM_MODE_{SDK_MODES[config[CONF_MODE]]}", True) - # Remove MHz suffix, convert to int - speed = int(config[CONF_SPEED][:-3]) - add_idf_sdkconfig_option(f"CONFIG_SPIRAM_SPEED_{speed}M", True) - add_idf_sdkconfig_option("CONFIG_SPIRAM_SPEED", speed) - if config[CONF_MODE] == TYPE_OCTAL and speed == 120: - add_idf_sdkconfig_option("CONFIG_ESPTOOLPY_FLASHFREQ_120M", True) - add_idf_sdkconfig_option("CONFIG_BOOTLOADER_FLASH_DC_AWARE", True) - if CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] >= cv.Version(5, 4, 0): - add_idf_sdkconfig_option( - "CONFIG_SPIRAM_TIMING_TUNING_POINT_VIA_TEMPERATURE_SENSOR", True - ) - if config[CONF_ENABLE_ECC]: - add_idf_sdkconfig_option("CONFIG_SPIRAM_ECC_ENABLE", True) + # Remove MHz suffix, convert to int + speed = int(config[CONF_SPEED][:-3]) + add_idf_sdkconfig_option(f"CONFIG_SPIRAM_SPEED_{speed}M", True) + add_idf_sdkconfig_option("CONFIG_SPIRAM_SPEED", speed) + if config[CONF_MODE] == TYPE_OCTAL and speed == 120: + add_idf_sdkconfig_option("CONFIG_ESPTOOLPY_FLASHFREQ_120M", True) + add_idf_sdkconfig_option("CONFIG_BOOTLOADER_FLASH_DC_AWARE", True) + if CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] >= cv.Version(5, 4, 0): + add_idf_sdkconfig_option( + "CONFIG_SPIRAM_TIMING_TUNING_POINT_VIA_TEMPERATURE_SENSOR", True + ) + if config[CONF_ENABLE_ECC]: + add_idf_sdkconfig_option("CONFIG_SPIRAM_ECC_ENABLE", True) cg.add_define("USE_PSRAM") diff --git a/esphome/components/qmc5883l/qmc5883l.cpp b/esphome/components/qmc5883l/qmc5883l.cpp index c9196f2469..d2041a2d52 100644 --- a/esphome/components/qmc5883l/qmc5883l.cpp +++ b/esphome/components/qmc5883l/qmc5883l.cpp @@ -8,6 +8,7 @@ namespace esphome { namespace qmc5883l { static const char *const TAG = "qmc5883l"; + static const uint8_t QMC5883L_ADDRESS = 0x0D; static const uint8_t QMC5883L_REGISTER_DATA_X_LSB = 0x00; @@ -32,6 +33,10 @@ void QMC5883LComponent::setup() { } delay(10); + if (this->drdy_pin_) { + this->drdy_pin_->setup(); + } + uint8_t control_1 = 0; control_1 |= 0b01 << 0; // MODE (Mode) -> 0b00=standby, 0b01=continuous control_1 |= this->datarate_ << 2; @@ -64,6 +69,7 @@ void QMC5883LComponent::setup() { high_freq_.start(); } } + void QMC5883LComponent::dump_config() { ESP_LOGCONFIG(TAG, "QMC5883L:"); LOG_I2C_DEVICE(this); @@ -77,11 +83,20 @@ void QMC5883LComponent::dump_config() { LOG_SENSOR(" ", "Z Axis", this->z_sensor_); LOG_SENSOR(" ", "Heading", this->heading_sensor_); LOG_SENSOR(" ", "Temperature", this->temperature_sensor_); + LOG_PIN(" DRDY Pin: ", this->drdy_pin_); } + float QMC5883LComponent::get_setup_priority() const { return setup_priority::DATA; } + void QMC5883LComponent::update() { i2c::ErrorCode err; uint8_t status = false; + + // If DRDY pin is configured and the data is not ready return. + if (this->drdy_pin_ && !this->drdy_pin_->digital_read()) { + return; + } + // Status byte gets cleared when data is read, so we have to read this first. // If status and two axes are desired, it's possible to save one byte of traffic by enabling // ROL_PNT in setup and reading 7 bytes starting at the status register. diff --git a/esphome/components/qmc5883l/qmc5883l.h b/esphome/components/qmc5883l/qmc5883l.h index 3202e37780..5ba7180e23 100644 --- a/esphome/components/qmc5883l/qmc5883l.h +++ b/esphome/components/qmc5883l/qmc5883l.h @@ -3,6 +3,7 @@ #include "esphome/core/component.h" #include "esphome/components/sensor/sensor.h" #include "esphome/components/i2c/i2c.h" +#include "esphome/core/hal.h" namespace esphome { namespace qmc5883l { @@ -33,6 +34,7 @@ class QMC5883LComponent : public PollingComponent, public i2c::I2CDevice { float get_setup_priority() const override; void update() override; + void set_drdy_pin(GPIOPin *pin) { drdy_pin_ = pin; } void set_datarate(QMC5883LDatarate datarate) { datarate_ = datarate; } void set_range(QMC5883LRange range) { range_ = range; } void set_oversampling(QMC5883LOversampling oversampling) { oversampling_ = oversampling; } @@ -51,6 +53,7 @@ class QMC5883LComponent : public PollingComponent, public i2c::I2CDevice { sensor::Sensor *z_sensor_{nullptr}; sensor::Sensor *heading_sensor_{nullptr}; sensor::Sensor *temperature_sensor_{nullptr}; + GPIOPin *drdy_pin_{nullptr}; enum ErrorCode { NONE = 0, COMMUNICATION_FAILED, diff --git a/esphome/components/qmc5883l/sensor.py b/esphome/components/qmc5883l/sensor.py index ade286cb9e..b79e370a05 100644 --- a/esphome/components/qmc5883l/sensor.py +++ b/esphome/components/qmc5883l/sensor.py @@ -1,8 +1,12 @@ +import logging + +from esphome import pins import esphome.codegen as cg from esphome.components import i2c, sensor import esphome.config_validation as cv from esphome.const import ( CONF_ADDRESS, + CONF_DATA_RATE, CONF_FIELD_STRENGTH_X, CONF_FIELD_STRENGTH_Y, CONF_FIELD_STRENGTH_Z, @@ -21,6 +25,10 @@ from esphome.const import ( UNIT_MICROTESLA, ) +_LOGGER = logging.getLogger(__name__) + +CONF_DRDY_PIN = "drdy_pin" + DEPENDENCIES = ["i2c"] qmc5883l_ns = cg.esphome_ns.namespace("qmc5883l") @@ -52,6 +60,18 @@ QMC5883LOversamplings = { } +def validate_config(config): + if ( + config[CONF_UPDATE_INTERVAL].total_milliseconds < 15 + and CONF_DRDY_PIN not in config + ): + _LOGGER.warning( + "[qmc5883l] 'update_interval' is less than 15ms and 'drdy_pin' is " + "not configured, this may result in I2C errors" + ) + return config + + def validate_enum(enum_values, units=None, int=True): _units = [] if units is not None: @@ -88,7 +108,7 @@ temperature_schema = sensor.sensor_schema( state_class=STATE_CLASS_MEASUREMENT, ) -CONFIG_SCHEMA = ( +CONFIG_SCHEMA = cv.All( cv.Schema( { cv.GenerateID(): cv.declare_id(QMC5883LComponent), @@ -104,29 +124,25 @@ CONFIG_SCHEMA = ( cv.Optional(CONF_FIELD_STRENGTH_Z): field_strength_schema, cv.Optional(CONF_HEADING): heading_schema, cv.Optional(CONF_TEMPERATURE): temperature_schema, + cv.Optional(CONF_DRDY_PIN): pins.gpio_input_pin_schema, + cv.Optional(CONF_DATA_RATE, default="200hz"): validate_enum( + QMC5883LDatarates, units=["hz", "Hz"] + ), } ) .extend(cv.polling_component_schema("60s")) - .extend(i2c.i2c_device_schema(0x0D)) + .extend(i2c.i2c_device_schema(0x0D)), + validate_config, ) -def auto_data_rate(config): - interval_sec = config[CONF_UPDATE_INTERVAL].total_milliseconds / 1000 - interval_hz = 1.0 / interval_sec - for datarate in sorted(QMC5883LDatarates.keys()): - if float(datarate) >= interval_hz: - return QMC5883LDatarates[datarate] - return QMC5883LDatarates[200] - - async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) await i2c.register_i2c_device(var, config) cg.add(var.set_oversampling(config[CONF_OVERSAMPLING])) - cg.add(var.set_datarate(auto_data_rate(config))) + cg.add(var.set_datarate(config[CONF_DATA_RATE])) cg.add(var.set_range(config[CONF_RANGE])) if CONF_FIELD_STRENGTH_X in config: sens = await sensor.new_sensor(config[CONF_FIELD_STRENGTH_X]) @@ -143,3 +159,6 @@ async def to_code(config): if CONF_TEMPERATURE in config: sens = await sensor.new_sensor(config[CONF_TEMPERATURE]) cg.add(var.set_temperature_sensor(sens)) + if CONF_DRDY_PIN in config: + pin = await cg.gpio_pin_expression(config[CONF_DRDY_PIN]) + cg.add(var.set_drdy_pin(pin)) diff --git a/esphome/components/remote_base/gobox_protocol.cpp b/esphome/components/remote_base/gobox_protocol.cpp index 54e0dff663..4f6de5e59e 100644 --- a/esphome/components/remote_base/gobox_protocol.cpp +++ b/esphome/components/remote_base/gobox_protocol.cpp @@ -10,8 +10,8 @@ constexpr uint32_t BIT_MARK_US = 580; // 70us seems like a safe time delta for constexpr uint32_t BIT_ONE_SPACE_US = 1640; constexpr uint32_t BIT_ZERO_SPACE_US = 545; constexpr uint64_t HEADER = 0b011001001100010uL; // 15 bits -constexpr uint64_t HEADER_SIZE = 15; -constexpr uint64_t CODE_SIZE = 17; +constexpr size_t HEADER_SIZE = 15; +constexpr size_t CODE_SIZE = 17; void GoboxProtocol::dump_timings_(const RawTimings &timings) const { ESP_LOGD(TAG, "Gobox: size=%u", timings.size()); @@ -39,7 +39,7 @@ void GoboxProtocol::encode(RemoteTransmitData *dst, const GoboxData &data) { } optional GoboxProtocol::decode(RemoteReceiveData src) { - if (src.size() < ((HEADER_SIZE + CODE_SIZE) * 2 + 1)) { + if (static_cast(src.size()) < ((HEADER_SIZE + CODE_SIZE) * 2 + 1)) { return {}; } diff --git a/esphome/components/remote_receiver/__init__.py b/esphome/components/remote_receiver/__init__.py index 9095016b55..cd2b440645 100644 --- a/esphome/components/remote_receiver/__init__.py +++ b/esphome/components/remote_receiver/__init__.py @@ -5,6 +5,8 @@ from esphome.config_helpers import filter_source_files_from_platform import esphome.config_validation as cv from esphome.const import ( CONF_BUFFER_SIZE, + CONF_CARRIER_DUTY_PERCENT, + CONF_CARRIER_FREQUENCY, CONF_CLOCK_RESOLUTION, CONF_DUMP, CONF_FILTER, @@ -149,6 +151,14 @@ CONFIG_SCHEMA = remote_base.validate_triggers( ), cv.boolean, ), + cv.SplitDefault(CONF_CARRIER_DUTY_PERCENT, esp32=100): cv.All( + cv.only_on_esp32, + cv.percentage_int, + cv.Range(min=1, max=100), + ), + cv.SplitDefault(CONF_CARRIER_FREQUENCY, esp32="0Hz"): cv.All( + cv.only_on_esp32, cv.frequency, cv.int_ + ), } ) .extend(cv.COMPONENT_SCHEMA) @@ -168,6 +178,8 @@ async def to_code(config): cg.add(var.set_clock_resolution(config[CONF_CLOCK_RESOLUTION])) if CONF_FILTER_SYMBOLS in config: cg.add(var.set_filter_symbols(config[CONF_FILTER_SYMBOLS])) + cg.add(var.set_carrier_duty_percent(config[CONF_CARRIER_DUTY_PERCENT])) + cg.add(var.set_carrier_frequency(config[CONF_CARRIER_FREQUENCY])) else: var = cg.new_Pvariable(config[CONF_ID], pin) @@ -196,8 +208,8 @@ FILTER_SOURCE_FILES = filter_source_files_from_platform( PlatformFramework.ESP32_ARDUINO, PlatformFramework.ESP32_IDF, }, - "remote_receiver_esp8266.cpp": {PlatformFramework.ESP8266_ARDUINO}, - "remote_receiver_libretiny.cpp": { + "remote_receiver.cpp": { + PlatformFramework.ESP8266_ARDUINO, PlatformFramework.BK72XX_ARDUINO, PlatformFramework.RTL87XX_ARDUINO, PlatformFramework.LN882X_ARDUINO, diff --git a/esphome/components/remote_receiver/remote_receiver_esp8266.cpp b/esphome/components/remote_receiver/remote_receiver.cpp similarity index 97% rename from esphome/components/remote_receiver/remote_receiver_esp8266.cpp rename to esphome/components/remote_receiver/remote_receiver.cpp index b8ac29a543..a8438e20d7 100644 --- a/esphome/components/remote_receiver/remote_receiver_esp8266.cpp +++ b/esphome/components/remote_receiver/remote_receiver.cpp @@ -3,12 +3,12 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" -#ifdef USE_ESP8266 +#if defined(USE_LIBRETINY) || defined(USE_ESP8266) namespace esphome { namespace remote_receiver { -static const char *const TAG = "remote_receiver.esp8266"; +static const char *const TAG = "remote_receiver"; void IRAM_ATTR HOT RemoteReceiverComponentStore::gpio_intr(RemoteReceiverComponentStore *arg) { const uint32_t now = micros(); diff --git a/esphome/components/remote_receiver/remote_receiver.h b/esphome/components/remote_receiver/remote_receiver.h index 45e06e664a..3ddcf353c7 100644 --- a/esphome/components/remote_receiver/remote_receiver.h +++ b/esphome/components/remote_receiver/remote_receiver.h @@ -64,6 +64,8 @@ class RemoteReceiverComponent : public remote_base::RemoteReceiverBase, void set_filter_symbols(uint32_t filter_symbols) { this->filter_symbols_ = filter_symbols; } void set_receive_symbols(uint32_t receive_symbols) { this->receive_symbols_ = receive_symbols; } void set_with_dma(bool with_dma) { this->with_dma_ = with_dma; } + void set_carrier_duty_percent(uint8_t carrier_duty_percent) { this->carrier_duty_percent_ = carrier_duty_percent; } + void set_carrier_frequency(uint32_t carrier_frequency) { this->carrier_frequency_ = carrier_frequency; } #endif void set_buffer_size(uint32_t buffer_size) { this->buffer_size_ = buffer_size; } void set_filter_us(uint32_t filter_us) { this->filter_us_ = filter_us; } @@ -76,6 +78,8 @@ class RemoteReceiverComponent : public remote_base::RemoteReceiverBase, uint32_t filter_symbols_{0}; uint32_t receive_symbols_{0}; bool with_dma_{false}; + uint32_t carrier_frequency_{0}; + uint8_t carrier_duty_percent_{100}; esp_err_t error_code_{ESP_OK}; std::string error_string_{""}; #endif diff --git a/esphome/components/remote_receiver/remote_receiver_esp32.cpp b/esphome/components/remote_receiver/remote_receiver_esp32.cpp index 7e1bd3c457..49358eef3f 100644 --- a/esphome/components/remote_receiver/remote_receiver_esp32.cpp +++ b/esphome/components/remote_receiver/remote_receiver_esp32.cpp @@ -72,6 +72,21 @@ void RemoteReceiverComponent::setup() { return; } + if (this->carrier_frequency_ > 0 && 0 < this->carrier_duty_percent_ && this->carrier_duty_percent_ < 100) { + rmt_carrier_config_t carrier; + memset(&carrier, 0, sizeof(carrier)); + carrier.frequency_hz = this->carrier_frequency_; + carrier.duty_cycle = (float) this->carrier_duty_percent_ / 100.0f; + carrier.flags.polarity_active_low = this->pin_->is_inverted(); + error = rmt_apply_carrier(this->channel_, &carrier); + if (error != ESP_OK) { + this->error_code_ = error; + this->error_string_ = "in rmt_apply_carrier"; + this->mark_failed(); + return; + } + } + rmt_rx_event_callbacks_t callbacks; memset(&callbacks, 0, sizeof(callbacks)); callbacks.on_recv_done = rmt_callback; @@ -111,11 +126,13 @@ void RemoteReceiverComponent::dump_config() { " Filter symbols: %" PRIu32 "\n" " Receive symbols: %" PRIu32 "\n" " Tolerance: %" PRIu32 "%s\n" + " Carrier frequency: %" PRIu32 " hz\n" + " Carrier duty: %u%%\n" " Filter out pulses shorter than: %" PRIu32 " us\n" " Signal is done after %" PRIu32 " us of no changes", this->clock_resolution_, this->rmt_symbols_, this->filter_symbols_, this->receive_symbols_, this->tolerance_, (this->tolerance_mode_ == remote_base::TOLERANCE_MODE_TIME) ? " us" : "%", - this->filter_us_, this->idle_us_); + this->carrier_frequency_, this->carrier_duty_percent_, this->filter_us_, this->idle_us_); if (this->is_failed()) { ESP_LOGE(TAG, "Configuring RMT driver failed: %s (%s)", esp_err_to_name(this->error_code_), this->error_string_.c_str()); diff --git a/esphome/components/remote_receiver/remote_receiver_libretiny.cpp b/esphome/components/remote_receiver/remote_receiver_libretiny.cpp deleted file mode 100644 index 8d801b37d2..0000000000 --- a/esphome/components/remote_receiver/remote_receiver_libretiny.cpp +++ /dev/null @@ -1,125 +0,0 @@ -#include "remote_receiver.h" -#include "esphome/core/hal.h" -#include "esphome/core/helpers.h" -#include "esphome/core/log.h" - -#ifdef USE_LIBRETINY - -namespace esphome { -namespace remote_receiver { - -static const char *const TAG = "remote_receiver.libretiny"; - -void IRAM_ATTR HOT RemoteReceiverComponentStore::gpio_intr(RemoteReceiverComponentStore *arg) { - const uint32_t now = micros(); - // If the lhs is 1 (rising edge) we should write to an uneven index and vice versa - const uint32_t next = (arg->buffer_write_at + 1) % arg->buffer_size; - const bool level = arg->pin.digital_read(); - if (level != next % 2) - return; - - // If next is buffer_read, we have hit an overflow - if (next == arg->buffer_read_at) - return; - - const uint32_t last_change = arg->buffer[arg->buffer_write_at]; - const uint32_t time_since_change = now - last_change; - if (time_since_change <= arg->filter_us) - return; - - arg->buffer[arg->buffer_write_at = next] = now; -} - -void RemoteReceiverComponent::setup() { - this->pin_->setup(); - auto &s = this->store_; - s.filter_us = this->filter_us_; - s.pin = this->pin_->to_isr(); - s.buffer_size = this->buffer_size_; - - this->high_freq_.start(); - if (s.buffer_size % 2 != 0) { - // Make sure divisible by two. This way, we know that every 0bxxx0 index is a space and every 0bxxx1 index is a mark - s.buffer_size++; - } - - s.buffer = new uint32_t[s.buffer_size]; - void *buf = (void *) s.buffer; - memset(buf, 0, s.buffer_size * sizeof(uint32_t)); - - // First index is a space. - if (this->pin_->digital_read()) { - s.buffer_write_at = s.buffer_read_at = 1; - } else { - s.buffer_write_at = s.buffer_read_at = 0; - } - this->pin_->attach_interrupt(RemoteReceiverComponentStore::gpio_intr, &this->store_, gpio::INTERRUPT_ANY_EDGE); -} -void RemoteReceiverComponent::dump_config() { - ESP_LOGCONFIG(TAG, "Remote Receiver:"); - LOG_PIN(" Pin: ", this->pin_); - if (this->pin_->digital_read()) { - ESP_LOGW(TAG, "Remote Receiver Signal starts with a HIGH value. Usually this means you have to " - "invert the signal using 'inverted: True' in the pin schema!"); - } - ESP_LOGCONFIG(TAG, - " Buffer Size: %u\n" - " Tolerance: %u%s\n" - " Filter out pulses shorter than: %u us\n" - " Signal is done after %u us of no changes", - this->buffer_size_, this->tolerance_, - (this->tolerance_mode_ == remote_base::TOLERANCE_MODE_TIME) ? " us" : "%", this->filter_us_, - this->idle_us_); -} - -void RemoteReceiverComponent::loop() { - auto &s = this->store_; - - // copy write at to local variables, as it's volatile - const uint32_t write_at = s.buffer_write_at; - const uint32_t dist = (s.buffer_size + write_at - s.buffer_read_at) % s.buffer_size; - // signals must at least one rising and one leading edge - if (dist <= 1) - return; - const uint32_t now = micros(); - if (now - s.buffer[write_at] < this->idle_us_) { - // The last change was fewer than the configured idle time ago. - return; - } - - ESP_LOGVV(TAG, "read_at=%u write_at=%u dist=%u now=%u end=%u", s.buffer_read_at, write_at, dist, now, - s.buffer[write_at]); - - // Skip first value, it's from the previous idle level - s.buffer_read_at = (s.buffer_read_at + 1) % s.buffer_size; - uint32_t prev = s.buffer_read_at; - s.buffer_read_at = (s.buffer_read_at + 1) % s.buffer_size; - const uint32_t reserve_size = 1 + (s.buffer_size + write_at - s.buffer_read_at) % s.buffer_size; - this->temp_.clear(); - this->temp_.reserve(reserve_size); - int32_t multiplier = s.buffer_read_at % 2 == 0 ? 1 : -1; - - for (uint32_t i = 0; prev != write_at; i++) { - int32_t delta = s.buffer[s.buffer_read_at] - s.buffer[prev]; - if (uint32_t(delta) >= this->idle_us_) { - // already found a space longer than idle. There must have been two pulses - break; - } - - ESP_LOGVV(TAG, " i=%u buffer[%u]=%u - buffer[%u]=%u -> %d", i, s.buffer_read_at, s.buffer[s.buffer_read_at], prev, - s.buffer[prev], multiplier * delta); - this->temp_.push_back(multiplier * delta); - prev = s.buffer_read_at; - s.buffer_read_at = (s.buffer_read_at + 1) % s.buffer_size; - multiplier *= -1; - } - s.buffer_read_at = (s.buffer_size + s.buffer_read_at - 1) % s.buffer_size; - this->temp_.push_back(this->idle_us_ * multiplier); - - this->call_listeners_dumpers_(); -} - -} // namespace remote_receiver -} // namespace esphome - -#endif diff --git a/esphome/components/remote_transmitter/__init__.py b/esphome/components/remote_transmitter/__init__.py index e79437013f..cb98c017f1 100644 --- a/esphome/components/remote_transmitter/__init__.py +++ b/esphome/components/remote_transmitter/__init__.py @@ -131,8 +131,8 @@ FILTER_SOURCE_FILES = filter_source_files_from_platform( PlatformFramework.ESP32_ARDUINO, PlatformFramework.ESP32_IDF, }, - "remote_transmitter_esp8266.cpp": {PlatformFramework.ESP8266_ARDUINO}, - "remote_transmitter_libretiny.cpp": { + "remote_transmitter.cpp": { + PlatformFramework.ESP8266_ARDUINO, PlatformFramework.BK72XX_ARDUINO, PlatformFramework.RTL87XX_ARDUINO, PlatformFramework.LN882X_ARDUINO, diff --git a/esphome/components/remote_transmitter/remote_transmitter.cpp b/esphome/components/remote_transmitter/remote_transmitter.cpp index 425418ff39..347e9d9d33 100644 --- a/esphome/components/remote_transmitter/remote_transmitter.cpp +++ b/esphome/components/remote_transmitter/remote_transmitter.cpp @@ -2,10 +2,113 @@ #include "esphome/core/log.h" #include "esphome/core/application.h" +#if defined(USE_LIBRETINY) || defined(USE_ESP8266) + namespace esphome { namespace remote_transmitter { static const char *const TAG = "remote_transmitter"; +void RemoteTransmitterComponent::setup() { + this->pin_->setup(); + this->pin_->digital_write(false); +} + +void RemoteTransmitterComponent::dump_config() { + ESP_LOGCONFIG(TAG, + "Remote Transmitter:\n" + " Carrier Duty: %u%%", + this->carrier_duty_percent_); + LOG_PIN(" Pin: ", this->pin_); +} + +void RemoteTransmitterComponent::calculate_on_off_time_(uint32_t carrier_frequency, uint32_t *on_time_period, + uint32_t *off_time_period) { + if (carrier_frequency == 0) { + *on_time_period = 0; + *off_time_period = 0; + return; + } + uint32_t period = (1000000UL + carrier_frequency / 2) / carrier_frequency; // round(1000000/freq) + period = std::max(uint32_t(1), period); + *on_time_period = (period * this->carrier_duty_percent_) / 100; + *off_time_period = period - *on_time_period; +} + +void RemoteTransmitterComponent::await_target_time_() { + const uint32_t current_time = micros(); + if (this->target_time_ == 0) { + this->target_time_ = current_time; + } else if ((int32_t) (this->target_time_ - current_time) > 0) { +#if defined(USE_LIBRETINY) + // busy loop for libretiny is required (see the comment inside micros() in wiring.c) + while ((int32_t) (this->target_time_ - micros()) > 0) + ; +#else + delayMicroseconds(this->target_time_ - current_time); +#endif + } +} + +void RemoteTransmitterComponent::mark_(uint32_t on_time, uint32_t off_time, uint32_t usec) { + this->await_target_time_(); + this->pin_->digital_write(true); + + const uint32_t target = this->target_time_ + usec; + if (this->carrier_duty_percent_ < 100 && (on_time > 0 || off_time > 0)) { + while (true) { // Modulate with carrier frequency + this->target_time_ += on_time; + if ((int32_t) (this->target_time_ - target) >= 0) + break; + this->await_target_time_(); + this->pin_->digital_write(false); + + this->target_time_ += off_time; + if ((int32_t) (this->target_time_ - target) >= 0) + break; + this->await_target_time_(); + this->pin_->digital_write(true); + } + } + this->target_time_ = target; +} + +void RemoteTransmitterComponent::space_(uint32_t usec) { + this->await_target_time_(); + this->pin_->digital_write(false); + this->target_time_ += usec; +} + +void RemoteTransmitterComponent::digital_write(bool value) { this->pin_->digital_write(value); } + +void RemoteTransmitterComponent::send_internal(uint32_t send_times, uint32_t send_wait) { + ESP_LOGD(TAG, "Sending remote code"); + uint32_t on_time, off_time; + this->calculate_on_off_time_(this->temp_.get_carrier_frequency(), &on_time, &off_time); + this->target_time_ = 0; + this->transmit_trigger_->trigger(); + for (uint32_t i = 0; i < send_times; i++) { + InterruptLock lock; + for (int32_t item : this->temp_.get_data()) { + if (item > 0) { + const auto length = uint32_t(item); + this->mark_(on_time, off_time, length); + } else { + const auto length = uint32_t(-item); + this->space_(length); + } + App.feed_wdt(); + } + this->await_target_time_(); // wait for duration of last pulse + this->pin_->digital_write(false); + + if (i + 1 < send_times) + this->target_time_ += send_wait; + } + this->complete_trigger_->trigger(); +} + } // namespace remote_transmitter } // namespace esphome + +#endif diff --git a/esphome/components/remote_transmitter/remote_transmitter_esp8266.cpp b/esphome/components/remote_transmitter/remote_transmitter_esp8266.cpp deleted file mode 100644 index fdd4198773..0000000000 --- a/esphome/components/remote_transmitter/remote_transmitter_esp8266.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include "remote_transmitter.h" -#include "esphome/core/log.h" -#include "esphome/core/application.h" - -#ifdef USE_ESP8266 - -namespace esphome { -namespace remote_transmitter { - -static const char *const TAG = "remote_transmitter"; - -void RemoteTransmitterComponent::setup() { - this->pin_->setup(); - this->pin_->digital_write(false); -} - -void RemoteTransmitterComponent::dump_config() { - ESP_LOGCONFIG(TAG, - "Remote Transmitter:\n" - " Carrier Duty: %u%%", - this->carrier_duty_percent_); - LOG_PIN(" Pin: ", this->pin_); -} - -void RemoteTransmitterComponent::calculate_on_off_time_(uint32_t carrier_frequency, uint32_t *on_time_period, - uint32_t *off_time_period) { - if (carrier_frequency == 0) { - *on_time_period = 0; - *off_time_period = 0; - return; - } - uint32_t period = (1000000UL + carrier_frequency / 2) / carrier_frequency; // round(1000000/freq) - period = std::max(uint32_t(1), period); - *on_time_period = (period * this->carrier_duty_percent_) / 100; - *off_time_period = period - *on_time_period; -} - -void RemoteTransmitterComponent::await_target_time_() { - const uint32_t current_time = micros(); - if (this->target_time_ == 0) { - this->target_time_ = current_time; - } else if ((int32_t) (this->target_time_ - current_time) > 0) { - delayMicroseconds(this->target_time_ - current_time); - } -} - -void RemoteTransmitterComponent::mark_(uint32_t on_time, uint32_t off_time, uint32_t usec) { - this->await_target_time_(); - this->pin_->digital_write(true); - - const uint32_t target = this->target_time_ + usec; - if (this->carrier_duty_percent_ < 100 && (on_time > 0 || off_time > 0)) { - while (true) { // Modulate with carrier frequency - this->target_time_ += on_time; - if ((int32_t) (this->target_time_ - target) >= 0) - break; - this->await_target_time_(); - this->pin_->digital_write(false); - - this->target_time_ += off_time; - if ((int32_t) (this->target_time_ - target) >= 0) - break; - this->await_target_time_(); - this->pin_->digital_write(true); - } - } - this->target_time_ = target; -} - -void RemoteTransmitterComponent::space_(uint32_t usec) { - this->await_target_time_(); - this->pin_->digital_write(false); - this->target_time_ += usec; -} - -void RemoteTransmitterComponent::digital_write(bool value) { this->pin_->digital_write(value); } - -void RemoteTransmitterComponent::send_internal(uint32_t send_times, uint32_t send_wait) { - ESP_LOGD(TAG, "Sending remote code"); - uint32_t on_time, off_time; - this->calculate_on_off_time_(this->temp_.get_carrier_frequency(), &on_time, &off_time); - this->target_time_ = 0; - this->transmit_trigger_->trigger(); - for (uint32_t i = 0; i < send_times; i++) { - for (int32_t item : this->temp_.get_data()) { - if (item > 0) { - const auto length = uint32_t(item); - this->mark_(on_time, off_time, length); - } else { - const auto length = uint32_t(-item); - this->space_(length); - } - App.feed_wdt(); - } - this->await_target_time_(); // wait for duration of last pulse - this->pin_->digital_write(false); - - if (i + 1 < send_times) - this->target_time_ += send_wait; - } - this->complete_trigger_->trigger(); -} - -} // namespace remote_transmitter -} // namespace esphome - -#endif diff --git a/esphome/components/remote_transmitter/remote_transmitter_libretiny.cpp b/esphome/components/remote_transmitter/remote_transmitter_libretiny.cpp deleted file mode 100644 index 9ba850090d..0000000000 --- a/esphome/components/remote_transmitter/remote_transmitter_libretiny.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include "remote_transmitter.h" -#include "esphome/core/log.h" -#include "esphome/core/application.h" - -#ifdef USE_LIBRETINY - -namespace esphome { -namespace remote_transmitter { - -static const char *const TAG = "remote_transmitter"; - -void RemoteTransmitterComponent::setup() { - this->pin_->setup(); - this->pin_->digital_write(false); -} - -void RemoteTransmitterComponent::dump_config() { - ESP_LOGCONFIG(TAG, - "Remote Transmitter:\n" - " Carrier Duty: %u%%", - this->carrier_duty_percent_); - LOG_PIN(" Pin: ", this->pin_); -} - -void RemoteTransmitterComponent::calculate_on_off_time_(uint32_t carrier_frequency, uint32_t *on_time_period, - uint32_t *off_time_period) { - if (carrier_frequency == 0) { - *on_time_period = 0; - *off_time_period = 0; - return; - } - uint32_t period = (1000000UL + carrier_frequency / 2) / carrier_frequency; // round(1000000/freq) - period = std::max(uint32_t(1), period); - *on_time_period = (period * this->carrier_duty_percent_) / 100; - *off_time_period = period - *on_time_period; -} - -void RemoteTransmitterComponent::await_target_time_() { - const uint32_t current_time = micros(); - if (this->target_time_ == 0) { - this->target_time_ = current_time; - } else { - while ((int32_t) (this->target_time_ - micros()) > 0) { - // busy loop that ensures micros is constantly called - } - } -} - -void RemoteTransmitterComponent::mark_(uint32_t on_time, uint32_t off_time, uint32_t usec) { - this->await_target_time_(); - this->pin_->digital_write(true); - - const uint32_t target = this->target_time_ + usec; - if (this->carrier_duty_percent_ < 100 && (on_time > 0 || off_time > 0)) { - while (true) { // Modulate with carrier frequency - this->target_time_ += on_time; - if ((int32_t) (this->target_time_ - target) >= 0) - break; - this->await_target_time_(); - this->pin_->digital_write(false); - - this->target_time_ += off_time; - if ((int32_t) (this->target_time_ - target) >= 0) - break; - this->await_target_time_(); - this->pin_->digital_write(true); - } - } - this->target_time_ = target; -} - -void RemoteTransmitterComponent::space_(uint32_t usec) { - this->await_target_time_(); - this->pin_->digital_write(false); - this->target_time_ += usec; -} - -void RemoteTransmitterComponent::digital_write(bool value) { this->pin_->digital_write(value); } - -void RemoteTransmitterComponent::send_internal(uint32_t send_times, uint32_t send_wait) { - ESP_LOGD(TAG, "Sending remote code"); - uint32_t on_time, off_time; - this->calculate_on_off_time_(this->temp_.get_carrier_frequency(), &on_time, &off_time); - this->target_time_ = 0; - this->transmit_trigger_->trigger(); - for (uint32_t i = 0; i < send_times; i++) { - InterruptLock lock; - for (int32_t item : this->temp_.get_data()) { - if (item > 0) { - const auto length = uint32_t(item); - this->mark_(on_time, off_time, length); - } else { - const auto length = uint32_t(-item); - this->space_(length); - } - App.feed_wdt(); - } - this->await_target_time_(); // wait for duration of last pulse - this->pin_->digital_write(false); - - if (i + 1 < send_times) - this->target_time_ += send_wait; - } - this->complete_trigger_->trigger(); -} - -} // namespace remote_transmitter -} // namespace esphome - -#endif diff --git a/esphome/components/rp2040/__init__.py b/esphome/components/rp2040/__init__.py index 1ec38e0159..3a1ea16fa3 100644 --- a/esphome/components/rp2040/__init__.py +++ b/esphome/components/rp2040/__init__.py @@ -1,5 +1,5 @@ import logging -import os +from pathlib import Path from string import ascii_letters, digits import esphome.codegen as cg @@ -19,7 +19,7 @@ from esphome.const import ( ThreadModel, ) from esphome.core import CORE, CoroPriority, EsphomeError, coroutine_with_priority -from esphome.helpers import copy_file_if_changed, mkdir_p, read_file, write_file +from esphome.helpers import copy_file_if_changed, read_file, write_file_if_changed from .const import KEY_BOARD, KEY_PIO_FILES, KEY_RP2040, rp2040_ns @@ -221,18 +221,18 @@ def generate_pio_files() -> bool: if not files: return False for key, data in files.items(): - pio_path = CORE.relative_build_path(f"src/pio/{key}.pio") - mkdir_p(os.path.dirname(pio_path)) - write_file(pio_path, data) + pio_path = CORE.build_path / "src" / "pio" / f"{key}.pio" + pio_path.parent.mkdir(parents=True, exist_ok=True) + write_file_if_changed(pio_path, data) includes.append(f"pio/{key}.pio.h") - write_file( + write_file_if_changed( CORE.relative_build_path("src/pio_includes.h"), "#pragma once\n" + "\n".join([f'#include "{include}"' for include in includes]), ) - dir = os.path.dirname(__file__) - build_pio_file = os.path.join(dir, "build_pio.py.script") + dir = Path(__file__).parent + build_pio_file = dir / "build_pio.py.script" copy_file_if_changed( build_pio_file, CORE.relative_build_path("build_pio.py"), @@ -243,8 +243,8 @@ def generate_pio_files() -> bool: # Called by writer.py def copy_files(): - dir = os.path.dirname(__file__) - post_build_file = os.path.join(dir, "post_build.py.script") + dir = Path(__file__).parent + post_build_file = dir / "post_build.py.script" copy_file_if_changed( post_build_file, CORE.relative_build_path("post_build.py"), @@ -252,4 +252,4 @@ def copy_files(): if generate_pio_files(): path = CORE.relative_src_path("esphome.h") content = read_file(path).rstrip("\n") - write_file(path, content + '\n#include "pio_includes.h"\n') + write_file_if_changed(path, content + '\n#include "pio_includes.h"\n') diff --git a/esphome/components/rtttl/rtttl.cpp b/esphome/components/rtttl/rtttl.cpp index 5aedc74489..b79f27e2e5 100644 --- a/esphome/components/rtttl/rtttl.cpp +++ b/esphome/components/rtttl/rtttl.cpp @@ -215,7 +215,7 @@ void Rtttl::loop() { sample[x].right = 0; } - if (x >= SAMPLE_BUFFER_SIZE || this->samples_sent_ >= this->samples_count_) { + if (static_cast(x) >= SAMPLE_BUFFER_SIZE || this->samples_sent_ >= this->samples_count_) { break; } this->samples_sent_++; @@ -374,7 +374,7 @@ void Rtttl::loop() { this->last_note_ = millis(); } -#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_DEBUG +#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE static const LogString *state_to_string(State state) { switch (state) { case STATE_STOPPED: diff --git a/esphome/components/scd30/sensor.py b/esphome/components/scd30/sensor.py index 6981af4de9..194df8ec4f 100644 --- a/esphome/components/scd30/sensor.py +++ b/esphome/components/scd30/sensor.py @@ -66,7 +66,7 @@ CONFIG_SCHEMA = ( ), cv.Optional(CONF_AMBIENT_PRESSURE_COMPENSATION, default=0): cv.pressure, cv.Optional(CONF_TEMPERATURE_OFFSET): cv.All( - cv.temperature, + cv.temperature_delta, cv.float_range(min=0, max=655.35), ), cv.Optional(CONF_UPDATE_INTERVAL, default="60s"): cv.All( diff --git a/esphome/components/script/__init__.py b/esphome/components/script/__init__.py index ee1f6a4ad0..e8a8aa5671 100644 --- a/esphome/components/script/__init__.py +++ b/esphome/components/script/__init__.py @@ -124,7 +124,7 @@ async def to_code(config): template, func_args = parameters_to_template(conf[CONF_PARAMETERS]) trigger = cg.new_Pvariable(conf[CONF_ID], template) # Add a human-readable name to the script - cg.add(trigger.set_name(conf[CONF_ID].id)) + cg.add(trigger.set_name(cg.LogStringLiteral(conf[CONF_ID].id))) if CONF_MAX_RUNS in conf: cg.add(trigger.set_max_runs(conf[CONF_MAX_RUNS])) diff --git a/esphome/components/script/script.h b/esphome/components/script/script.h index b16bb53acc..b87402f52e 100644 --- a/esphome/components/script/script.h +++ b/esphome/components/script/script.h @@ -48,14 +48,14 @@ template class Script : public ScriptLogger, public Trigger void execute_tuple_(const std::tuple &tuple, seq /*unused*/) { this->execute(std::get(tuple)...); } - std::string name_; + const LogString *name_{nullptr}; }; /** A script type for which only a single instance at a time is allowed. @@ -68,7 +68,7 @@ template class SingleScript : public Script { void execute(Ts... x) override { if (this->is_action_running()) { this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' is already running! (mode: single)"), - this->name_.c_str()); + LOG_STR_ARG(this->name_)); return; } @@ -85,7 +85,7 @@ template class RestartScript : public Script { public: void execute(Ts... x) override { if (this->is_action_running()) { - this->esp_logd_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' restarting (mode: restart)"), this->name_.c_str()); + this->esp_logd_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' restarting (mode: restart)"), LOG_STR_ARG(this->name_)); this->stop_action(); } @@ -105,12 +105,12 @@ template class QueueingScript : public Script, public Com // num_runs_ + 1 if (this->max_runs_ != 0 && this->num_runs_ + 1 >= this->max_runs_) { this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' maximum number of queued runs exceeded!"), - this->name_.c_str()); + LOG_STR_ARG(this->name_)); return; } this->esp_logd_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' queueing new instance (mode: queued)"), - this->name_.c_str()); + LOG_STR_ARG(this->name_)); this->num_runs_++; this->var_queue_.push(std::make_tuple(x...)); return; @@ -157,7 +157,7 @@ template class ParallelScript : public Script { void execute(Ts... x) override { if (this->max_runs_ != 0 && this->automation_parent_->num_running() >= this->max_runs_) { this->esp_logw_(__LINE__, ESPHOME_LOG_FORMAT("Script '%s' maximum number of parallel runs exceeded!"), - this->name_.c_str()); + LOG_STR_ARG(this->name_)); return; } this->trigger(x...); diff --git a/esphome/components/select/select.cpp b/esphome/components/select/select.cpp index beb72aa320..16e8288ca1 100644 --- a/esphome/components/select/select.cpp +++ b/esphome/components/select/select.cpp @@ -34,11 +34,12 @@ size_t Select::size() const { optional Select::index_of(const std::string &option) const { const auto &options = traits.get_options(); - auto it = std::find(options.begin(), options.end(), option); - if (it == options.end()) { - return {}; + for (size_t i = 0; i < options.size(); i++) { + if (options[i] == option) { + return i; + } } - return std::distance(options.begin(), it); + return {}; } optional Select::active_index() const { diff --git a/esphome/components/select/select_call.cpp b/esphome/components/select/select_call.cpp index a8272f8622..dd398b4052 100644 --- a/esphome/components/select/select_call.cpp +++ b/esphome/components/select/select_call.cpp @@ -107,7 +107,7 @@ void SelectCall::perform() { } } - if (std::find(options.begin(), options.end(), target_value) == options.end()) { + if (!parent->has_option(target_value)) { ESP_LOGW(TAG, "'%s' - Option %s is not a valid option", name, target_value.c_str()); return; } diff --git a/esphome/components/sensirion_common/i2c_sensirion.cpp b/esphome/components/sensirion_common/i2c_sensirion.cpp index 22c4b0e53c..9eac6b4525 100644 --- a/esphome/components/sensirion_common/i2c_sensirion.cpp +++ b/esphome/components/sensirion_common/i2c_sensirion.cpp @@ -76,7 +76,8 @@ bool SensirionI2CDevice::write_command_(uint16_t command, CommandLen command_len temp[raw_idx++] = data[i] >> 8; #endif // Use MSB first since Sensirion devices use CRC-8 with MSB first - temp[raw_idx++] = crc8(&temp[raw_idx - 2], 2, 0xFF, CRC_POLYNOMIAL, true); + uint8_t crc = crc8(&temp[raw_idx - 2], 2, 0xFF, CRC_POLYNOMIAL, true); + temp[raw_idx++] = crc; } this->last_error_ = this->write(temp, raw_idx); return this->last_error_ == i2c::ERROR_OK; diff --git a/esphome/components/sensor/__init__.py b/esphome/components/sensor/__init__.py index fe9822b3ca..2b99f68ac0 100644 --- a/esphome/components/sensor/__init__.py +++ b/esphome/components/sensor/__init__.py @@ -74,6 +74,7 @@ from esphome.const import ( DEVICE_CLASS_OZONE, DEVICE_CLASS_PH, DEVICE_CLASS_PM1, + DEVICE_CLASS_PM4, DEVICE_CLASS_PM10, DEVICE_CLASS_PM25, DEVICE_CLASS_POWER, @@ -143,6 +144,7 @@ DEVICE_CLASSES = [ DEVICE_CLASS_PM1, DEVICE_CLASS_PM10, DEVICE_CLASS_PM25, + DEVICE_CLASS_PM4, DEVICE_CLASS_POWER, DEVICE_CLASS_POWER_FACTOR, DEVICE_CLASS_PRECIPITATION, diff --git a/esphome/components/sha256/__init__.py b/esphome/components/sha256/__init__.py new file mode 100644 index 0000000000..f07157416d --- /dev/null +++ b/esphome/components/sha256/__init__.py @@ -0,0 +1,22 @@ +import esphome.codegen as cg +import esphome.config_validation as cv +from esphome.core import CORE +from esphome.helpers import IS_MACOS +from esphome.types import ConfigType + +CODEOWNERS = ["@esphome/core"] + +sha256_ns = cg.esphome_ns.namespace("sha256") + +CONFIG_SCHEMA = cv.Schema({}) + + +async def to_code(config: ConfigType) -> None: + # Add OpenSSL library for host platform + if not CORE.is_host: + return + if IS_MACOS: + # macOS needs special handling for Homebrew OpenSSL + cg.add_build_flag("-I/opt/homebrew/opt/openssl/include") + cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib") + cg.add_build_flag("-lcrypto") diff --git a/esphome/components/sha256/sha256.cpp b/esphome/components/sha256/sha256.cpp new file mode 100644 index 0000000000..32abbd739d --- /dev/null +++ b/esphome/components/sha256/sha256.cpp @@ -0,0 +1,116 @@ +#include "sha256.h" + +// Only compile SHA256 implementation on platforms that support it +#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST) + +#include "esphome/core/helpers.h" +#include + +namespace esphome::sha256 { + +#if defined(USE_ESP32) || defined(USE_LIBRETINY) + +// CRITICAL ESP32-S3 HARDWARE SHA ACCELERATION REQUIREMENTS: +// +// The ESP32-S3 uses hardware DMA for SHA acceleration. The mbedtls_sha256_context structure contains +// internal state that the DMA engine references. This imposes two critical constraints: +// +// 1. NO VARIABLE LENGTH ARRAYS (VLAs): VLAs corrupt the stack layout, causing the DMA engine to +// write to incorrect memory locations. This results in null pointer dereferences and crashes. +// ALWAYS use fixed-size arrays (e.g., char buf[65], not char buf[size+1]). +// +// 2. SAME STACK FRAME ONLY: The SHA256 object must be created and used entirely within the same +// function. NEVER pass the SHA256 object or HashBase pointer to another function. When the stack +// frame changes (function call/return), the DMA references become invalid and will produce +// truncated hash output (20 bytes instead of 32) or corrupt memory. +// +// CORRECT USAGE: +// void my_function() { +// sha256::SHA256 hasher; // Created locally +// hasher.init(); +// hasher.add(data, len); // Any size, no chunking needed +// hasher.calculate(); +// bool ok = hasher.equals_hex(expected); +// // hasher destroyed when function returns +// } +// +// INCORRECT USAGE (WILL FAIL ON ESP32-S3): +// void my_function() { +// sha256::SHA256 hasher; +// helper(&hasher); // WRONG: Passed to different stack frame +// } +// void helper(HashBase *h) { +// h->init(); // WRONG: Will produce truncated/corrupted output +// } + +SHA256::~SHA256() { mbedtls_sha256_free(&this->ctx_); } + +void SHA256::init() { + mbedtls_sha256_init(&this->ctx_); + mbedtls_sha256_starts(&this->ctx_, 0); // 0 = SHA256, not SHA224 +} + +void SHA256::add(const uint8_t *data, size_t len) { mbedtls_sha256_update(&this->ctx_, data, len); } + +void SHA256::calculate() { mbedtls_sha256_finish(&this->ctx_, this->digest_); } + +#elif defined(USE_ESP8266) || defined(USE_RP2040) + +SHA256::~SHA256() = default; + +void SHA256::init() { + br_sha256_init(&this->ctx_); + this->calculated_ = false; +} + +void SHA256::add(const uint8_t *data, size_t len) { br_sha256_update(&this->ctx_, data, len); } + +void SHA256::calculate() { + if (!this->calculated_) { + br_sha256_out(&this->ctx_, this->digest_); + this->calculated_ = true; + } +} + +#elif defined(USE_HOST) + +SHA256::~SHA256() { + if (this->ctx_) { + EVP_MD_CTX_free(this->ctx_); + } +} + +void SHA256::init() { + if (this->ctx_) { + EVP_MD_CTX_free(this->ctx_); + } + this->ctx_ = EVP_MD_CTX_new(); + EVP_DigestInit_ex(this->ctx_, EVP_sha256(), nullptr); + this->calculated_ = false; +} + +void SHA256::add(const uint8_t *data, size_t len) { + if (!this->ctx_) { + this->init(); + } + EVP_DigestUpdate(this->ctx_, data, len); +} + +void SHA256::calculate() { + if (!this->ctx_) { + this->init(); + } + if (!this->calculated_) { + unsigned int len = 32; + EVP_DigestFinal_ex(this->ctx_, this->digest_, &len); + this->calculated_ = true; + } +} + +#else +#error "SHA256 not supported on this platform" +#endif + +} // namespace esphome::sha256 + +#endif // Platform check diff --git a/esphome/components/sha256/sha256.h b/esphome/components/sha256/sha256.h new file mode 100644 index 0000000000..a2b62799e1 --- /dev/null +++ b/esphome/components/sha256/sha256.h @@ -0,0 +1,60 @@ +#pragma once + +#include "esphome/core/defines.h" + +// Only define SHA256 on platforms that support it +#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST) + +#include +#include +#include +#include "esphome/core/hash_base.h" + +#if defined(USE_ESP32) || defined(USE_LIBRETINY) +#include "mbedtls/sha256.h" +#elif defined(USE_ESP8266) || defined(USE_RP2040) +#include +#elif defined(USE_HOST) +#include +#else +#error "SHA256 not supported on this platform" +#endif + +namespace esphome::sha256 { + +class SHA256 : public esphome::HashBase { + public: + SHA256() = default; + ~SHA256() override; + + void init() override; + void add(const uint8_t *data, size_t len) override; + using HashBase::add; // Bring base class overload into scope + void add(const std::string &data) { this->add((const uint8_t *) data.c_str(), data.length()); } + + void calculate() override; + + /// Get the size of the hash in bytes (32 for SHA256) + size_t get_size() const override { return 32; } + + protected: +#if defined(USE_ESP32) || defined(USE_LIBRETINY) + // CRITICAL: The mbedtls context MUST be stack-allocated (not a pointer) for ESP32-S3 hardware SHA acceleration. + // The ESP32-S3 DMA engine references this structure's memory addresses. If the context is passed to another + // function (crossing stack frames) or if VLAs are present, the DMA operations will corrupt memory and produce + // truncated/incorrect hash results. + mbedtls_sha256_context ctx_{}; +#elif defined(USE_ESP8266) || defined(USE_RP2040) + br_sha256_context ctx_{}; + bool calculated_{false}; +#elif defined(USE_HOST) + EVP_MD_CTX *ctx_{nullptr}; + bool calculated_{false}; +#else +#error "SHA256 not supported on this platform" +#endif +}; + +} // namespace esphome::sha256 + +#endif // Platform check diff --git a/esphome/components/socket/lwip_raw_tcp_impl.cpp b/esphome/components/socket/lwip_raw_tcp_impl.cpp index 2d64a275df..3377682474 100644 --- a/esphome/components/socket/lwip_raw_tcp_impl.cpp +++ b/esphome/components/socket/lwip_raw_tcp_impl.cpp @@ -9,7 +9,7 @@ #include "lwip/tcp.h" #include #include -#include +#include #include "esphome/core/helpers.h" #include "esphome/core/log.h" @@ -50,12 +50,18 @@ class LWIPRawImpl : public Socket { errno = EBADF; return nullptr; } - if (accepted_sockets_.empty()) { + if (this->accepted_socket_count_ == 0) { errno = EWOULDBLOCK; return nullptr; } - std::unique_ptr sock = std::move(accepted_sockets_.front()); - accepted_sockets_.pop(); + // Take from front for FIFO ordering + std::unique_ptr sock = std::move(this->accepted_sockets_[0]); + // Shift remaining sockets forward + for (uint8_t i = 1; i < this->accepted_socket_count_; i++) { + this->accepted_sockets_[i - 1] = std::move(this->accepted_sockets_[i]); + } + this->accepted_socket_count_--; + LWIP_LOG("Connection accepted by application, queue size: %d", this->accepted_socket_count_); if (addr != nullptr) { sock->getpeername(addr, addrlen); } @@ -494,9 +500,18 @@ class LWIPRawImpl : public Socket { // nothing to do here, we just don't push it to the queue return ERR_OK; } + // Check if we've reached the maximum accept queue size + if (this->accepted_socket_count_ >= MAX_ACCEPTED_SOCKETS) { + LWIP_LOG("Rejecting connection, queue full (%d)", this->accepted_socket_count_); + // Abort the connection when queue is full + tcp_abort(newpcb); + // Must return ERR_ABRT since we called tcp_abort() + return ERR_ABRT; + } auto sock = make_unique(family_, newpcb); sock->init(); - accepted_sockets_.push(std::move(sock)); + this->accepted_sockets_[this->accepted_socket_count_++] = std::move(sock); + LWIP_LOG("Accepted connection, queue size: %d", this->accepted_socket_count_); return ERR_OK; } void err_fn(err_t err) { @@ -587,7 +602,20 @@ class LWIPRawImpl : public Socket { } struct tcp_pcb *pcb_; - std::queue> accepted_sockets_; + // Accept queue - holds incoming connections briefly until the event loop calls accept() + // This is NOT a connection pool - just a temporary queue between LWIP callbacks and the main loop + // 3 slots is plenty since connections are pulled out quickly by the event loop + // + // Memory analysis: std::array<3> vs original std::queue implementation: + // - std::queue uses std::deque internally which on 32-bit systems needs: + // 24 bytes (deque object) + 32+ bytes (map array) + heap allocations + // Total: ~56+ bytes minimum, plus heap fragmentation + // - std::array<3>: 12 bytes fixed (3 pointers × 4 bytes) + // Saves ~44+ bytes RAM per listening socket + avoids ALL heap allocations + // Used on ESP8266 and RP2040 (platforms using LWIP_TCP implementation) + static constexpr size_t MAX_ACCEPTED_SOCKETS = 3; + std::array, MAX_ACCEPTED_SOCKETS> accepted_sockets_; + uint8_t accepted_socket_count_ = 0; // Number of sockets currently in queue bool rx_closed_ = false; pbuf *rx_buf_ = nullptr; size_t rx_buf_offset_ = 0; diff --git a/esphome/components/sonoff_d1/sonoff_d1.cpp b/esphome/components/sonoff_d1/sonoff_d1.cpp index e3d55681c5..cd09f31dd7 100644 --- a/esphome/components/sonoff_d1/sonoff_d1.cpp +++ b/esphome/components/sonoff_d1/sonoff_d1.cpp @@ -50,7 +50,7 @@ static const char *const TAG = "sonoff_d1"; uint8_t SonoffD1Output::calc_checksum_(const uint8_t *cmd, const size_t len) { uint8_t crc = 0; - for (int i = 2; i < len - 1; i++) { + for (size_t i = 2; i < len - 1; i++) { crc += cmd[i]; } return crc; diff --git a/esphome/components/spi/__init__.py b/esphome/components/spi/__init__.py index 894c6d1878..d803ee66dc 100644 --- a/esphome/components/spi/__init__.py +++ b/esphome/components/spi/__init__.py @@ -276,9 +276,6 @@ def get_spi_interface(index): return ["&SPI", "&SPI1"][index] if index == 0: return "&SPI" - # Following code can't apply to C2, H2 or 8266 since they have only one SPI - if get_target_variant() in (VARIANT_ESP32S3, VARIANT_ESP32S2): - return "new SPIClass(FSPI)" return "new SPIClass(HSPI)" diff --git a/esphome/components/split_buffer/__init__.py b/esphome/components/split_buffer/__init__.py new file mode 100644 index 0000000000..be7472936f --- /dev/null +++ b/esphome/components/split_buffer/__init__.py @@ -0,0 +1,5 @@ +CODEOWNERS = ["@jesserockz"] + +# Allows split_buffer to be configured in yaml, to allow use of the C++ api. + +CONFIG_SCHEMA = {} diff --git a/esphome/components/split_buffer/split_buffer.cpp b/esphome/components/split_buffer/split_buffer.cpp new file mode 100644 index 0000000000..a710670a5d --- /dev/null +++ b/esphome/components/split_buffer/split_buffer.cpp @@ -0,0 +1,133 @@ +#include "split_buffer.h" + +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" + +namespace esphome::split_buffer { + +static constexpr const char *const TAG = "split_buffer"; + +SplitBuffer::~SplitBuffer() { this->free(); } + +bool SplitBuffer::init(size_t total_length) { + this->free(); // Clean up any existing allocation + + if (total_length == 0) { + return false; + } + + this->total_length_ = total_length; + size_t current_buffer_size = total_length; + + RAMAllocator ptr_allocator; + RAMAllocator allocator; + + // Try to allocate the entire buffer first + while (current_buffer_size > 0) { + // Calculate how many buffers we need of this size + size_t needed_buffers = (total_length + current_buffer_size - 1) / current_buffer_size; + + // Try to allocate array of buffer pointers + uint8_t **temp_buffers = ptr_allocator.allocate(needed_buffers); + if (temp_buffers == nullptr) { + // If we can't even allocate the pointer array, don't need to continue + ESP_LOGE(TAG, "Failed to allocate pointers"); + return false; + } + + // Initialize all pointers to null + for (size_t i = 0; i < needed_buffers; i++) { + temp_buffers[i] = nullptr; + } + + // Try to allocate all the buffers + bool allocation_success = true; + for (size_t i = 0; i < needed_buffers; i++) { + size_t this_buffer_size = current_buffer_size; + // Last buffer might be smaller if total_length is not divisible by current_buffer_size + if (i == needed_buffers - 1 && total_length % current_buffer_size != 0) { + this_buffer_size = total_length % current_buffer_size; + } + + temp_buffers[i] = allocator.allocate(this_buffer_size); + if (temp_buffers[i] == nullptr) { + allocation_success = false; + break; + } + + // Initialize buffer to zero + memset(temp_buffers[i], 0, this_buffer_size); + } + + if (allocation_success) { + // Success! Store the result + this->buffers_ = temp_buffers; + this->buffer_count_ = needed_buffers; + this->buffer_size_ = current_buffer_size; + ESP_LOGD(TAG, "Allocated %zu * %zu bytes - %zu bytes", this->buffer_count_, this->buffer_size_, + this->total_length_); + return true; + } + + // Allocation failed, clean up and try smaller buffers + for (size_t i = 0; i < needed_buffers; i++) { + if (temp_buffers[i] != nullptr) { + allocator.deallocate(temp_buffers[i], 0); + } + } + ptr_allocator.deallocate(temp_buffers, 0); + + // Halve the buffer size and try again + current_buffer_size = current_buffer_size / 2; + } + + ESP_LOGE(TAG, "Failed to allocate %zu bytes", total_length); + return false; +} + +void SplitBuffer::free() { + if (this->buffers_ != nullptr) { + RAMAllocator allocator; + for (size_t i = 0; i < this->buffer_count_; i++) { + if (this->buffers_[i] != nullptr) { + allocator.deallocate(this->buffers_[i], 0); + } + } + RAMAllocator ptr_allocator; + ptr_allocator.deallocate(this->buffers_, 0); + this->buffers_ = nullptr; + } + this->buffer_count_ = 0; + this->buffer_size_ = 0; + this->total_length_ = 0; +} + +uint8_t &SplitBuffer::operator[](size_t index) { + if (index >= this->total_length_) { + ESP_LOGE(TAG, "Out of bounds - %zu >= %zu", index, this->total_length_); + // Return reference to a static dummy byte to avoid crash + static uint8_t dummy = 0; + return dummy; + } + + size_t buffer_index = index / this->buffer_size_; + size_t offset_in_buffer = index - this->buffer_size_ * buffer_index; + + return this->buffers_[buffer_index][offset_in_buffer]; +} + +const uint8_t &SplitBuffer::operator[](size_t index) const { + if (index >= this->total_length_) { + ESP_LOGE(TAG, "Out of bounds - %zu >= %zu", index, this->total_length_); + // Return reference to a static dummy byte to avoid crash + static const uint8_t DUMMY = 0; + return DUMMY; + } + + size_t buffer_index = index / this->buffer_size_; + size_t offset_in_buffer = index - this->buffer_size_ * buffer_index; + + return this->buffers_[buffer_index][offset_in_buffer]; +} + +} // namespace esphome::split_buffer diff --git a/esphome/components/split_buffer/split_buffer.h b/esphome/components/split_buffer/split_buffer.h new file mode 100644 index 0000000000..c3490f3d6e --- /dev/null +++ b/esphome/components/split_buffer/split_buffer.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +namespace esphome::split_buffer { + +class SplitBuffer { + public: + SplitBuffer() = default; + ~SplitBuffer(); + + // Initialize the buffer with the desired total length + bool init(size_t total_length); + + // Free all allocated buffers + void free(); + + // Access operators + uint8_t &operator[](size_t index); + const uint8_t &operator[](size_t index) const; + + // Get the total length + size_t size() const { return this->total_length_; } + + // Get buffer information + size_t get_buffer_count() const { return this->buffer_count_; } + size_t get_buffer_size() const { return this->buffer_size_; } + + // Check if successfully initialized + bool is_valid() const { return this->buffers_ != nullptr && this->buffer_count_ > 0; } + + private: + uint8_t **buffers_{nullptr}; + size_t buffer_count_{0}; + size_t buffer_size_{0}; + size_t total_length_{0}; +}; + +} // namespace esphome::split_buffer diff --git a/esphome/components/sps30/sps30.cpp b/esphome/components/sps30/sps30.cpp index b99bf416d6..21a782e49a 100644 --- a/esphome/components/sps30/sps30.cpp +++ b/esphome/components/sps30/sps30.cpp @@ -52,17 +52,19 @@ void SPS30Component::setup() { } else { result = this->write_command(SPS30_CMD_SET_AUTOMATIC_CLEANING_INTERVAL_SECONDS); } - if (result) { - delay(20); - uint16_t secs[2]; - if (this->read_data(secs, 2)) { - this->fan_interval_ = secs[0] << 16 | secs[1]; - } - } - this->status_clear_warning(); - this->skipped_data_read_cycles_ = 0; - this->start_continuous_measurement_(); + this->set_timeout(20, [this, result]() { + if (result) { + uint16_t secs[2]; + if (this->read_data(secs, 2)) { + this->fan_interval_ = secs[0] << 16 | secs[1]; + } + } + this->status_clear_warning(); + this->skipped_data_read_cycles_ = 0; + this->start_continuous_measurement_(); + this->setup_complete_ = true; + }); }); } @@ -111,6 +113,8 @@ void SPS30Component::dump_config() { } void SPS30Component::update() { + if (!this->setup_complete_) + return; /// Check if warning flag active (sensor reconnected?) if (this->status_has_warning()) { ESP_LOGD(TAG, "Reconnecting"); diff --git a/esphome/components/sps30/sps30.h b/esphome/components/sps30/sps30.h index 461a770ab6..18847e16d9 100644 --- a/esphome/components/sps30/sps30.h +++ b/esphome/components/sps30/sps30.h @@ -30,9 +30,11 @@ class SPS30Component : public PollingComponent, public sensirion_common::Sensiri bool start_fan_cleaning(); protected: + bool setup_complete_{false}; uint16_t raw_firmware_version_; char serial_number_[17] = {0}; /// Terminating NULL character uint8_t skipped_data_read_cycles_ = 0; + bool start_continuous_measurement_(); enum ErrorCode : uint8_t { diff --git a/esphome/components/st7567_i2c/st7567_i2c.cpp b/esphome/components/st7567_i2c/st7567_i2c.cpp index 710e473b11..14c21d5148 100644 --- a/esphome/components/st7567_i2c/st7567_i2c.cpp +++ b/esphome/components/st7567_i2c/st7567_i2c.cpp @@ -50,8 +50,10 @@ void HOT I2CST7567::write_display_data() { static const size_t BLOCK_SIZE = 64; for (uint8_t x = 0; x < (uint8_t) this->get_width_internal(); x += BLOCK_SIZE) { + size_t remaining = static_cast(this->get_width_internal()) - x; + size_t chunk = remaining > BLOCK_SIZE ? BLOCK_SIZE : remaining; this->write_register(esphome::st7567_base::ST7567_SET_START_LINE, &buffer_[y * this->get_width_internal() + x], - this->get_width_internal() - x > BLOCK_SIZE ? BLOCK_SIZE : this->get_width_internal() - x); + chunk); } } } diff --git a/esphome/components/st7789v/st7789v.cpp b/esphome/components/st7789v/st7789v.cpp index 44f2293ac4..ade9c1126f 100644 --- a/esphome/components/st7789v/st7789v.cpp +++ b/esphome/components/st7789v/st7789v.cpp @@ -176,8 +176,9 @@ void ST7789V::write_display_data() { if (this->eightbitcolor_) { uint8_t temp_buffer[TEMP_BUFFER_SIZE]; size_t temp_index = 0; - for (int line = 0; line < this->get_buffer_length_(); line = line + this->get_width_internal()) { - for (int index = 0; index < this->get_width_internal(); ++index) { + size_t width = static_cast(this->get_width_internal()); + for (size_t line = 0; line < this->get_buffer_length_(); line += width) { + for (size_t index = 0; index < width; ++index) { auto color = display::ColorUtil::color_to_565( display::ColorUtil::to_color(this->buffer_[index + line], display::ColorOrder::COLOR_ORDER_RGB, display::ColorBitness::COLOR_BITNESS_332, true)); diff --git a/esphome/components/statsd/statsd.cpp b/esphome/components/statsd/statsd.cpp index 05f71c7b24..7729f36858 100644 --- a/esphome/components/statsd/statsd.cpp +++ b/esphome/components/statsd/statsd.cpp @@ -151,7 +151,7 @@ void StatsdComponent::send_(std::string *out) { int n_bytes = this->sock_->sendto(out->c_str(), out->length(), 0, reinterpret_cast(&this->destination_), sizeof(this->destination_)); - if (n_bytes != out->length()) { + if (n_bytes != static_cast(out->length())) { ESP_LOGE(TAG, "Failed to send UDP packed (%d of %d)", n_bytes, out->length()); } #endif diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index a96f56a045..1a1736aed1 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -4,7 +4,7 @@ from esphome import core from esphome.config_helpers import Extend, Remove, merge_config import esphome.config_validation as cv from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS -from esphome.yaml_util import ESPHomeDataBase, make_data_base +from esphome.yaml_util import ESPHomeDataBase, ESPLiteralValue, make_data_base from .jinja import Jinja, JinjaStr, TemplateError, TemplateRuntimeError, has_jinja @@ -127,6 +127,8 @@ def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): def _substitute_item(substitutions, item, path, jinja, ignore_missing): + if isinstance(item, ESPLiteralValue): + return None # do not substitute inside literal blocks if isinstance(item, list): for i, it in enumerate(item): sub = _substitute_item(substitutions, it, path + [i], jinja, ignore_missing) diff --git a/esphome/components/substitutions/jinja.py b/esphome/components/substitutions/jinja.py index c6e40a668d..e7164d8fff 100644 --- a/esphome/components/substitutions/jinja.py +++ b/esphome/components/substitutions/jinja.py @@ -1,9 +1,10 @@ +from ast import literal_eval import logging import math import re import jinja2 as jinja -from jinja2.nativetypes import NativeEnvironment +from jinja2.sandbox import SandboxedEnvironment TemplateError = jinja.TemplateError TemplateSyntaxError = jinja.TemplateSyntaxError @@ -70,7 +71,7 @@ class Jinja: """ def __init__(self, context_vars): - self.env = NativeEnvironment( + self.env = SandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, block_start_string="<%", @@ -90,6 +91,15 @@ class Jinja: **SAFE_GLOBAL_FUNCTIONS, } + def safe_eval(self, expr): + try: + result = literal_eval(expr) + if not isinstance(result, str): + return result + except (ValueError, SyntaxError, MemoryError, TypeError): + pass + return expr + def expand(self, content_str): """ Renders a string that may contain Jinja expressions or statements @@ -106,7 +116,7 @@ class Jinja: override_vars = content_str.upvalues try: template = self.env.from_string(content_str) - result = template.render(override_vars) + result = self.safe_eval(template.render(override_vars)) if isinstance(result, Undefined): # This happens when the expression is simply an undefined variable. Jinja does not # raise an exception, instead we get "Undefined". diff --git a/esphome/components/sx126x/__init__.py b/esphome/components/sx126x/__init__.py index b6aeaf072c..370cd102d4 100644 --- a/esphome/components/sx126x/__init__.py +++ b/esphome/components/sx126x/__init__.py @@ -15,6 +15,10 @@ CONF_BANDWIDTH = "bandwidth" CONF_BITRATE = "bitrate" CONF_CODING_RATE = "coding_rate" CONF_CRC_ENABLE = "crc_enable" +CONF_CRC_INVERTED = "crc_inverted" +CONF_CRC_SIZE = "crc_size" +CONF_CRC_POLYNOMIAL = "crc_polynomial" +CONF_CRC_INITIAL = "crc_initial" CONF_DEVIATION = "deviation" CONF_DIO1_PIN = "dio1_pin" CONF_HW_VERSION = "hw_version" @@ -188,6 +192,14 @@ CONFIG_SCHEMA = ( cv.Required(CONF_BUSY_PIN): pins.internal_gpio_input_pin_schema, cv.Optional(CONF_CODING_RATE, default="CR_4_5"): cv.enum(CODING_RATE), cv.Optional(CONF_CRC_ENABLE, default=False): cv.boolean, + cv.Optional(CONF_CRC_INVERTED, default=True): cv.boolean, + cv.Optional(CONF_CRC_SIZE, default=2): cv.int_range(min=1, max=2), + cv.Optional(CONF_CRC_POLYNOMIAL, default=0x1021): cv.All( + cv.hex_int, cv.Range(min=0, max=0xFFFF) + ), + cv.Optional(CONF_CRC_INITIAL, default=0x1D0F): cv.All( + cv.hex_int, cv.Range(min=0, max=0xFFFF) + ), cv.Optional(CONF_DEVIATION, default=5000): cv.int_range(min=0, max=100000), cv.Required(CONF_DIO1_PIN): pins.internal_gpio_input_pin_schema, cv.Required(CONF_FREQUENCY): cv.int_range(min=137000000, max=1020000000), @@ -251,6 +263,10 @@ async def to_code(config): cg.add(var.set_shaping(config[CONF_SHAPING])) cg.add(var.set_bitrate(config[CONF_BITRATE])) cg.add(var.set_crc_enable(config[CONF_CRC_ENABLE])) + cg.add(var.set_crc_inverted(config[CONF_CRC_INVERTED])) + cg.add(var.set_crc_size(config[CONF_CRC_SIZE])) + cg.add(var.set_crc_polynomial(config[CONF_CRC_POLYNOMIAL])) + cg.add(var.set_crc_initial(config[CONF_CRC_INITIAL])) cg.add(var.set_payload_length(config[CONF_PAYLOAD_LENGTH])) cg.add(var.set_preamble_size(config[CONF_PREAMBLE_SIZE])) cg.add(var.set_preamble_detect(config[CONF_PREAMBLE_DETECT])) diff --git a/esphome/components/sx126x/sx126x.cpp b/esphome/components/sx126x/sx126x.cpp index f5393c478a..bb59f26b79 100644 --- a/esphome/components/sx126x/sx126x.cpp +++ b/esphome/components/sx126x/sx126x.cpp @@ -235,6 +235,16 @@ void SX126x::configure() { buf[7] = (fdev >> 0) & 0xFF; this->write_opcode_(RADIO_SET_MODULATIONPARAMS, buf, 8); + // set crc params + if (this->crc_enable_) { + buf[0] = this->crc_initial_ >> 8; + buf[1] = this->crc_initial_ & 0xFF; + this->write_register_(REG_CRC_INITIAL, buf, 2); + buf[0] = this->crc_polynomial_ >> 8; + buf[1] = this->crc_polynomial_ & 0xFF; + this->write_register_(REG_CRC_POLYNOMIAL, buf, 2); + } + // set packet params and sync word this->set_packet_params_(this->get_max_packet_size()); if (!this->sync_value_.empty()) { @@ -276,7 +286,11 @@ void SX126x::set_packet_params_(uint8_t payload_length) { buf[4] = 0x00; buf[5] = (this->payload_length_ > 0) ? 0x00 : 0x01; buf[6] = payload_length; - buf[7] = this->crc_enable_ ? 0x06 : 0x01; + if (this->crc_enable_) { + buf[7] = (this->crc_inverted_ ? 0x04 : 0x00) + (this->crc_size_ & 0x02); + } else { + buf[7] = 0x01; + } buf[8] = 0x00; this->write_opcode_(RADIO_SET_PACKETPARAMS, buf, 9); } diff --git a/esphome/components/sx126x/sx126x.h b/esphome/components/sx126x/sx126x.h index fd5c37942d..47d6449738 100644 --- a/esphome/components/sx126x/sx126x.h +++ b/esphome/components/sx126x/sx126x.h @@ -67,6 +67,10 @@ class SX126x : public Component, void set_busy_pin(InternalGPIOPin *busy_pin) { this->busy_pin_ = busy_pin; } void set_coding_rate(uint8_t coding_rate) { this->coding_rate_ = coding_rate; } void set_crc_enable(bool crc_enable) { this->crc_enable_ = crc_enable; } + void set_crc_inverted(bool crc_inverted) { this->crc_inverted_ = crc_inverted; } + void set_crc_size(uint8_t crc_size) { this->crc_size_ = crc_size; } + void set_crc_polynomial(uint16_t crc_polynomial) { this->crc_polynomial_ = crc_polynomial; } + void set_crc_initial(uint16_t crc_initial) { this->crc_initial_ = crc_initial; } void set_deviation(uint32_t deviation) { this->deviation_ = deviation; } void set_dio1_pin(InternalGPIOPin *dio1_pin) { this->dio1_pin_ = dio1_pin; } void set_frequency(uint32_t frequency) { this->frequency_ = frequency; } @@ -118,6 +122,11 @@ class SX126x : public Component, char version_[16]; SX126xBw bandwidth_{SX126X_BW_125000}; uint32_t bitrate_{0}; + bool crc_enable_{false}; + bool crc_inverted_{false}; + uint8_t crc_size_{0}; + uint16_t crc_polynomial_{0}; + uint16_t crc_initial_{0}; uint32_t deviation_{0}; uint32_t frequency_{0}; uint32_t payload_length_{0}; @@ -131,7 +140,6 @@ class SX126x : public Component, uint8_t shaping_{0}; uint8_t spreading_factor_{0}; int8_t pa_power_{0}; - bool crc_enable_{false}; bool rx_start_{false}; bool rf_switch_{false}; }; diff --git a/esphome/components/sx126x/sx126x_reg.h b/esphome/components/sx126x/sx126x_reg.h index 3b12d822b5..143f4a05da 100644 --- a/esphome/components/sx126x/sx126x_reg.h +++ b/esphome/components/sx126x/sx126x_reg.h @@ -53,6 +53,8 @@ enum SX126xOpCode : uint8_t { enum SX126xRegister : uint16_t { REG_VERSION_STRING = 0x0320, + REG_CRC_INITIAL = 0x06BC, + REG_CRC_POLYNOMIAL = 0x06BE, REG_GFSK_SYNCWORD = 0x06C0, REG_LORA_SYNCWORD = 0x0740, REG_OCP = 0x08E7, diff --git a/esphome/components/text_sensor/text_sensor.cpp b/esphome/components/text_sensor/text_sensor.cpp index 72b540b84c..17bf20466e 100644 --- a/esphome/components/text_sensor/text_sensor.cpp +++ b/esphome/components/text_sensor/text_sensor.cpp @@ -6,6 +6,22 @@ namespace text_sensor { static const char *const TAG = "text_sensor"; +void log_text_sensor(const char *tag, const char *prefix, const char *type, TextSensor *obj) { + if (obj == nullptr) { + return; + } + + ESP_LOGCONFIG(tag, "%s%s '%s'", prefix, type, obj->get_name().c_str()); + + if (!obj->get_device_class_ref().empty()) { + ESP_LOGCONFIG(tag, "%s Device Class: '%s'", prefix, obj->get_device_class_ref().c_str()); + } + + if (!obj->get_icon_ref().empty()) { + ESP_LOGCONFIG(tag, "%s Icon: '%s'", prefix, obj->get_icon_ref().c_str()); + } +} + void TextSensor::publish_state(const std::string &state) { this->raw_state = state; if (this->raw_callback_) { diff --git a/esphome/components/text_sensor/text_sensor.h b/esphome/components/text_sensor/text_sensor.h index 3ab88e2d91..abbea27b59 100644 --- a/esphome/components/text_sensor/text_sensor.h +++ b/esphome/components/text_sensor/text_sensor.h @@ -11,16 +11,9 @@ namespace esphome { namespace text_sensor { -#define LOG_TEXT_SENSOR(prefix, type, obj) \ - if ((obj) != nullptr) { \ - ESP_LOGCONFIG(TAG, "%s%s '%s'", prefix, LOG_STR_LITERAL(type), (obj)->get_name().c_str()); \ - if (!(obj)->get_device_class_ref().empty()) { \ - ESP_LOGCONFIG(TAG, "%s Device Class: '%s'", prefix, (obj)->get_device_class_ref().c_str()); \ - } \ - if (!(obj)->get_icon_ref().empty()) { \ - ESP_LOGCONFIG(TAG, "%s Icon: '%s'", prefix, (obj)->get_icon_ref().c_str()); \ - } \ - } +void log_text_sensor(const char *tag, const char *prefix, const char *type, TextSensor *obj); + +#define LOG_TEXT_SENSOR(prefix, type, obj) log_text_sensor(TAG, prefix, LOG_STR_LITERAL(type), obj) #define SUB_TEXT_SENSOR(name) \ protected: \ diff --git a/esphome/components/tormatic/tormatic_cover.cpp b/esphome/components/tormatic/tormatic_cover.cpp index be412d62a8..ef93964a28 100644 --- a/esphome/components/tormatic/tormatic_cover.cpp +++ b/esphome/components/tormatic/tormatic_cover.cpp @@ -251,7 +251,7 @@ void Tormatic::stop_at_target_() { // Read a GateStatus from the unit. The unit only sends messages in response to // status requests or commands, so a message needs to be sent first. optional Tormatic::read_gate_status_() { - if (this->available() < sizeof(MessageHeader)) { + if (this->available() < static_cast(sizeof(MessageHeader))) { return {}; } diff --git a/esphome/components/tuya/select/tuya_select.cpp b/esphome/components/tuya/select/tuya_select.cpp index 07b0ff2815..91ddbc77ec 100644 --- a/esphome/components/tuya/select/tuya_select.cpp +++ b/esphome/components/tuya/select/tuya_select.cpp @@ -50,7 +50,7 @@ void TuyaSelect::dump_config() { " Options are:", this->select_id_, this->is_int_ ? "int" : "enum"); auto options = this->traits.get_options(); - for (auto i = 0; i < this->mappings_.size(); i++) { + for (size_t i = 0; i < this->mappings_.size(); i++) { ESP_LOGCONFIG(TAG, " %i: %s", this->mappings_.at(i), options.at(i).c_str()); } } diff --git a/esphome/components/tuya/tuya.cpp b/esphome/components/tuya/tuya.cpp index 1443d10254..12b14be9ff 100644 --- a/esphome/components/tuya/tuya.cpp +++ b/esphome/components/tuya/tuya.cpp @@ -215,12 +215,37 @@ void Tuya::handle_command_(uint8_t command, uint8_t version, const uint8_t *buff this->send_empty_command_(TuyaCommandType::DATAPOINT_QUERY); } break; - case TuyaCommandType::WIFI_RESET: - ESP_LOGE(TAG, "WIFI_RESET is not handled"); - break; case TuyaCommandType::WIFI_SELECT: - ESP_LOGE(TAG, "WIFI_SELECT is not handled"); + case TuyaCommandType::WIFI_RESET: { + const bool is_select = (len >= 1); + // Send WIFI_SELECT ACK + TuyaCommand ack; + ack.cmd = is_select ? TuyaCommandType::WIFI_SELECT : TuyaCommandType::WIFI_RESET; + ack.payload.clear(); + this->send_command_(ack); + // Establish pairing mode for correct first WIFI_STATE byte, EZ (0x00) default + uint8_t first = 0x00; + const char *mode_str = "EZ"; + if (is_select && buffer[0] == 0x01) { + first = 0x01; + mode_str = "AP"; + } + // Send WIFI_STATE response, MCU exits pairing mode + TuyaCommand st; + st.cmd = TuyaCommandType::WIFI_STATE; + st.payload.resize(1); + st.payload[0] = first; + this->send_command_(st); + st.payload[0] = 0x02; + this->send_command_(st); + st.payload[0] = 0x03; + this->send_command_(st); + st.payload[0] = 0x04; + this->send_command_(st); + ESP_LOGI(TAG, "%s received (%s), replied with WIFI_STATE confirming connection established", + is_select ? "WIFI_SELECT" : "WIFI_RESET", mode_str); break; + } case TuyaCommandType::DATAPOINT_DELIVER: break; case TuyaCommandType::DATAPOINT_REPORT_ASYNC: diff --git a/esphome/components/uart/__init__.py b/esphome/components/uart/__init__.py index 7d4c6360fe..764576744f 100644 --- a/esphome/components/uart/__init__.py +++ b/esphome/components/uart/__init__.py @@ -1,3 +1,4 @@ +import math import re from esphome import automation, pins @@ -14,9 +15,9 @@ from esphome.const import ( CONF_DIRECTION, CONF_DUMMY_RECEIVER, CONF_DUMMY_RECEIVER_ID, + CONF_FLOW_CONTROL_PIN, CONF_ID, CONF_INVERT, - CONF_INVERTED, CONF_LAMBDA, CONF_NUMBER, CONF_PORT, @@ -39,9 +40,6 @@ uart_ns = cg.esphome_ns.namespace("uart") UARTComponent = uart_ns.class_("UARTComponent") IDFUARTComponent = uart_ns.class_("IDFUARTComponent", UARTComponent, cg.Component) -ESP32ArduinoUARTComponent = uart_ns.class_( - "ESP32ArduinoUARTComponent", UARTComponent, cg.Component -) ESP8266UartComponent = uart_ns.class_( "ESP8266UartComponent", UARTComponent, cg.Component ) @@ -53,7 +51,6 @@ HostUartComponent = uart_ns.class_("HostUartComponent", UARTComponent, cg.Compon NATIVE_UART_CLASSES = ( str(IDFUARTComponent), - str(ESP32ArduinoUARTComponent), str(ESP8266UartComponent), str(RP2040UartComponent), str(LibreTinyUARTComponent), @@ -119,20 +116,6 @@ def validate_rx_pin(value): return value -def validate_invert_esp32(config): - if ( - CORE.is_esp32 - and CORE.using_arduino - and CONF_TX_PIN in config - and CONF_RX_PIN in config - and config[CONF_TX_PIN][CONF_INVERTED] != config[CONF_RX_PIN][CONF_INVERTED] - ): - raise cv.Invalid( - "Different invert values for TX and RX pin are not supported for ESP32 when using Arduino." - ) - return config - - def validate_host_config(config): if CORE.is_host: if CONF_TX_PIN in config or CONF_RX_PIN in config: @@ -151,10 +134,7 @@ def _uart_declare_type(value): if CORE.is_esp8266: return cv.declare_id(ESP8266UartComponent)(value) if CORE.is_esp32: - if CORE.using_arduino: - return cv.declare_id(ESP32ArduinoUARTComponent)(value) - if CORE.using_esp_idf: - return cv.declare_id(IDFUARTComponent)(value) + return cv.declare_id(IDFUARTComponent)(value) if CORE.is_rp2040: return cv.declare_id(RP2040UartComponent)(value) if CORE.is_libretiny: @@ -174,6 +154,8 @@ UART_PARITY_OPTIONS = { CONF_STOP_BITS = "stop_bits" CONF_DATA_BITS = "data_bits" CONF_PARITY = "parity" +CONF_RX_FULL_THRESHOLD = "rx_full_threshold" +CONF_RX_TIMEOUT = "rx_timeout" UARTDirection = uart_ns.enum("UARTDirection") UART_DIRECTIONS = { @@ -241,8 +223,17 @@ CONFIG_SCHEMA = cv.All( cv.Required(CONF_BAUD_RATE): cv.int_range(min=1), cv.Optional(CONF_TX_PIN): pins.internal_gpio_output_pin_schema, cv.Optional(CONF_RX_PIN): validate_rx_pin, + cv.Optional(CONF_FLOW_CONTROL_PIN): cv.All( + cv.only_on_esp32, pins.internal_gpio_output_pin_schema + ), cv.Optional(CONF_PORT): cv.All(validate_port, cv.only_on(PLATFORM_HOST)), cv.Optional(CONF_RX_BUFFER_SIZE, default=256): cv.validate_bytes, + cv.Optional(CONF_RX_FULL_THRESHOLD): cv.All( + cv.only_on_esp32, cv.validate_bytes, cv.int_range(min=1, max=120) + ), + cv.SplitDefault(CONF_RX_TIMEOUT, esp32=2): cv.All( + cv.only_on_esp32, cv.validate_bytes, cv.int_range(min=0, max=92) + ), cv.Optional(CONF_STOP_BITS, default=1): cv.one_of(1, 2, int=True), cv.Optional(CONF_DATA_BITS, default=8): cv.int_range(min=5, max=8), cv.Optional(CONF_PARITY, default="NONE"): cv.enum( @@ -255,7 +246,6 @@ CONFIG_SCHEMA = cv.All( } ).extend(cv.COMPONENT_SCHEMA), cv.has_at_least_one_key(CONF_TX_PIN, CONF_RX_PIN, CONF_PORT), - validate_invert_esp32, validate_host_config, ) @@ -298,9 +288,27 @@ async def to_code(config): if CONF_RX_PIN in config: rx_pin = await cg.gpio_pin_expression(config[CONF_RX_PIN]) cg.add(var.set_rx_pin(rx_pin)) + if CONF_FLOW_CONTROL_PIN in config: + flow_control_pin = await cg.gpio_pin_expression(config[CONF_FLOW_CONTROL_PIN]) + cg.add(var.set_flow_control_pin(flow_control_pin)) if CONF_PORT in config: cg.add(var.set_name(config[CONF_PORT])) cg.add(var.set_rx_buffer_size(config[CONF_RX_BUFFER_SIZE])) + if CORE.is_esp32: + if CONF_RX_FULL_THRESHOLD not in config: + # Calculate rx_full_threshold to be 10ms + bytelength = config[CONF_DATA_BITS] + config[CONF_STOP_BITS] + 1 + if config[CONF_PARITY] != "NONE": + bytelength += 1 + config[CONF_RX_FULL_THRESHOLD] = max( + 1, + min( + 120, + math.floor((config[CONF_BAUD_RATE] / (bytelength * 1000 / 10)) - 1), + ), + ) + cg.add(var.set_rx_full_threshold(config[CONF_RX_FULL_THRESHOLD])) + cg.add(var.set_rx_timeout(config[CONF_RX_TIMEOUT])) cg.add(var.set_stop_bits(config[CONF_STOP_BITS])) cg.add(var.set_data_bits(config[CONF_DATA_BITS])) cg.add(var.set_parity(config[CONF_PARITY])) @@ -444,8 +452,10 @@ async def uart_write_to_code(config, action_id, template_arg, args): FILTER_SOURCE_FILES = filter_source_files_from_platform( { - "uart_component_esp32_arduino.cpp": {PlatformFramework.ESP32_ARDUINO}, - "uart_component_esp_idf.cpp": {PlatformFramework.ESP32_IDF}, + "uart_component_esp_idf.cpp": { + PlatformFramework.ESP32_IDF, + PlatformFramework.ESP32_ARDUINO, + }, "uart_component_esp8266.cpp": {PlatformFramework.ESP8266_ARDUINO}, "uart_component_host.cpp": {PlatformFramework.HOST_NATIVE}, "uart_component_rp2040.cpp": {PlatformFramework.RP2040_ARDUINO}, diff --git a/esphome/components/uart/uart.h b/esphome/components/uart/uart.h index dc6962fbae..e2912db122 100644 --- a/esphome/components/uart/uart.h +++ b/esphome/components/uart/uart.h @@ -18,6 +18,12 @@ class UARTDevice { void write_byte(uint8_t data) { this->parent_->write_byte(data); } + void set_rx_full_threshold(size_t rx_full_threshold) { this->parent_->set_rx_full_threshold(rx_full_threshold); } + void set_rx_full_threshold_ms(size_t time) { this->parent_->set_rx_full_threshold_ms(time); } + size_t get_rx_full_threshold() { return this->parent_->get_rx_full_threshold(); } + void set_rx_timeout(size_t rx_timeout) { this->parent_->set_rx_timeout(rx_timeout); } + size_t get_rx_timeout() { return this->parent_->get_rx_timeout(); } + void write_array(const uint8_t *data, size_t len) { this->parent_->write_array(data, len); } void write_array(const std::vector &data) { this->parent_->write_array(data); } template void write_array(const std::array &data) { diff --git a/esphome/components/uart/uart_component.cpp b/esphome/components/uart/uart_component.cpp index 09b8c975ab..8f670275d4 100644 --- a/esphome/components/uart/uart_component.cpp +++ b/esphome/components/uart/uart_component.cpp @@ -20,5 +20,13 @@ bool UARTComponent::check_read_timeout_(size_t len) { return true; } +void UARTComponent::set_rx_full_threshold_ms(uint8_t time) { + uint8_t bytelength = this->data_bits_ + this->stop_bits_ + 1; + if (this->parity_ != UARTParityOptions::UART_CONFIG_PARITY_NONE) + bytelength += 1; + int32_t val = clamp((this->baud_rate_ / (bytelength * 1000 / time)) - 1, 1, 120); + this->set_rx_full_threshold(val); +} + } // namespace uart } // namespace esphome diff --git a/esphome/components/uart/uart_component.h b/esphome/components/uart/uart_component.h index a57910c1a1..452688b3e9 100644 --- a/esphome/components/uart/uart_component.h +++ b/esphome/components/uart/uart_component.h @@ -6,6 +6,7 @@ #include "esphome/core/component.h" #include "esphome/core/hal.h" #include "esphome/core/log.h" +#include "esphome/core/helpers.h" #ifdef USE_UART_DEBUGGER #include "esphome/core/automation.h" #endif @@ -82,6 +83,10 @@ class UARTComponent { // @param rx_pin Pointer to the internal GPIO pin used for reception. void set_rx_pin(InternalGPIOPin *rx_pin) { this->rx_pin_ = rx_pin; } + // Sets the flow control pin for the UART bus. + // @param flow_control_pin Pointer to the internal GPIO pin used for flow control. + void set_flow_control_pin(InternalGPIOPin *flow_control_pin) { this->flow_control_pin_ = flow_control_pin; } + // Sets the size of the RX buffer. // @param rx_buffer_size Size of the RX buffer in bytes. void set_rx_buffer_size(size_t rx_buffer_size) { this->rx_buffer_size_ = rx_buffer_size; } @@ -90,6 +95,26 @@ class UARTComponent { // @return Size of the RX buffer in bytes. size_t get_rx_buffer_size() { return this->rx_buffer_size_; } + // Sets the RX FIFO full interrupt threshold. + // @param rx_full_threshold RX full interrupt threshold in bytes. + virtual void set_rx_full_threshold(size_t rx_full_threshold) {} + + // Sets the RX FIFO full interrupt threshold. + // @param time RX full interrupt threshold in ms. + void set_rx_full_threshold_ms(uint8_t time); + + // Gets the RX FIFO full interrupt threshold. + // @return RX full interrupt threshold in bytes. + size_t get_rx_full_threshold() { return this->rx_full_threshold_; } + + // Sets the RX timeout interrupt threshold. + // @param rx_timeout RX timeout interrupt threshold (unit: time of sending one byte). + virtual void set_rx_timeout(size_t rx_timeout) {} + + // Gets the RX timeout interrupt threshold. + // @return RX timeout interrupt threshold (unit: time of sending one byte). + size_t get_rx_timeout() { return this->rx_timeout_; } + // Sets the number of stop bits used in UART communication. // @param stop_bits Number of stop bits. void set_stop_bits(uint8_t stop_bits) { this->stop_bits_ = stop_bits; } @@ -161,7 +186,10 @@ class UARTComponent { InternalGPIOPin *tx_pin_; InternalGPIOPin *rx_pin_; + InternalGPIOPin *flow_control_pin_; size_t rx_buffer_size_; + size_t rx_full_threshold_{1}; + size_t rx_timeout_{0}; uint32_t baud_rate_; uint8_t stop_bits_; uint8_t data_bits_; diff --git a/esphome/components/uart/uart_component_esp32_arduino.cpp b/esphome/components/uart/uart_component_esp32_arduino.cpp deleted file mode 100644 index 4a1c326789..0000000000 --- a/esphome/components/uart/uart_component_esp32_arduino.cpp +++ /dev/null @@ -1,214 +0,0 @@ -#ifdef USE_ESP32_FRAMEWORK_ARDUINO -#include "uart_component_esp32_arduino.h" -#include "esphome/core/application.h" -#include "esphome/core/defines.h" -#include "esphome/core/helpers.h" -#include "esphome/core/log.h" - -#ifdef USE_LOGGER -#include "esphome/components/logger/logger.h" -#endif - -namespace esphome { -namespace uart { -static const char *const TAG = "uart.arduino_esp32"; - -static const uint32_t UART_PARITY_EVEN = 0 << 0; -static const uint32_t UART_PARITY_ODD = 1 << 0; -static const uint32_t UART_PARITY_ENABLE = 1 << 1; -static const uint32_t UART_NB_BIT_5 = 0 << 2; -static const uint32_t UART_NB_BIT_6 = 1 << 2; -static const uint32_t UART_NB_BIT_7 = 2 << 2; -static const uint32_t UART_NB_BIT_8 = 3 << 2; -static const uint32_t UART_NB_STOP_BIT_1 = 1 << 4; -static const uint32_t UART_NB_STOP_BIT_2 = 3 << 4; -static const uint32_t UART_TICK_APB_CLOCK = 1 << 27; - -uint32_t ESP32ArduinoUARTComponent::get_config() { - uint32_t config = 0; - - /* - * All bits numbers below come from - * framework-arduinoespressif32/cores/esp32/esp32-hal-uart.h - * And more specifically conf0 union in uart_dev_t. - * - * Below is bit used from conf0 union. - * : - * parity:0 0:even 1:odd - * parity_en:1 Set this bit to enable uart parity check. - * bit_num:2-4 0:5bits 1:6bits 2:7bits 3:8bits - * stop_bit_num:4-6 stop bit. 1:1bit 2:1.5bits 3:2bits - * tick_ref_always_on:27 select the clock.1:apb clock:ref_tick - */ - - if (this->parity_ == UART_CONFIG_PARITY_EVEN) { - config |= UART_PARITY_EVEN | UART_PARITY_ENABLE; - } else if (this->parity_ == UART_CONFIG_PARITY_ODD) { - config |= UART_PARITY_ODD | UART_PARITY_ENABLE; - } - - switch (this->data_bits_) { - case 5: - config |= UART_NB_BIT_5; - break; - case 6: - config |= UART_NB_BIT_6; - break; - case 7: - config |= UART_NB_BIT_7; - break; - case 8: - config |= UART_NB_BIT_8; - break; - } - - if (this->stop_bits_ == 1) { - config |= UART_NB_STOP_BIT_1; - } else { - config |= UART_NB_STOP_BIT_2; - } - - config |= UART_TICK_APB_CLOCK; - - return config; -} - -void ESP32ArduinoUARTComponent::setup() { - // Use Arduino HardwareSerial UARTs if all used pins match the ones - // preconfigured by the platform. For example if RX disabled but TX pin - // is 1 we still want to use Serial. - bool is_default_tx, is_default_rx; -#ifdef CONFIG_IDF_TARGET_ESP32C3 - is_default_tx = tx_pin_ == nullptr || tx_pin_->get_pin() == 21; - is_default_rx = rx_pin_ == nullptr || rx_pin_->get_pin() == 20; -#else - is_default_tx = tx_pin_ == nullptr || tx_pin_->get_pin() == 1; - is_default_rx = rx_pin_ == nullptr || rx_pin_->get_pin() == 3; -#endif - static uint8_t next_uart_num = 0; - if (is_default_tx && is_default_rx && next_uart_num == 0) { -#if ARDUINO_USB_CDC_ON_BOOT - this->hw_serial_ = &Serial0; -#else - this->hw_serial_ = &Serial; -#endif - next_uart_num++; - } else { -#ifdef USE_LOGGER - bool logger_uses_hardware_uart = true; - -#ifdef USE_LOGGER_USB_CDC - if (logger::global_logger->get_uart() == logger::UART_SELECTION_USB_CDC) { - // this is not a hardware UART, ignore it - logger_uses_hardware_uart = false; - } -#endif // USE_LOGGER_USB_CDC - -#ifdef USE_LOGGER_USB_SERIAL_JTAG - if (logger::global_logger->get_uart() == logger::UART_SELECTION_USB_SERIAL_JTAG) { - // this is not a hardware UART, ignore it - logger_uses_hardware_uart = false; - } -#endif // USE_LOGGER_USB_SERIAL_JTAG - - if (logger_uses_hardware_uart && logger::global_logger->get_baud_rate() > 0 && - logger::global_logger->get_uart() == next_uart_num) { - next_uart_num++; - } -#endif // USE_LOGGER - - if (next_uart_num >= SOC_UART_NUM) { - ESP_LOGW(TAG, "Maximum number of UART components created already."); - this->mark_failed(); - return; - } - - this->number_ = next_uart_num; - this->hw_serial_ = new HardwareSerial(next_uart_num++); // NOLINT(cppcoreguidelines-owning-memory) - } - - this->load_settings(false); -} - -void ESP32ArduinoUARTComponent::load_settings(bool dump_config) { - int8_t tx = this->tx_pin_ != nullptr ? this->tx_pin_->get_pin() : -1; - int8_t rx = this->rx_pin_ != nullptr ? this->rx_pin_->get_pin() : -1; - bool invert = false; - if (tx_pin_ != nullptr && tx_pin_->is_inverted()) - invert = true; - if (rx_pin_ != nullptr && rx_pin_->is_inverted()) - invert = true; - this->hw_serial_->setRxBufferSize(this->rx_buffer_size_); - this->hw_serial_->begin(this->baud_rate_, get_config(), rx, tx, invert); - if (dump_config) { - ESP_LOGCONFIG(TAG, "UART %u was reloaded.", this->number_); - this->dump_config(); - } -} - -void ESP32ArduinoUARTComponent::dump_config() { - ESP_LOGCONFIG(TAG, "UART Bus %d:", this->number_); - LOG_PIN(" TX Pin: ", tx_pin_); - LOG_PIN(" RX Pin: ", rx_pin_); - if (this->rx_pin_ != nullptr) { - ESP_LOGCONFIG(TAG, " RX Buffer Size: %u", this->rx_buffer_size_); - } - ESP_LOGCONFIG(TAG, - " Baud Rate: %u baud\n" - " Data Bits: %u\n" - " Parity: %s\n" - " Stop bits: %u", - this->baud_rate_, this->data_bits_, LOG_STR_ARG(parity_to_str(this->parity_)), this->stop_bits_); - this->check_logger_conflict(); -} - -void ESP32ArduinoUARTComponent::write_array(const uint8_t *data, size_t len) { - this->hw_serial_->write(data, len); -#ifdef USE_UART_DEBUGGER - for (size_t i = 0; i < len; i++) { - this->debug_callback_.call(UART_DIRECTION_TX, data[i]); - } -#endif -} - -bool ESP32ArduinoUARTComponent::peek_byte(uint8_t *data) { - if (!this->check_read_timeout_()) - return false; - *data = this->hw_serial_->peek(); - return true; -} - -bool ESP32ArduinoUARTComponent::read_array(uint8_t *data, size_t len) { - if (!this->check_read_timeout_(len)) - return false; - this->hw_serial_->readBytes(data, len); -#ifdef USE_UART_DEBUGGER - for (size_t i = 0; i < len; i++) { - this->debug_callback_.call(UART_DIRECTION_RX, data[i]); - } -#endif - return true; -} - -int ESP32ArduinoUARTComponent::available() { return this->hw_serial_->available(); } -void ESP32ArduinoUARTComponent::flush() { - ESP_LOGVV(TAG, " Flushing"); - this->hw_serial_->flush(); -} - -void ESP32ArduinoUARTComponent::check_logger_conflict() { -#ifdef USE_LOGGER - if (this->hw_serial_ == nullptr || logger::global_logger->get_baud_rate() == 0) { - return; - } - - if (this->hw_serial_ == logger::global_logger->get_hw_serial()) { - ESP_LOGW(TAG, " You're using the same serial port for logging and the UART component. Please " - "disable logging over the serial port by setting logger->baud_rate to 0."); - } -#endif -} - -} // namespace uart -} // namespace esphome -#endif // USE_ESP32_FRAMEWORK_ARDUINO diff --git a/esphome/components/uart/uart_component_esp32_arduino.h b/esphome/components/uart/uart_component_esp32_arduino.h deleted file mode 100644 index de17d9718b..0000000000 --- a/esphome/components/uart/uart_component_esp32_arduino.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#ifdef USE_ESP32_FRAMEWORK_ARDUINO - -#include -#include -#include -#include "esphome/core/component.h" -#include "esphome/core/hal.h" -#include "esphome/core/log.h" -#include "uart_component.h" - -namespace esphome { -namespace uart { - -class ESP32ArduinoUARTComponent : public UARTComponent, public Component { - public: - void setup() override; - void dump_config() override; - float get_setup_priority() const override { return setup_priority::BUS; } - - void write_array(const uint8_t *data, size_t len) override; - - bool peek_byte(uint8_t *data) override; - bool read_array(uint8_t *data, size_t len) override; - - int available() override; - void flush() override; - - uint32_t get_config(); - - HardwareSerial *get_hw_serial() { return this->hw_serial_; } - uint8_t get_hw_serial_number() { return this->number_; } - - /** - * Load the UART with the current settings. - * @param dump_config (Optional, default `true`): True for displaying new settings or - * false to change it quitely - * - * Example: - * ```cpp - * id(uart1).load_settings(); - * ``` - * - * This will load the current UART interface with the latest settings (baud_rate, parity, etc). - */ - void load_settings(bool dump_config) override; - void load_settings() override { this->load_settings(true); } - - protected: - void check_logger_conflict() override; - - HardwareSerial *hw_serial_{nullptr}; - uint8_t number_{0}; -}; - -} // namespace uart -} // namespace esphome - -#endif // USE_ESP32_FRAMEWORK_ARDUINO diff --git a/esphome/components/uart/uart_component_esp_idf.cpp b/esphome/components/uart/uart_component_esp_idf.cpp index 6bb4b16819..7530856b1e 100644 --- a/esphome/components/uart/uart_component_esp_idf.cpp +++ b/esphome/components/uart/uart_component_esp_idf.cpp @@ -1,4 +1,4 @@ -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include "uart_component_esp_idf.h" #include @@ -90,6 +90,12 @@ void IDFUARTComponent::setup() { xSemaphoreTake(this->lock_, portMAX_DELAY); + this->load_settings(false); + + xSemaphoreGive(this->lock_); +} + +void IDFUARTComponent::load_settings(bool dump_config) { uart_config_t uart_config = this->get_config_(); esp_err_t err = uart_param_config(this->uart_num_, &uart_config); if (err != ESP_OK) { @@ -100,6 +106,7 @@ void IDFUARTComponent::setup() { int8_t tx = this->tx_pin_ != nullptr ? this->tx_pin_->get_pin() : -1; int8_t rx = this->rx_pin_ != nullptr ? this->rx_pin_->get_pin() : -1; + int8_t flow_control = this->flow_control_pin_ != nullptr ? this->flow_control_pin_->get_pin() : -1; uint32_t invert = 0; if (this->tx_pin_ != nullptr && this->tx_pin_->is_inverted()) @@ -114,13 +121,21 @@ void IDFUARTComponent::setup() { return; } - err = uart_set_pin(this->uart_num_, tx, rx, UART_PIN_NO_CHANGE, UART_PIN_NO_CHANGE); + err = uart_set_pin(this->uart_num_, tx, rx, flow_control, UART_PIN_NO_CHANGE); if (err != ESP_OK) { ESP_LOGW(TAG, "uart_set_pin failed: %s", esp_err_to_name(err)); this->mark_failed(); return; } + if (uart_is_driver_installed(this->uart_num_)) { + uart_driver_delete(this->uart_num_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "uart_driver_delete failed: %s", esp_err_to_name(err)); + this->mark_failed(); + return; + } + } err = uart_driver_install(this->uart_num_, /* UART RX ring buffer size. */ this->rx_buffer_size_, /* UART TX ring buffer size. If set to zero, driver will not use TX buffer, TX function will block task until all data have been sent out.*/ @@ -133,17 +148,29 @@ void IDFUARTComponent::setup() { return; } - xSemaphoreGive(this->lock_); -} - -void IDFUARTComponent::load_settings(bool dump_config) { - uart_config_t uart_config = this->get_config_(); - esp_err_t err = uart_param_config(this->uart_num_, &uart_config); + err = uart_set_rx_full_threshold(this->uart_num_, this->rx_full_threshold_); if (err != ESP_OK) { - ESP_LOGW(TAG, "uart_param_config failed: %s", esp_err_to_name(err)); + ESP_LOGW(TAG, "uart_set_rx_full_threshold failed: %s", esp_err_to_name(err)); this->mark_failed(); return; - } else if (dump_config) { + } + + err = uart_set_rx_timeout(this->uart_num_, this->rx_timeout_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "uart_set_rx_timeout failed: %s", esp_err_to_name(err)); + this->mark_failed(); + return; + } + + auto mode = this->flow_control_pin_ != nullptr ? UART_MODE_RS485_HALF_DUPLEX : UART_MODE_UART; + err = uart_set_mode(this->uart_num_, mode); + if (err != ESP_OK) { + ESP_LOGW(TAG, "uart_set_mode failed: %s", esp_err_to_name(err)); + this->mark_failed(); + return; + } + + if (dump_config) { ESP_LOGCONFIG(TAG, "UART %u was reloaded.", this->uart_num_); this->dump_config(); } @@ -153,8 +180,13 @@ void IDFUARTComponent::dump_config() { ESP_LOGCONFIG(TAG, "UART Bus %u:", this->uart_num_); LOG_PIN(" TX Pin: ", tx_pin_); LOG_PIN(" RX Pin: ", rx_pin_); + LOG_PIN(" Flow Control Pin: ", flow_control_pin_); if (this->rx_pin_ != nullptr) { - ESP_LOGCONFIG(TAG, " RX Buffer Size: %u", this->rx_buffer_size_); + ESP_LOGCONFIG(TAG, + " RX Buffer Size: %u\n" + " RX Full Threshold: %u\n" + " RX Timeout: %u", + this->rx_buffer_size_, this->rx_full_threshold_, this->rx_timeout_); } ESP_LOGCONFIG(TAG, " Baud Rate: %" PRIu32 " baud\n" @@ -165,6 +197,28 @@ void IDFUARTComponent::dump_config() { this->check_logger_conflict(); } +void IDFUARTComponent::set_rx_full_threshold(size_t rx_full_threshold) { + if (this->is_ready()) { + esp_err_t err = uart_set_rx_full_threshold(this->uart_num_, rx_full_threshold); + if (err != ESP_OK) { + ESP_LOGW(TAG, "uart_set_rx_full_threshold failed: %s", esp_err_to_name(err)); + return; + } + } + this->rx_full_threshold_ = rx_full_threshold; +} + +void IDFUARTComponent::set_rx_timeout(size_t rx_timeout) { + if (this->is_ready()) { + esp_err_t err = uart_set_rx_timeout(this->uart_num_, rx_timeout); + if (err != ESP_OK) { + ESP_LOGW(TAG, "uart_set_rx_timeout failed: %s", esp_err_to_name(err)); + return; + } + } + this->rx_timeout_ = rx_timeout; +} + void IDFUARTComponent::write_array(const uint8_t *data, size_t len) { xSemaphoreTake(this->lock_, portMAX_DELAY); uart_write_bytes(this->uart_num_, data, len); diff --git a/esphome/components/uart/uart_component_esp_idf.h b/esphome/components/uart/uart_component_esp_idf.h index 215641ebe2..a2ba2aa968 100644 --- a/esphome/components/uart/uart_component_esp_idf.h +++ b/esphome/components/uart/uart_component_esp_idf.h @@ -1,6 +1,6 @@ #pragma once -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include #include "esphome/core/component.h" @@ -15,6 +15,9 @@ class IDFUARTComponent : public UARTComponent, public Component { void dump_config() override; float get_setup_priority() const override { return setup_priority::BUS; } + void set_rx_full_threshold(size_t rx_full_threshold) override; + void set_rx_timeout(size_t rx_timeout) override; + void write_array(const uint8_t *data, size_t len) override; bool peek_byte(uint8_t *data) override; @@ -55,4 +58,4 @@ class IDFUARTComponent : public UARTComponent, public Component { } // namespace uart } // namespace esphome -#endif // USE_ESP_IDF +#endif // USE_ESP32 diff --git a/esphome/components/uponor_smatrix/climate/uponor_smatrix_climate.cpp b/esphome/components/uponor_smatrix/climate/uponor_smatrix_climate.cpp index d7e672d8cf..19a9112c73 100644 --- a/esphome/components/uponor_smatrix/climate/uponor_smatrix_climate.cpp +++ b/esphome/components/uponor_smatrix/climate/uponor_smatrix_climate.cpp @@ -58,7 +58,7 @@ void UponorSmatrixClimate::control(const climate::ClimateCall &call) { } void UponorSmatrixClimate::on_device_data(const UponorSmatrixData *data, size_t data_len) { - for (int i = 0; i < data_len; i++) { + for (size_t i = 0; i < data_len; i++) { switch (data[i].id) { case UPONOR_ID_TARGET_TEMP_MIN: this->min_temperature_ = raw_to_celsius(data[i].value); diff --git a/esphome/components/uponor_smatrix/sensor/uponor_smatrix_sensor.cpp b/esphome/components/uponor_smatrix/sensor/uponor_smatrix_sensor.cpp index 452660dc14..a1d0db214f 100644 --- a/esphome/components/uponor_smatrix/sensor/uponor_smatrix_sensor.cpp +++ b/esphome/components/uponor_smatrix/sensor/uponor_smatrix_sensor.cpp @@ -18,7 +18,7 @@ void UponorSmatrixSensor::dump_config() { } void UponorSmatrixSensor::on_device_data(const UponorSmatrixData *data, size_t data_len) { - for (int i = 0; i < data_len; i++) { + for (size_t i = 0; i < data_len; i++) { switch (data[i].id) { case UPONOR_ID_ROOM_TEMP: if (this->temperature_sensor_ != nullptr) diff --git a/esphome/components/uponor_smatrix/uponor_smatrix.cpp b/esphome/components/uponor_smatrix/uponor_smatrix.cpp index a0017518bf..867305059f 100644 --- a/esphome/components/uponor_smatrix/uponor_smatrix.cpp +++ b/esphome/components/uponor_smatrix/uponor_smatrix.cpp @@ -122,7 +122,7 @@ bool UponorSmatrixComponent::parse_byte_(uint8_t byte) { // Decode packet payload data for easy access UponorSmatrixData data[data_len]; - for (int i = 0; i < data_len; i++) { + for (size_t i = 0; i < data_len; i++) { data[i].id = packet[(i * 3) + 4]; data[i].value = encode_uint16(packet[(i * 3) + 5], packet[(i * 3) + 6]); } @@ -135,7 +135,7 @@ bool UponorSmatrixComponent::parse_byte_(uint8_t byte) { // thermostat sending both room temperature and time information. bool found_temperature = false; bool found_time = false; - for (int i = 0; i < data_len; i++) { + for (size_t i = 0; i < data_len; i++) { if (data[i].id == UPONOR_ID_ROOM_TEMP) found_temperature = true; if (data[i].id == UPONOR_ID_DATETIME1) @@ -181,7 +181,7 @@ bool UponorSmatrixComponent::send(uint16_t device_address, const UponorSmatrixDa packet.push_back(device_address >> 8); packet.push_back(device_address >> 0); - for (int i = 0; i < data_len; i++) { + for (size_t i = 0; i < data_len; i++) { packet.push_back(data[i].id); packet.push_back(data[i].value >> 8); packet.push_back(data[i].value >> 0); diff --git a/esphome/components/usb_host/__init__.py b/esphome/components/usb_host/__init__.py index 0fe3310127..de734bf425 100644 --- a/esphome/components/usb_host/__init__.py +++ b/esphome/components/usb_host/__init__.py @@ -1,5 +1,6 @@ import esphome.codegen as cg from esphome.components.esp32 import ( + VARIANT_ESP32P4, VARIANT_ESP32S2, VARIANT_ESP32S3, add_idf_sdkconfig_option, @@ -47,7 +48,7 @@ CONFIG_SCHEMA = cv.All( } ), cv.only_with_esp_idf, - only_on_variant(supported=[VARIANT_ESP32S2, VARIANT_ESP32S3]), + only_on_variant(supported=[VARIANT_ESP32S2, VARIANT_ESP32S3, VARIANT_ESP32P4]), ) diff --git a/esphome/components/usb_host/usb_host.h b/esphome/components/usb_host/usb_host.h index c5466eb1f0..4f8d2ec9a8 100644 --- a/esphome/components/usb_host/usb_host.h +++ b/esphome/components/usb_host/usb_host.h @@ -1,18 +1,45 @@ #pragma once // Should not be needed, but it's required to pass CI clang-tidy checks -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "esphome/core/component.h" #include #include "usb/usb_host.h" - -#include +#include +#include +#include "esphome/core/lock_free_queue.h" +#include "esphome/core/event_pool.h" +#include namespace esphome { namespace usb_host { +// THREADING MODEL: +// This component uses a dedicated USB task for event processing to prevent data loss. +// - USB Task (high priority): Handles USB events, executes transfer callbacks +// - Main Loop Task: Initiates transfers, processes completion events +// +// Thread-safe communication: +// - Lock-free queues for USB task -> main loop events (SPSC pattern) +// - Lock-free TransferRequest pool using atomic bitmask (MCSP pattern) +// +// TransferRequest pool access pattern: +// - get_trq_() [allocate]: Called from BOTH USB task and main loop threads +// * USB task: via USB UART input callbacks that restart transfers immediately +// * Main loop: for output transfers and flow-controlled input restarts +// - release_trq() [deallocate]: Called from main loop thread only +// +// The multi-threaded allocation is intentional for performance: +// - USB task can immediately restart input transfers without context switching +// - Main loop controls backpressure by deciding when to restart after consuming data +// The atomic bitmask ensures thread-safe allocation without mutex blocking. + static const char *const TAG = "usb_host"; +// Forward declarations +struct TransferRequest; +class USBClient; + // constants for setup packet type static const uint8_t USB_RECIP_DEVICE = 0; static const uint8_t USB_RECIP_INTERFACE = 1; @@ -26,6 +53,10 @@ static const uint8_t USB_DIR_OUT = 0; static const size_t SETUP_PACKET_SIZE = 8; static const size_t MAX_REQUESTS = 16; // maximum number of outstanding requests possible. +static_assert(MAX_REQUESTS <= 16, "MAX_REQUESTS must be <= 16 to fit in uint16_t bitmask"); +static constexpr size_t USB_EVENT_QUEUE_SIZE = 32; // Size of event queue between USB task and main loop +static constexpr size_t USB_TASK_STACK_SIZE = 4096; // Stack size for USB task (same as ESP-IDF USB examples) +static constexpr UBaseType_t USB_TASK_PRIORITY = 5; // Higher priority than main loop (tskIDLE_PRIORITY + 5) // used to report a transfer status struct TransferStatus { @@ -49,6 +80,31 @@ struct TransferRequest { USBClient *client; }; +enum EventType : uint8_t { + EVENT_DEVICE_NEW, + EVENT_DEVICE_GONE, + EVENT_TRANSFER_COMPLETE, + EVENT_CONTROL_COMPLETE, +}; + +struct UsbEvent { + EventType type; + union { + struct { + uint8_t address; + } device_new; + struct { + usb_device_handle_t handle; + } device_gone; + struct { + TransferRequest *trq; + } transfer; + } data; + + // Required for EventPool - no cleanup needed for POD types + void release() {} +}; + // callback function type. enum ClientState { @@ -63,13 +119,7 @@ class USBClient : public Component { friend class USBHost; public: - USBClient(uint16_t vid, uint16_t pid) : vid_(vid), pid_(pid) { init_pool(); } - - void init_pool() { - this->trq_pool_.clear(); - for (size_t i = 0; i != MAX_REQUESTS; i++) - this->trq_pool_.push_back(&this->requests_[i]); - } + USBClient(uint16_t vid, uint16_t pid) : vid_(vid), pid_(pid), trq_in_use_(0) {} void setup() override; void loop() override; // setup must happen after the host bus has been setup @@ -84,12 +134,26 @@ class USBClient : public Component { bool control_transfer(uint8_t type, uint8_t request, uint16_t value, uint16_t index, const transfer_cb_t &callback, const std::vector &data = {}); + // Lock-free event queue and pool for USB task to main loop communication + // Must be public for access from static callbacks + LockFreeQueue event_queue; + EventPool event_pool; + protected: bool register_(); - TransferRequest *get_trq_(); + TransferRequest *get_trq_(); // Lock-free allocation using atomic bitmask (multi-consumer safe) virtual void disconnect(); virtual void on_connected() {} - virtual void on_disconnected() { this->init_pool(); } + virtual void on_disconnected() { + // Reset all requests to available (all bits to 0) + this->trq_in_use_.store(0); + } + + // USB task management + static void usb_task_fn(void *arg); + void usb_task_loop(); + + TaskHandle_t usb_task_handle_{nullptr}; usb_host_client_handle_t handle_{}; usb_device_handle_t device_handle_{}; @@ -97,7 +161,12 @@ class USBClient : public Component { int state_{USB_CLIENT_INIT}; uint16_t vid_{}; uint16_t pid_{}; - std::list trq_pool_{}; + // Lock-free pool management using atomic bitmask (no dynamic allocation) + // Bit i = 1: requests_[i] is in use, Bit i = 0: requests_[i] is available + // Supports multiple concurrent consumers (both threads can allocate) + // Single producer for deallocation (main loop only) + // Limited to 16 slots by uint16_t size (enforced by static_assert) + std::atomic trq_in_use_; TransferRequest requests_[MAX_REQUESTS]{}; }; class USBHost : public Component { diff --git a/esphome/components/usb_host/usb_host_client.cpp b/esphome/components/usb_host/usb_host_client.cpp index 4c0c12fa18..b26385a8ef 100644 --- a/esphome/components/usb_host/usb_host_client.cpp +++ b/esphome/components/usb_host/usb_host_client.cpp @@ -1,5 +1,5 @@ // Should not be needed, but it's required to pass CI clang-tidy checks -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "usb_host.h" #include "esphome/core/log.h" #include "esphome/core/hal.h" @@ -7,6 +7,7 @@ #include #include +#include namespace esphome { namespace usb_host { @@ -139,24 +140,40 @@ static std::string get_descriptor_string(const usb_str_desc_t *desc) { return {buffer}; } +// CALLBACK CONTEXT: USB task (called from usb_host_client_handle_events in USB task) static void client_event_cb(const usb_host_client_event_msg_t *event_msg, void *ptr) { auto *client = static_cast(ptr); + + // Allocate event from pool + UsbEvent *event = client->event_pool.allocate(); + if (event == nullptr) { + // No events available - increment counter for periodic logging + client->event_queue.increment_dropped_count(); + return; + } + + // Queue events to be processed in main loop switch (event_msg->event) { case USB_HOST_CLIENT_EVENT_NEW_DEV: { - auto addr = event_msg->new_dev.address; ESP_LOGD(TAG, "New device %d", event_msg->new_dev.address); - client->on_opened(addr); + event->type = EVENT_DEVICE_NEW; + event->data.device_new.address = event_msg->new_dev.address; break; } case USB_HOST_CLIENT_EVENT_DEV_GONE: { - client->on_removed(event_msg->dev_gone.dev_hdl); - ESP_LOGD(TAG, "Device gone %d", event_msg->new_dev.address); + ESP_LOGD(TAG, "Device gone"); + event->type = EVENT_DEVICE_GONE; + event->data.device_gone.handle = event_msg->dev_gone.dev_hdl; break; } default: ESP_LOGD(TAG, "Unknown event %d", event_msg->event); - break; + client->event_pool.release(event); + return; } + + // Push to lock-free queue (always succeeds since pool size == queue size) + client->event_queue.push(event); } void USBClient::setup() { usb_host_client_config_t config{.is_synchronous = false, @@ -169,13 +186,65 @@ void USBClient::setup() { this->mark_failed(); return; } - for (auto *trq : this->trq_pool_) { - usb_host_transfer_alloc(64, 0, &trq->transfer); - trq->client = this; + // Pre-allocate USB transfer buffers for all slots at startup + // This avoids any dynamic allocation during runtime + for (size_t i = 0; i < MAX_REQUESTS; i++) { + usb_host_transfer_alloc(64, 0, &this->requests_[i].transfer); + this->requests_[i].client = this; // Set once, never changes + } + + // Create and start USB task + xTaskCreate(usb_task_fn, "usb_task", + USB_TASK_STACK_SIZE, // Stack size + this, // Task parameter + USB_TASK_PRIORITY, // Priority (higher than main loop) + &this->usb_task_handle_); + + if (this->usb_task_handle_ == nullptr) { + ESP_LOGE(TAG, "Failed to create USB task"); + this->mark_failed(); + } +} + +void USBClient::usb_task_fn(void *arg) { + auto *client = static_cast(arg); + client->usb_task_loop(); +} + +void USBClient::usb_task_loop() { + while (true) { + usb_host_client_handle_events(this->handle_, portMAX_DELAY); } } void USBClient::loop() { + // Process any events from the USB task + UsbEvent *event; + while ((event = this->event_queue.pop()) != nullptr) { + switch (event->type) { + case EVENT_DEVICE_NEW: + this->on_opened(event->data.device_new.address); + break; + case EVENT_DEVICE_GONE: + this->on_removed(event->data.device_gone.handle); + break; + case EVENT_TRANSFER_COMPLETE: + case EVENT_CONTROL_COMPLETE: { + auto *trq = event->data.transfer.trq; + this->release_trq(trq); + break; + } + } + // Return event to pool for reuse + this->event_pool.release(event); + } + + // Log dropped events periodically + uint16_t dropped = this->event_queue.get_and_reset_dropped_count(); + if (dropped > 0) { + ESP_LOGW(TAG, "Dropped %u USB events due to queue overflow", dropped); + } + switch (this->state_) { case USB_CLIENT_OPEN: { int err; @@ -228,7 +297,6 @@ void USBClient::loop() { } default: - usb_host_client_handle_events(this->handle_, 0); break; } } @@ -245,6 +313,26 @@ void USBClient::on_removed(usb_device_handle_t handle) { } } +// Helper to queue transfer cleanup to main loop +static void queue_transfer_cleanup(TransferRequest *trq, EventType type) { + auto *client = trq->client; + + // Allocate event from pool + UsbEvent *event = client->event_pool.allocate(); + if (event == nullptr) { + // No events available - increment counter for periodic logging + client->event_queue.increment_dropped_count(); + return; + } + + event->type = type; + event->data.transfer.trq = trq; + + // Push to lock-free queue (always succeeds since pool size == queue size) + client->event_queue.push(event); +} + +// CALLBACK CONTEXT: USB task (called from usb_host_client_handle_events in USB task) static void control_callback(const usb_transfer_t *xfer) { auto *trq = static_cast(xfer->context); trq->status.error_code = xfer->status; @@ -252,22 +340,54 @@ static void control_callback(const usb_transfer_t *xfer) { trq->status.endpoint = xfer->bEndpointAddress; trq->status.data = xfer->data_buffer; trq->status.data_len = xfer->actual_num_bytes; - if (trq->callback != nullptr) + + // Execute callback in USB task context + if (trq->callback != nullptr) { trq->callback(trq->status); - trq->client->release_trq(trq); + } + + // Queue cleanup to main loop + queue_transfer_cleanup(trq, EVENT_CONTROL_COMPLETE); } +// THREAD CONTEXT: Called from both USB task and main loop threads (multi-consumer) +// - USB task: USB UART input callbacks restart transfers for immediate data reception +// - Main loop: Output transfers and flow-controlled input restarts after consuming data +// +// THREAD SAFETY: Lock-free using atomic compare-and-swap on bitmask +// This multi-threaded access is intentional for performance - USB task can +// immediately restart transfers without waiting for main loop scheduling. TransferRequest *USBClient::get_trq_() { - if (this->trq_pool_.empty()) { - ESP_LOGE(TAG, "Too many requests queued"); - return nullptr; + uint16_t mask = this->trq_in_use_.load(std::memory_order_relaxed); + + // Find first available slot (bit = 0) and try to claim it atomically + // We use a while loop to allow retrying the same slot after CAS failure + size_t i = 0; + while (i != MAX_REQUESTS) { + if (mask & (1U << i)) { + // Slot is in use, move to next slot + i++; + continue; + } + + // Slot i appears available, try to claim it atomically + uint16_t desired = mask | (1U << i); // Set bit i to mark as in-use + + if (this->trq_in_use_.compare_exchange_weak(mask, desired, std::memory_order_acquire, std::memory_order_relaxed)) { + // Successfully claimed slot i - prepare the TransferRequest + auto *trq = &this->requests_[i]; + trq->transfer->context = trq; + trq->transfer->device_handle = this->device_handle_; + return trq; + } + // CAS failed - another thread modified the bitmask + // mask was already updated by compare_exchange_weak with the current value + // No need to reload - the CAS already did that for us + i = 0; } - auto *trq = this->trq_pool_.front(); - this->trq_pool_.pop_front(); - trq->client = this; - trq->transfer->context = trq; - trq->transfer->device_handle = this->device_handle_; - return trq; + + ESP_LOGE(TAG, "All %d transfer slots in use", MAX_REQUESTS); + return nullptr; } void USBClient::disconnect() { this->on_disconnected(); @@ -280,6 +400,8 @@ void USBClient::disconnect() { this->device_addr_ = -1; } +// THREAD CONTEXT: Called from main loop thread only +// - Used for device configuration and control operations bool USBClient::control_transfer(uint8_t type, uint8_t request, uint16_t value, uint16_t index, const transfer_cb_t &callback, const std::vector &data) { auto *trq = this->get_trq_(); @@ -315,6 +437,7 @@ bool USBClient::control_transfer(uint8_t type, uint8_t request, uint16_t value, return true; } +// CALLBACK CONTEXT: USB task (called from usb_host_client_handle_events in USB task) static void transfer_callback(usb_transfer_t *xfer) { auto *trq = static_cast(xfer->context); trq->status.error_code = xfer->status; @@ -322,12 +445,21 @@ static void transfer_callback(usb_transfer_t *xfer) { trq->status.endpoint = xfer->bEndpointAddress; trq->status.data = xfer->data_buffer; trq->status.data_len = xfer->actual_num_bytes; - if (trq->callback != nullptr) + + // Always execute callback in USB task context + // Callbacks should be fast and non-blocking (e.g., copy data to queue) + if (trq->callback != nullptr) { trq->callback(trq->status); - trq->client->release_trq(trq); + } + + // Queue cleanup to main loop + queue_transfer_cleanup(trq, EVENT_TRANSFER_COMPLETE); } /** * Performs a transfer input operation. + * THREAD CONTEXT: Called from both USB task and main loop threads! + * - USB task: USB UART input callbacks call start_input() which calls this + * - Main loop: Initial setup and other components * * @param ep_address The endpoint address. * @param callback The callback function to be called when the transfer is complete. @@ -354,6 +486,9 @@ void USBClient::transfer_in(uint8_t ep_address, const transfer_cb_t &callback, u /** * Performs an output transfer operation. + * THREAD CONTEXT: Called from main loop thread only + * - USB UART output uses defer() to ensure main loop context + * - Modbus and other components call from loop() * * @param ep_address The endpoint address. * @param callback The callback function to be called when the transfer is complete. @@ -386,7 +521,28 @@ void USBClient::dump_config() { " Product id %04X", this->vid_, this->pid_); } -void USBClient::release_trq(TransferRequest *trq) { this->trq_pool_.push_back(trq); } +// THREAD CONTEXT: Only called from main loop thread (single producer for deallocation) +// - Via event processing when handling EVENT_TRANSFER_COMPLETE/EVENT_CONTROL_COMPLETE +// - Directly when transfer submission fails +// +// THREAD SAFETY: Lock-free using atomic AND to clear bit +// Single-producer pattern makes this simpler than allocation +void USBClient::release_trq(TransferRequest *trq) { + if (trq == nullptr) + return; + + // Calculate index from pointer arithmetic + size_t index = trq - this->requests_; + if (index >= MAX_REQUESTS) { + ESP_LOGE(TAG, "Invalid TransferRequest pointer"); + return; + } + + // Atomically clear bit i to mark slot as available + // fetch_and with inverted bitmask clears the bit atomically + uint16_t bit = 1U << index; + this->trq_in_use_.fetch_and(static_cast(~bit), std::memory_order_release); +} } // namespace usb_host } // namespace esphome diff --git a/esphome/components/usb_host/usb_host_component.cpp b/esphome/components/usb_host/usb_host_component.cpp index 682026a9c5..fb19239c73 100644 --- a/esphome/components/usb_host/usb_host_component.cpp +++ b/esphome/components/usb_host/usb_host_component.cpp @@ -1,5 +1,5 @@ // Should not be needed, but it's required to pass CI clang-tidy checks -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "usb_host.h" #include #include "esphome/core/log.h" diff --git a/esphome/components/usb_uart/__init__.py b/esphome/components/usb_uart/__init__.py index 6999b1b955..a852e1f78b 100644 --- a/esphome/components/usb_uart/__init__.py +++ b/esphome/components/usb_uart/__init__.py @@ -24,7 +24,6 @@ usb_uart_ns = cg.esphome_ns.namespace("usb_uart") USBUartComponent = usb_uart_ns.class_("USBUartComponent", Component) USBUartChannel = usb_uart_ns.class_("USBUartChannel", UARTComponent) - UARTParityOptions = usb_uart_ns.enum("UARTParityOptions") UART_PARITY_OPTIONS = { "NONE": UARTParityOptions.UART_CONFIG_PARITY_NONE, diff --git a/esphome/components/usb_uart/ch34x.cpp b/esphome/components/usb_uart/ch34x.cpp index 37cd33f841..889366b579 100644 --- a/esphome/components/usb_uart/ch34x.cpp +++ b/esphome/components/usb_uart/ch34x.cpp @@ -1,4 +1,4 @@ -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "usb_uart.h" #include "usb/usb_host.h" #include "esphome/core/log.h" @@ -16,12 +16,12 @@ using namespace bytebuffer; void USBUartTypeCH34X::enable_channels() { // enable the channels for (auto channel : this->channels_) { - if (!channel->initialised_) + if (!channel->initialised_.load()) continue; usb_host::transfer_cb_t callback = [=](const usb_host::TransferStatus &status) { if (!status.success) { ESP_LOGE(TAG, "Control transfer failed, status=%s", esp_err_to_name(status.error_code)); - channel->initialised_ = false; + channel->initialised_.store(false); } }; @@ -48,7 +48,7 @@ void USBUartTypeCH34X::enable_channels() { auto factor = static_cast(clk / baud_rate); if (factor == 0 || factor == 0xFF) { ESP_LOGE(TAG, "Invalid baud rate %" PRIu32, baud_rate); - channel->initialised_ = false; + channel->initialised_.store(false); continue; } if ((clk / factor - baud_rate) > (baud_rate - clk / (factor + 1))) diff --git a/esphome/components/usb_uart/cp210x.cpp b/esphome/components/usb_uart/cp210x.cpp index f7d60c307a..5fec0bed02 100644 --- a/esphome/components/usb_uart/cp210x.cpp +++ b/esphome/components/usb_uart/cp210x.cpp @@ -1,4 +1,4 @@ -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "usb_uart.h" #include "usb/usb_host.h" #include "esphome/core/log.h" @@ -100,12 +100,12 @@ std::vector USBUartTypeCP210X::parse_descriptors(usb_device_handle_t dev void USBUartTypeCP210X::enable_channels() { // enable the channels for (auto channel : this->channels_) { - if (!channel->initialised_) + if (!channel->initialised_.load()) continue; usb_host::transfer_cb_t callback = [=](const usb_host::TransferStatus &status) { if (!status.success) { ESP_LOGE(TAG, "Control transfer failed, status=%s", esp_err_to_name(status.error_code)); - channel->initialised_ = false; + channel->initialised_.store(false); } }; this->control_transfer(USB_VENDOR_IFC | usb_host::USB_DIR_OUT, IFC_ENABLE, 1, channel->index_, callback); diff --git a/esphome/components/usb_uart/usb_uart.cpp b/esphome/components/usb_uart/usb_uart.cpp index bf1c9086f1..29003e071e 100644 --- a/esphome/components/usb_uart/usb_uart.cpp +++ b/esphome/components/usb_uart/usb_uart.cpp @@ -1,5 +1,5 @@ // Should not be needed, but it's required to pass CI clang-tidy checks -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "usb_uart.h" #include "esphome/core/log.h" #include "esphome/components/uart/uart_debugger.h" @@ -130,7 +130,7 @@ size_t RingBuffer::pop(uint8_t *data, size_t len) { return len; } void USBUartChannel::write_array(const uint8_t *data, size_t len) { - if (!this->initialised_) { + if (!this->initialised_.load()) { ESP_LOGV(TAG, "Channel not initialised - write ignored"); return; } @@ -152,7 +152,7 @@ bool USBUartChannel::peek_byte(uint8_t *data) { return true; } bool USBUartChannel::read_array(uint8_t *data, size_t len) { - if (!this->initialised_) { + if (!this->initialised_.load()) { ESP_LOGV(TAG, "Channel not initialised - read ignored"); return false; } @@ -170,7 +170,34 @@ bool USBUartChannel::read_array(uint8_t *data, size_t len) { return status; } void USBUartComponent::setup() { USBClient::setup(); } -void USBUartComponent::loop() { USBClient::loop(); } +void USBUartComponent::loop() { + USBClient::loop(); + + // Process USB data from the lock-free queue + UsbDataChunk *chunk; + while ((chunk = this->usb_data_queue_.pop()) != nullptr) { + auto *channel = chunk->channel; + +#ifdef USE_UART_DEBUGGER + if (channel->debug_) { + uart::UARTDebug::log_hex(uart::UART_DIRECTION_RX, std::vector(chunk->data, chunk->data + chunk->length), + ','); // NOLINT() + } +#endif + + // Push data to ring buffer (now safe in main loop) + channel->input_buffer_.push(chunk->data, chunk->length); + + // Return chunk to pool for reuse + this->chunk_pool_.release(chunk); + } + + // Log dropped USB data periodically + uint16_t dropped = this->usb_data_queue_.get_and_reset_dropped_count(); + if (dropped > 0) { + ESP_LOGW(TAG, "Dropped %u USB data chunks due to buffer overflow", dropped); + } +} void USBUartComponent::dump_config() { USBClient::dump_config(); for (auto &channel : this->channels_) { @@ -187,49 +214,77 @@ void USBUartComponent::dump_config() { } } void USBUartComponent::start_input(USBUartChannel *channel) { - if (!channel->initialised_ || channel->input_started_ || - channel->input_buffer_.get_free_space() < channel->cdc_dev_.in_ep->wMaxPacketSize) + if (!channel->initialised_.load() || channel->input_started_.load()) return; + // THREAD CONTEXT: Called from both USB task and main loop threads + // - USB task: Immediate restart after successful transfer for continuous data flow + // - Main loop: Controlled restart after consuming data (backpressure mechanism) + // + // This dual-thread access is intentional for performance: + // - USB task restarts avoid context switch delays for high-speed data + // - Main loop restarts provide flow control when buffers are full + // + // The underlying transfer_in() uses lock-free atomic allocation from the + // TransferRequest pool, making this multi-threaded access safe const auto *ep = channel->cdc_dev_.in_ep; + // CALLBACK CONTEXT: This lambda is executed in USB task via transfer_callback auto callback = [this, channel](const usb_host::TransferStatus &status) { ESP_LOGV(TAG, "Transfer result: length: %u; status %X", status.data_len, status.error_code); if (!status.success) { ESP_LOGE(TAG, "Control transfer failed, status=%s", esp_err_to_name(status.error_code)); + // On failure, don't restart - let next read_array() trigger it + channel->input_started_.store(false); return; } -#ifdef USE_UART_DEBUGGER - if (channel->debug_) { - uart::UARTDebug::log_hex(uart::UART_DIRECTION_RX, - std::vector(status.data, status.data + status.data_len), ','); // NOLINT() - } -#endif - channel->input_started_ = false; - if (!channel->dummy_receiver_) { - for (size_t i = 0; i != status.data_len; i++) { - channel->input_buffer_.push(status.data[i]); + + if (!channel->dummy_receiver_ && status.data_len > 0) { + // Allocate a chunk from the pool + UsbDataChunk *chunk = this->chunk_pool_.allocate(); + if (chunk == nullptr) { + // No chunks available - queue is full or we're out of memory + this->usb_data_queue_.increment_dropped_count(); + // Mark input as not started so we can retry + channel->input_started_.store(false); + return; } + + // Copy data to chunk (this is fast, happens in USB task) + memcpy(chunk->data, status.data, status.data_len); + chunk->length = status.data_len; + chunk->channel = channel; + + // Push to lock-free queue for main loop processing + // Push always succeeds because pool size == queue size + this->usb_data_queue_.push(chunk); } - if (channel->input_buffer_.get_free_space() >= channel->cdc_dev_.in_ep->wMaxPacketSize) { - this->defer([this, channel] { this->start_input(channel); }); - } + + // On success, restart input immediately from USB task for performance + // The lock-free queue will handle backpressure + channel->input_started_.store(false); + this->start_input(channel); }; - channel->input_started_ = true; + channel->input_started_.store(true); this->transfer_in(ep->bEndpointAddress, callback, ep->wMaxPacketSize); } void USBUartComponent::start_output(USBUartChannel *channel) { - if (channel->output_started_) + // IMPORTANT: This function must only be called from the main loop! + // The output_buffer_ is not thread-safe and can only be accessed from main loop. + // USB callbacks use defer() to ensure this function runs in the correct context. + if (channel->output_started_.load()) return; if (channel->output_buffer_.is_empty()) { return; } const auto *ep = channel->cdc_dev_.out_ep; + // CALLBACK CONTEXT: This lambda is executed in USB task via transfer_callback auto callback = [this, channel](const usb_host::TransferStatus &status) { ESP_LOGV(TAG, "Output Transfer result: length: %u; status %X", status.data_len, status.error_code); - channel->output_started_ = false; + channel->output_started_.store(false); + // Defer restart to main loop (defer is thread-safe) this->defer([this, channel] { this->start_output(channel); }); }; - channel->output_started_ = true; + channel->output_started_.store(true); uint8_t data[ep->wMaxPacketSize]; auto len = channel->output_buffer_.pop(data, ep->wMaxPacketSize); this->transfer_out(ep->bEndpointAddress, callback, data, len); @@ -249,7 +304,8 @@ static void fix_mps(const usb_ep_desc_t *ep) { if (ep != nullptr) { auto *ep_mutable = const_cast(ep); if (ep->wMaxPacketSize > 64) { - ESP_LOGW(TAG, "Corrected MPS of EP %u from %u to 64", ep->bEndpointAddress, ep->wMaxPacketSize); + ESP_LOGW(TAG, "Corrected MPS of EP 0x%02X from %u to 64", static_cast(ep->bEndpointAddress & 0xFF), + ep->wMaxPacketSize); ep_mutable->wMaxPacketSize = 64; } } @@ -266,13 +322,13 @@ void USBUartTypeCdcAcm::on_connected() { for (auto *channel : this->channels_) { if (i == cdc_devs.size()) { ESP_LOGE(TAG, "No configuration found for channel %d", channel->index_); - this->status_set_warning(LOG_STR("No configuration found for channel")); + this->status_set_warning("No configuration found for channel"); break; } channel->cdc_dev_ = cdc_devs[i++]; fix_mps(channel->cdc_dev_.in_ep); fix_mps(channel->cdc_dev_.out_ep); - channel->initialised_ = true; + channel->initialised_.store(true); auto err = usb_host_interface_claim(this->handle_, this->device_handle_, channel->cdc_dev_.bulk_interface_number, 0); if (err != ESP_OK) { @@ -301,9 +357,9 @@ void USBUartTypeCdcAcm::on_disconnected() { usb_host_endpoint_flush(this->device_handle_, channel->cdc_dev_.notify_ep->bEndpointAddress); } usb_host_interface_release(this->handle_, this->device_handle_, channel->cdc_dev_.bulk_interface_number); - channel->initialised_ = false; - channel->input_started_ = false; - channel->output_started_ = false; + channel->initialised_.store(false); + channel->input_started_.store(false); + channel->output_started_.store(false); channel->input_buffer_.clear(); channel->output_buffer_.clear(); } @@ -312,10 +368,10 @@ void USBUartTypeCdcAcm::on_disconnected() { void USBUartTypeCdcAcm::enable_channels() { for (auto *channel : this->channels_) { - if (!channel->initialised_) + if (!channel->initialised_.load()) continue; - channel->input_started_ = false; - channel->output_started_ = false; + channel->input_started_.store(false); + channel->output_started_.store(false); this->start_input(channel); } } diff --git a/esphome/components/usb_uart/usb_uart.h b/esphome/components/usb_uart/usb_uart.h index a103c51add..a5e7905ac5 100644 --- a/esphome/components/usb_uart/usb_uart.h +++ b/esphome/components/usb_uart/usb_uart.h @@ -1,15 +1,19 @@ #pragma once -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32P4) #include "esphome/core/component.h" #include "esphome/core/helpers.h" #include "esphome/components/uart/uart_component.h" #include "esphome/components/usb_host/usb_host.h" +#include "esphome/core/lock_free_queue.h" +#include "esphome/core/event_pool.h" +#include namespace esphome { namespace usb_uart { class USBUartTypeCdcAcm; class USBUartComponent; +class USBUartChannel; static const char *const TAG = "usb_uart"; @@ -68,6 +72,17 @@ class RingBuffer { uint8_t *buffer_; }; +// Structure for queuing received USB data chunks +struct UsbDataChunk { + static constexpr size_t MAX_CHUNK_SIZE = 64; // USB packet size + uint8_t data[MAX_CHUNK_SIZE]; + uint8_t length; // Max 64 bytes, so uint8_t is sufficient + USBUartChannel *channel; + + // Required for EventPool - no cleanup needed for POD types + void release() {} +}; + class USBUartChannel : public uart::UARTComponent, public Parented { friend class USBUartComponent; friend class USBUartTypeCdcAcm; @@ -90,16 +105,20 @@ class USBUartChannel : public uart::UARTComponent, public Parenteddummy_receiver_ = dummy_receiver; } protected: - const uint8_t index_; + // Larger structures first for better alignment RingBuffer input_buffer_; RingBuffer output_buffer_; - UARTParityOptions parity_{UART_CONFIG_PARITY_NONE}; - bool input_started_{true}; - bool output_started_{true}; CdcEps cdc_dev_{}; + // Enum (likely 4 bytes) + UARTParityOptions parity_{UART_CONFIG_PARITY_NONE}; + // Group atomics together (each 1 byte) + std::atomic input_started_{true}; + std::atomic output_started_{true}; + std::atomic initialised_{false}; + // Group regular bytes together to minimize padding + const uint8_t index_; bool debug_{}; bool dummy_receiver_{}; - bool initialised_{}; }; class USBUartComponent : public usb_host::USBClient { @@ -115,6 +134,11 @@ class USBUartComponent : public usb_host::USBClient { void start_input(USBUartChannel *channel); void start_output(USBUartChannel *channel); + // Lock-free data transfer from USB task to main loop + static constexpr int USB_DATA_QUEUE_SIZE = 32; + LockFreeQueue usb_data_queue_; + EventPool chunk_pool_; + protected: std::vector channels_{}; }; diff --git a/esphome/components/valve/valve.cpp b/esphome/components/valve/valve.cpp index 0ee710fc02..b041fe8449 100644 --- a/esphome/components/valve/valve.cpp +++ b/esphome/components/valve/valve.cpp @@ -1,5 +1,6 @@ #include "valve.h" #include "esphome/core/log.h" +#include namespace esphome { namespace valve { diff --git a/esphome/components/veml7700/veml7700.cpp b/esphome/components/veml7700/veml7700.cpp index c3b601e288..eb286ba21b 100644 --- a/esphome/components/veml7700/veml7700.cpp +++ b/esphome/components/veml7700/veml7700.cpp @@ -1,6 +1,7 @@ #include "veml7700.h" #include "esphome/core/application.h" #include "esphome/core/log.h" +#include namespace esphome { namespace veml7700 { @@ -12,30 +13,30 @@ static float reduce_to_zero(float a, float b) { return (a > b) ? (a - b) : 0; } template T get_next(const T (&array)[size], const T val) { size_t i = 0; - size_t idx = -1; - while (idx == -1 && i < size) { + size_t idx = std::numeric_limits::max(); + while (idx == std::numeric_limits::max() && i < size) { if (array[i] == val) { idx = i; break; } i++; } - if (idx == -1 || i + 1 >= size) + if (idx == std::numeric_limits::max() || i + 1 >= size) return val; return array[i + 1]; } template T get_prev(const T (&array)[size], const T val) { size_t i = size - 1; - size_t idx = -1; - while (idx == -1 && i > 0) { + size_t idx = std::numeric_limits::max(); + while (idx == std::numeric_limits::max() && i > 0) { if (array[i] == val) { idx = i; break; } i--; } - if (idx == -1 || i == 0) + if (idx == std::numeric_limits::max() || i == 0) return val; return array[i - 1]; } diff --git a/esphome/components/version/version_text_sensor.cpp b/esphome/components/version/version_text_sensor.cpp index ed093595cc..65dbfd27cf 100644 --- a/esphome/components/version/version_text_sensor.cpp +++ b/esphome/components/version/version_text_sensor.cpp @@ -2,6 +2,7 @@ #include "esphome/core/log.h" #include "esphome/core/application.h" #include "esphome/core/version.h" +#include "esphome/core/helpers.h" namespace esphome { namespace version { @@ -12,7 +13,7 @@ void VersionTextSensor::setup() { if (this->hide_timestamp_) { this->publish_state(ESPHOME_VERSION); } else { - this->publish_state(ESPHOME_VERSION " " + App.get_compilation_time()); + this->publish_state(str_sprintf(ESPHOME_VERSION " %s", App.get_compilation_time().c_str())); } } float VersionTextSensor::get_setup_priority() const { return setup_priority::DATA; } diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index a0cf1a155b..7ece73994f 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -429,8 +429,9 @@ void VoiceAssistant::client_subscription(api::APIConnection *client, bool subscr if (this->api_client_ != nullptr) { ESP_LOGE(TAG, "Multiple API Clients attempting to connect to Voice Assistant"); - ESP_LOGE(TAG, "Current client: %s", this->api_client_->get_client_combined_info().c_str()); - ESP_LOGE(TAG, "New client: %s", client->get_client_combined_info().c_str()); + ESP_LOGE(TAG, "Current client: %s (%s)", this->api_client_->get_name().c_str(), + this->api_client_->get_peername().c_str()); + ESP_LOGE(TAG, "New client: %s (%s)", client->get_name().c_str(), client->get_peername().c_str()); return; } diff --git a/esphome/components/waveshare_epaper/waveshare_epaper.cpp b/esphome/components/waveshare_epaper/waveshare_epaper.cpp index 75c6b84b79..3510d157d6 100644 --- a/esphome/components/waveshare_epaper/waveshare_epaper.cpp +++ b/esphome/components/waveshare_epaper/waveshare_epaper.cpp @@ -2274,11 +2274,11 @@ void GDEW0154M09::clear_() { uint32_t pixsize = this->get_buffer_length_(); for (uint8_t j = 0; j < 2; j++) { this->command(CMD_DTM1_DATA_START_TRANS); - for (int count = 0; count < pixsize; count++) { + for (uint32_t count = 0; count < pixsize; count++) { this->data(0x00); } this->command(CMD_DTM2_DATA_START_TRANS2); - for (int count = 0; count < pixsize; count++) { + for (uint32_t count = 0; count < pixsize; count++) { this->data(0xff); } this->command(CMD_DISPLAY_REFRESH); @@ -2291,11 +2291,11 @@ void HOT GDEW0154M09::display() { this->init_internal_(); // "Mode 0 display" for now this->command(CMD_DTM1_DATA_START_TRANS); - for (int i = 0; i < this->get_buffer_length_(); i++) { + for (uint32_t i = 0; i < this->get_buffer_length_(); i++) { this->data(0xff); } this->command(CMD_DTM2_DATA_START_TRANS2); // write 'new' data to SRAM - for (int i = 0; i < this->get_buffer_length_(); i++) { + for (uint32_t i = 0; i < this->get_buffer_length_(); i++) { this->data(this->buffer_[i]); } this->command(CMD_DISPLAY_REFRESH); diff --git a/esphome/components/web_server/list_entities.cpp b/esphome/components/web_server/list_entities.cpp index fb02821760..3eb3764857 100644 --- a/esphome/components/web_server/list_entities.cpp +++ b/esphome/components/web_server/list_entities.cpp @@ -9,13 +9,12 @@ namespace esphome { namespace web_server { -#ifdef USE_ARDUINO +#ifdef USE_ESP32 +ListEntitiesIterator::ListEntitiesIterator(const WebServer *ws, AsyncEventSource *es) : web_server_(ws), events_(es) {} +#elif USE_ARDUINO ListEntitiesIterator::ListEntitiesIterator(const WebServer *ws, DeferredUpdateEventSource *es) : web_server_(ws), events_(es) {} #endif -#ifdef USE_ESP_IDF -ListEntitiesIterator::ListEntitiesIterator(const WebServer *ws, AsyncEventSource *es) : web_server_(ws), events_(es) {} -#endif ListEntitiesIterator::~ListEntitiesIterator() {} #ifdef USE_BINARY_SENSOR diff --git a/esphome/components/web_server/list_entities.h b/esphome/components/web_server/list_entities.h index ba81c70c86..43e1cc2544 100644 --- a/esphome/components/web_server/list_entities.h +++ b/esphome/components/web_server/list_entities.h @@ -5,25 +5,24 @@ #include "esphome/core/component.h" #include "esphome/core/component_iterator.h" namespace esphome { -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 namespace web_server_idf { class AsyncEventSource; } #endif namespace web_server { -#ifdef USE_ARDUINO +#if !defined(USE_ESP32) && defined(USE_ARDUINO) class DeferredUpdateEventSource; #endif class WebServer; class ListEntitiesIterator : public ComponentIterator { public: -#ifdef USE_ARDUINO - ListEntitiesIterator(const WebServer *ws, DeferredUpdateEventSource *es); -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 ListEntitiesIterator(const WebServer *ws, esphome::web_server_idf::AsyncEventSource *es); +#elif defined(USE_ARDUINO) + ListEntitiesIterator(const WebServer *ws, DeferredUpdateEventSource *es); #endif virtual ~ListEntitiesIterator(); #ifdef USE_BINARY_SENSOR @@ -90,11 +89,10 @@ class ListEntitiesIterator : public ComponentIterator { protected: const WebServer *web_server_; -#ifdef USE_ARDUINO - DeferredUpdateEventSource *events_; -#endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 esphome::web_server_idf::AsyncEventSource *events_; +#elif USE_ARDUINO + DeferredUpdateEventSource *events_; #endif }; diff --git a/esphome/components/web_server/ota/__init__.py b/esphome/components/web_server/ota/__init__.py index 22e56639e1..4a98db8877 100644 --- a/esphome/components/web_server/ota/__init__.py +++ b/esphome/components/web_server/ota/__init__.py @@ -29,5 +29,5 @@ async def to_code(config): await ota_to_code(var, config) await cg.register_component(var, config) cg.add_define("USE_WEBSERVER_OTA") - if CORE.using_esp_idf: + if CORE.is_esp32: add_idf_component(name="zorxx/multipart-parser", ref="1.0.1") diff --git a/esphome/components/web_server/ota/ota_web_server.cpp b/esphome/components/web_server/ota/ota_web_server.cpp index 672a9868c5..7929f3647f 100644 --- a/esphome/components/web_server/ota/ota_web_server.cpp +++ b/esphome/components/web_server/ota/ota_web_server.cpp @@ -17,6 +17,12 @@ #endif #endif // USE_ARDUINO +#if USE_ESP32 +using PlatformString = std::string; +#elif USE_ARDUINO +using PlatformString = String; +#endif + namespace esphome { namespace web_server { @@ -26,8 +32,8 @@ class OTARequestHandler : public AsyncWebHandler { public: OTARequestHandler(WebServerOTAComponent *parent) : parent_(parent) {} void handleRequest(AsyncWebServerRequest *request) override; - void handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len, - bool final) override; + void handleUpload(AsyncWebServerRequest *request, const PlatformString &filename, size_t index, uint8_t *data, + size_t len, bool final) override; bool canHandle(AsyncWebServerRequest *request) const override { // Check if this is an OTA update request bool is_ota_request = request->url() == "/update" && request->method() == HTTP_POST; @@ -100,7 +106,7 @@ void OTARequestHandler::ota_init_(const char *filename) { this->ota_success_ = false; } -void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, +void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const PlatformString &filename, size_t index, uint8_t *data, size_t len, bool final) { ota::OTAResponseTypes error_code = ota::OTA_RESPONSE_OK; diff --git a/esphome/components/web_server/web_server.cpp b/esphome/components/web_server/web_server.cpp index 290992b096..cfd5fc947b 100644 --- a/esphome/components/web_server/web_server.cpp +++ b/esphome/components/web_server/web_server.cpp @@ -8,7 +8,7 @@ #include "esphome/core/log.h" #include "esphome/core/util.h" -#ifdef USE_ARDUINO +#if !defined(USE_ESP32) && defined(USE_ARDUINO) #include "StreamString.h" #endif @@ -103,19 +103,19 @@ static UrlMatch match_url(const char *url_ptr, size_t url_len, bool only_domain) return match; } -#ifdef USE_ARDUINO +#if !defined(USE_ESP32) && defined(USE_ARDUINO) // helper for allowing only unique entries in the queue void DeferredUpdateEventSource::deq_push_back_with_dedup_(void *source, message_generator_t *message_generator) { DeferredEvent item(source, message_generator); - auto iter = std::find_if(this->deferred_queue_.begin(), this->deferred_queue_.end(), - [&item](const DeferredEvent &test) -> bool { return test == item; }); - - if (iter != this->deferred_queue_.end()) { - (*iter) = item; - } else { - this->deferred_queue_.push_back(item); + // Use range-based for loop instead of std::find_if to reduce template instantiation overhead and binary size + for (auto &event : this->deferred_queue_) { + if (event == item) { + event = item; + return; + } } + this->deferred_queue_.push_back(item); } void DeferredUpdateEventSource::process_deferred_queue_() { @@ -127,6 +127,10 @@ void DeferredUpdateEventSource::process_deferred_queue_() { deferred_queue_.erase(deferred_queue_.begin()); this->consecutive_send_failures_ = 0; // Reset failure count on successful send } else { + // NOTE: Similar logic exists in web_server_idf/web_server_idf.cpp in AsyncEventSourceResponse::process_buffer_() + // The implementations differ due to platform-specific APIs (DISCARDED vs HTTPD_SOCK_ERR_TIMEOUT, close() vs + // fd_.store(0)), but the failure counting and timeout logic should be kept in sync. If you change this logic, + // also update the ESP-IDF implementation. this->consecutive_send_failures_++; if (this->consecutive_send_failures_ >= MAX_CONSECUTIVE_SEND_FAILURES) { // Too many failures, connection is likely dead @@ -228,10 +232,11 @@ void DeferredUpdateEventSourceList::on_client_connect_(WebServer *ws, DeferredUp #ifdef USE_WEBSERVER_SORTING for (auto &group : ws->sorting_groups_) { - message = json::build_json([group](JsonObject root) { - root["name"] = group.second.name; - root["sorting_weight"] = group.second.weight; - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + root["name"] = group.second.name; + root["sorting_weight"] = group.second.weight; + message = builder.serialize(); // up to 31 groups should be able to be queued initially without defer source->try_send_nodefer(message.c_str(), "sorting_group"); @@ -265,17 +270,20 @@ void WebServer::set_js_include(const char *js_include) { this->js_include_ = js_ #endif std::string WebServer::get_config_json() { - return json::build_json([this](JsonObject root) { - root["title"] = App.get_friendly_name().empty() ? App.get_name() : App.get_friendly_name(); - root["comment"] = App.get_comment(); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + root["title"] = App.get_friendly_name().empty() ? App.get_name() : App.get_friendly_name(); + root["comment"] = App.get_comment(); #if defined(USE_WEBSERVER_OTA_DISABLED) || !defined(USE_WEBSERVER_OTA) - root["ota"] = false; // Note: USE_WEBSERVER_OTA_DISABLED only affects web_server, not captive_portal + root["ota"] = false; // Note: USE_WEBSERVER_OTA_DISABLED only affects web_server, not captive_portal #else - root["ota"] = true; + root["ota"] = true; #endif - root["log"] = this->expose_log_; - root["lang"] = "en"; - }); + root["log"] = this->expose_log_; + root["lang"] = "en"; + + return builder.serialize(); } void WebServer::setup() { @@ -293,7 +301,7 @@ void WebServer::setup() { } #endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 this->base_->add_handler(&this->events_); #endif this->base_->add_handler(this); @@ -377,11 +385,14 @@ void WebServer::handle_js_request(AsyncWebServerRequest *request) { #endif // Helper functions to reduce code size by avoiding macro expansion -static void set_json_id(JsonObject &root, EntityBase *obj, const std::string &id, JsonDetail start_config) { - root["id"] = id; +static void set_json_id(JsonObject &root, EntityBase *obj, const char *prefix, JsonDetail start_config) { + char id_buf[160]; // object_id can be up to 128 chars + prefix + dash + null + const auto &object_id = obj->get_object_id(); + snprintf(id_buf, sizeof(id_buf), "%s-%s", prefix, object_id.c_str()); + root["id"] = id_buf; if (start_config == DETAIL_ALL) { root["name"] = obj->get_name(); - root["icon"] = obj->get_icon(); + root["icon"] = obj->get_icon_ref(); root["entity_category"] = obj->get_entity_category(); bool is_disabled = obj->is_disabled_by_default(); if (is_disabled) @@ -389,17 +400,19 @@ static void set_json_id(JsonObject &root, EntityBase *obj, const std::string &id } } +// Keep as separate function even though only used once: reduces code size by ~48 bytes +// by allowing compiler to share code between template instantiations (bool, float, etc.) template -static void set_json_value(JsonObject &root, EntityBase *obj, const std::string &id, const T &value, +static void set_json_value(JsonObject &root, EntityBase *obj, const char *prefix, const T &value, JsonDetail start_config) { - set_json_id(root, obj, id, start_config); + set_json_id(root, obj, prefix, start_config); root["value"] = value; } template -static void set_json_icon_state_value(JsonObject &root, EntityBase *obj, const std::string &id, - const std::string &state, const T &value, JsonDetail start_config) { - set_json_value(root, obj, id, value, start_config); +static void set_json_icon_state_value(JsonObject &root, EntityBase *obj, const char *prefix, const std::string &state, + const T &value, JsonDetail start_config) { + set_json_value(root, obj, prefix, value, start_config); root["state"] = state; } @@ -435,22 +448,26 @@ std::string WebServer::sensor_all_json_generator(WebServer *web_server, void *so return web_server->sensor_json((sensor::Sensor *) (source), ((sensor::Sensor *) (source))->state, DETAIL_ALL); } std::string WebServer::sensor_json(sensor::Sensor *obj, float value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - std::string state; - if (std::isnan(value)) { - state = "NA"; - } else { - state = value_accuracy_to_string(value, obj->get_accuracy_decimals()); - if (!obj->get_unit_of_measurement().empty()) - state += " " + obj->get_unit_of_measurement(); - } - set_json_icon_state_value(root, obj, "sensor-" + obj->get_object_id(), state, value, start_config); - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - if (!obj->get_unit_of_measurement().empty()) - root["uom"] = obj->get_unit_of_measurement(); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + const auto uom_ref = obj->get_unit_of_measurement_ref(); + + // Build JSON directly inline + std::string state; + if (std::isnan(value)) { + state = "NA"; + } else { + state = value_accuracy_with_uom_to_string(value, obj->get_accuracy_decimals(), uom_ref); + } + set_json_icon_state_value(root, obj, "sensor", state, value, start_config); + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + if (!uom_ref.empty()) + root["uom"] = uom_ref; + } + + return builder.serialize(); } #endif @@ -483,12 +500,15 @@ std::string WebServer::text_sensor_all_json_generator(WebServer *web_server, voi } std::string WebServer::text_sensor_json(text_sensor::TextSensor *obj, const std::string &value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "text_sensor-" + obj->get_object_id(), value, value, start_config); - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_icon_state_value(root, obj, "text_sensor", value, value, start_config); + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -553,13 +573,16 @@ std::string WebServer::switch_all_json_generator(WebServer *web_server, void *so return web_server->switch_json((switch_::Switch *) (source), ((switch_::Switch *) (source))->state, DETAIL_ALL); } std::string WebServer::switch_json(switch_::Switch *obj, bool value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "switch-" + obj->get_object_id(), value ? "ON" : "OFF", value, start_config); - if (start_config == DETAIL_ALL) { - root["assumed_state"] = obj->assumed_state(); - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_icon_state_value(root, obj, "switch", value ? "ON" : "OFF", value, start_config); + if (start_config == DETAIL_ALL) { + root["assumed_state"] = obj->assumed_state(); + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -590,12 +613,15 @@ std::string WebServer::button_all_json_generator(WebServer *web_server, void *so return web_server->button_json((button::Button *) (source), DETAIL_ALL); } std::string WebServer::button_json(button::Button *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "button-" + obj->get_object_id(), start_config); - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "button", start_config); + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -627,13 +653,15 @@ std::string WebServer::binary_sensor_all_json_generator(WebServer *web_server, v ((binary_sensor::BinarySensor *) (source))->state, DETAIL_ALL); } std::string WebServer::binary_sensor_json(binary_sensor::BinarySensor *obj, bool value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "binary_sensor-" + obj->get_object_id(), value ? "ON" : "OFF", value, - start_config); - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_icon_state_value(root, obj, "binary_sensor", value ? "ON" : "OFF", value, start_config); + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -694,20 +722,22 @@ std::string WebServer::fan_all_json_generator(WebServer *web_server, void *sourc return web_server->fan_json((fan::Fan *) (source), DETAIL_ALL); } std::string WebServer::fan_json(fan::Fan *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "fan-" + obj->get_object_id(), obj->state ? "ON" : "OFF", obj->state, - start_config); - const auto traits = obj->get_traits(); - if (traits.supports_speed()) { - root["speed_level"] = obj->speed; - root["speed_count"] = traits.supported_speed_count(); - } - if (obj->get_traits().supports_oscillation()) - root["oscillation"] = obj->oscillating; - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_icon_state_value(root, obj, "fan", obj->state ? "ON" : "OFF", obj->state, start_config); + const auto traits = obj->get_traits(); + if (traits.supports_speed()) { + root["speed_level"] = obj->speed; + root["speed_count"] = traits.supported_speed_count(); + } + if (obj->get_traits().supports_oscillation()) + root["oscillation"] = obj->oscillating; + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -767,20 +797,23 @@ std::string WebServer::light_all_json_generator(WebServer *web_server, void *sou return web_server->light_json((light::LightState *) (source), DETAIL_ALL); } std::string WebServer::light_json(light::LightState *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "light-" + obj->get_object_id(), start_config); - root["state"] = obj->remote_values.is_on() ? "ON" : "OFF"; + json::JsonBuilder builder; + JsonObject root = builder.root(); - light::LightJSONSchema::dump_json(*obj, root); - if (start_config == DETAIL_ALL) { - JsonArray opt = root["effects"].to(); - opt.add("None"); - for (auto const &option : obj->get_effects()) { - opt.add(option->get_name()); - } - this->add_sorting_info_(root, obj); + set_json_id(root, obj, "light", start_config); + root["state"] = obj->remote_values.is_on() ? "ON" : "OFF"; + + light::LightJSONSchema::dump_json(*obj, root); + if (start_config == DETAIL_ALL) { + JsonArray opt = root["effects"].to(); + opt.add("None"); + for (auto const &option : obj->get_effects()) { + opt.add(option->get_name()); } - }); + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -803,15 +836,28 @@ void WebServer::handle_cover_request(AsyncWebServerRequest *request, const UrlMa } auto call = obj->make_call(); - if (match.method_equals("open")) { - call.set_command_open(); - } else if (match.method_equals("close")) { - call.set_command_close(); - } else if (match.method_equals("stop")) { - call.set_command_stop(); - } else if (match.method_equals("toggle")) { - call.set_command_toggle(); - } else if (!match.method_equals("set")) { + + // Lookup table for cover methods + static const struct { + const char *name; + cover::CoverCall &(cover::CoverCall::*action)(); + } METHODS[] = { + {"open", &cover::CoverCall::set_command_open}, + {"close", &cover::CoverCall::set_command_close}, + {"stop", &cover::CoverCall::set_command_stop}, + {"toggle", &cover::CoverCall::set_command_toggle}, + }; + + bool found = false; + for (const auto &method : METHODS) { + if (match.method_equals(method.name)) { + (call.*method.action)(); + found = true; + break; + } + } + + if (!found && !match.method_equals("set")) { request->send(404); return; } @@ -839,19 +885,22 @@ std::string WebServer::cover_all_json_generator(WebServer *web_server, void *sou return web_server->cover_json((cover::Cover *) (source), DETAIL_ALL); } std::string WebServer::cover_json(cover::Cover *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "cover-" + obj->get_object_id(), obj->is_fully_closed() ? "CLOSED" : "OPEN", - obj->position, start_config); - root["current_operation"] = cover::cover_operation_to_str(obj->current_operation); + json::JsonBuilder builder; + JsonObject root = builder.root(); - if (obj->get_traits().get_supports_position()) - root["position"] = obj->position; - if (obj->get_traits().get_supports_tilt()) - root["tilt"] = obj->tilt; - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + set_json_icon_state_value(root, obj, "cover", obj->is_fully_closed() ? "CLOSED" : "OPEN", obj->position, + start_config); + root["current_operation"] = cover::cover_operation_to_str(obj->current_operation); + + if (obj->get_traits().get_supports_position()) + root["position"] = obj->position; + if (obj->get_traits().get_supports_tilt()) + root["tilt"] = obj->tilt; + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -894,31 +943,33 @@ std::string WebServer::number_all_json_generator(WebServer *web_server, void *so return web_server->number_json((number::Number *) (source), ((number::Number *) (source))->state, DETAIL_ALL); } std::string WebServer::number_json(number::Number *obj, float value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_id(root, obj, "number-" + obj->get_object_id(), start_config); - if (start_config == DETAIL_ALL) { - root["min_value"] = - value_accuracy_to_string(obj->traits.get_min_value(), step_to_accuracy_decimals(obj->traits.get_step())); - root["max_value"] = - value_accuracy_to_string(obj->traits.get_max_value(), step_to_accuracy_decimals(obj->traits.get_step())); - root["step"] = - value_accuracy_to_string(obj->traits.get_step(), step_to_accuracy_decimals(obj->traits.get_step())); - root["mode"] = (int) obj->traits.get_mode(); - if (!obj->traits.get_unit_of_measurement().empty()) - root["uom"] = obj->traits.get_unit_of_measurement(); - this->add_sorting_info_(root, obj); - } - if (std::isnan(value)) { - root["value"] = "\"NaN\""; - root["state"] = "NA"; - } else { - root["value"] = value_accuracy_to_string(value, step_to_accuracy_decimals(obj->traits.get_step())); - std::string state = value_accuracy_to_string(value, step_to_accuracy_decimals(obj->traits.get_step())); - if (!obj->traits.get_unit_of_measurement().empty()) - state += " " + obj->traits.get_unit_of_measurement(); - root["state"] = state; - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + const auto uom_ref = obj->traits.get_unit_of_measurement_ref(); + + set_json_id(root, obj, "number", start_config); + if (start_config == DETAIL_ALL) { + root["min_value"] = + value_accuracy_to_string(obj->traits.get_min_value(), step_to_accuracy_decimals(obj->traits.get_step())); + root["max_value"] = + value_accuracy_to_string(obj->traits.get_max_value(), step_to_accuracy_decimals(obj->traits.get_step())); + root["step"] = value_accuracy_to_string(obj->traits.get_step(), step_to_accuracy_decimals(obj->traits.get_step())); + root["mode"] = (int) obj->traits.get_mode(); + if (!uom_ref.empty()) + root["uom"] = uom_ref; + this->add_sorting_info_(root, obj); + } + if (std::isnan(value)) { + root["value"] = "\"NaN\""; + root["state"] = "NA"; + } else { + root["value"] = value_accuracy_to_string(value, step_to_accuracy_decimals(obj->traits.get_step())); + root["state"] = + value_accuracy_with_uom_to_string(value, step_to_accuracy_decimals(obj->traits.get_step()), uom_ref); + } + + return builder.serialize(); } #endif @@ -966,15 +1017,18 @@ std::string WebServer::date_all_json_generator(WebServer *web_server, void *sour return web_server->date_json((datetime::DateEntity *) (source), DETAIL_ALL); } std::string WebServer::date_json(datetime::DateEntity *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "date-" + obj->get_object_id(), start_config); - std::string value = str_sprintf("%d-%02d-%02d", obj->year, obj->month, obj->day); - root["value"] = value; - root["state"] = value; - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "date", start_config); + std::string value = str_sprintf("%d-%02d-%02d", obj->year, obj->month, obj->day); + root["value"] = value; + root["state"] = value; + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif // USE_DATETIME_DATE @@ -1021,15 +1075,18 @@ std::string WebServer::time_all_json_generator(WebServer *web_server, void *sour return web_server->time_json((datetime::TimeEntity *) (source), DETAIL_ALL); } std::string WebServer::time_json(datetime::TimeEntity *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "time-" + obj->get_object_id(), start_config); - std::string value = str_sprintf("%02d:%02d:%02d", obj->hour, obj->minute, obj->second); - root["value"] = value; - root["state"] = value; - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "time", start_config); + std::string value = str_sprintf("%02d:%02d:%02d", obj->hour, obj->minute, obj->second); + root["value"] = value; + root["state"] = value; + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif // USE_DATETIME_TIME @@ -1076,16 +1133,19 @@ std::string WebServer::datetime_all_json_generator(WebServer *web_server, void * return web_server->datetime_json((datetime::DateTimeEntity *) (source), DETAIL_ALL); } std::string WebServer::datetime_json(datetime::DateTimeEntity *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "datetime-" + obj->get_object_id(), start_config); - std::string value = str_sprintf("%d-%02d-%02d %02d:%02d:%02d", obj->year, obj->month, obj->day, obj->hour, - obj->minute, obj->second); - root["value"] = value; - root["state"] = value; - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "datetime", start_config); + std::string value = + str_sprintf("%d-%02d-%02d %02d:%02d:%02d", obj->year, obj->month, obj->day, obj->hour, obj->minute, obj->second); + root["value"] = value; + root["state"] = value; + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif // USE_DATETIME_DATETIME @@ -1128,22 +1188,25 @@ std::string WebServer::text_all_json_generator(WebServer *web_server, void *sour return web_server->text_json((text::Text *) (source), ((text::Text *) (source))->state, DETAIL_ALL); } std::string WebServer::text_json(text::Text *obj, const std::string &value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_id(root, obj, "text-" + obj->get_object_id(), start_config); - root["min_length"] = obj->traits.get_min_length(); - root["max_length"] = obj->traits.get_max_length(); - root["pattern"] = obj->traits.get_pattern(); - if (obj->traits.get_mode() == text::TextMode::TEXT_MODE_PASSWORD) { - root["state"] = "********"; - } else { - root["state"] = value; - } - root["value"] = value; - if (start_config == DETAIL_ALL) { - root["mode"] = (int) obj->traits.get_mode(); - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "text", start_config); + root["min_length"] = obj->traits.get_min_length(); + root["max_length"] = obj->traits.get_max_length(); + root["pattern"] = obj->traits.get_pattern(); + if (obj->traits.get_mode() == text::TextMode::TEXT_MODE_PASSWORD) { + root["state"] = "********"; + } else { + root["state"] = value; + } + root["value"] = value; + if (start_config == DETAIL_ALL) { + root["mode"] = (int) obj->traits.get_mode(); + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -1186,21 +1249,24 @@ std::string WebServer::select_all_json_generator(WebServer *web_server, void *so return web_server->select_json((select::Select *) (source), ((select::Select *) (source))->state, DETAIL_ALL); } std::string WebServer::select_json(select::Select *obj, const std::string &value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "select-" + obj->get_object_id(), value, value, start_config); - if (start_config == DETAIL_ALL) { - JsonArray opt = root["option"].to(); - for (auto &option : obj->traits.get_options()) { - opt.add(option); - } - this->add_sorting_info_(root, obj); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_icon_state_value(root, obj, "select", value, value, start_config); + if (start_config == DETAIL_ALL) { + JsonArray opt = root["option"].to(); + for (auto &option : obj->traits.get_options()) { + opt.add(option); } - }); + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif // Longest: HORIZONTAL -#define PSTR_LOCAL(mode_s) strncpy_P(buf, (PGM_P) ((mode_s)), 15) +#define PSTR_LOCAL(mode_s) ESPHOME_strncpy_P(buf, (ESPHOME_PGM_P) ((mode_s)), 15) #ifdef USE_CLIMATE void WebServer::on_climate_update(climate::Climate *obj) { @@ -1244,98 +1310,102 @@ void WebServer::handle_climate_request(AsyncWebServerRequest *request, const Url request->send(404); } std::string WebServer::climate_state_json_generator(WebServer *web_server, void *source) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson return web_server->climate_json((climate::Climate *) (source), DETAIL_STATE); } std::string WebServer::climate_all_json_generator(WebServer *web_server, void *source) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson return web_server->climate_json((climate::Climate *) (source), DETAIL_ALL); } std::string WebServer::climate_json(climate::Climate *obj, JsonDetail start_config) { // NOLINTBEGIN(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "climate-" + obj->get_object_id(), start_config); - const auto traits = obj->get_traits(); - int8_t target_accuracy = traits.get_target_temperature_accuracy_decimals(); - int8_t current_accuracy = traits.get_current_temperature_accuracy_decimals(); - char buf[16]; + json::JsonBuilder builder; + JsonObject root = builder.root(); + set_json_id(root, obj, "climate", start_config); + const auto traits = obj->get_traits(); + int8_t target_accuracy = traits.get_target_temperature_accuracy_decimals(); + int8_t current_accuracy = traits.get_current_temperature_accuracy_decimals(); + char buf[16]; - if (start_config == DETAIL_ALL) { - JsonArray opt = root["modes"].to(); - for (climate::ClimateMode m : traits.get_supported_modes()) - opt.add(PSTR_LOCAL(climate::climate_mode_to_string(m))); - if (!traits.get_supported_custom_fan_modes().empty()) { - JsonArray opt = root["fan_modes"].to(); - for (climate::ClimateFanMode m : traits.get_supported_fan_modes()) - opt.add(PSTR_LOCAL(climate::climate_fan_mode_to_string(m))); - } - - if (!traits.get_supported_custom_fan_modes().empty()) { - JsonArray opt = root["custom_fan_modes"].to(); - for (auto const &custom_fan_mode : traits.get_supported_custom_fan_modes()) - opt.add(custom_fan_mode); - } - if (traits.get_supports_swing_modes()) { - JsonArray opt = root["swing_modes"].to(); - for (auto swing_mode : traits.get_supported_swing_modes()) - opt.add(PSTR_LOCAL(climate::climate_swing_mode_to_string(swing_mode))); - } - if (traits.get_supports_presets() && obj->preset.has_value()) { - JsonArray opt = root["presets"].to(); - for (climate::ClimatePreset m : traits.get_supported_presets()) - opt.add(PSTR_LOCAL(climate::climate_preset_to_string(m))); - } - if (!traits.get_supported_custom_presets().empty() && obj->custom_preset.has_value()) { - JsonArray opt = root["custom_presets"].to(); - for (auto const &custom_preset : traits.get_supported_custom_presets()) - opt.add(custom_preset); - } - this->add_sorting_info_(root, obj); + if (start_config == DETAIL_ALL) { + JsonArray opt = root["modes"].to(); + for (climate::ClimateMode m : traits.get_supported_modes()) + opt.add(PSTR_LOCAL(climate::climate_mode_to_string(m))); + if (!traits.get_supported_custom_fan_modes().empty()) { + JsonArray opt = root["fan_modes"].to(); + for (climate::ClimateFanMode m : traits.get_supported_fan_modes()) + opt.add(PSTR_LOCAL(climate::climate_fan_mode_to_string(m))); } - bool has_state = false; - root["mode"] = PSTR_LOCAL(climate_mode_to_string(obj->mode)); - root["max_temp"] = value_accuracy_to_string(traits.get_visual_max_temperature(), target_accuracy); - root["min_temp"] = value_accuracy_to_string(traits.get_visual_min_temperature(), target_accuracy); - root["step"] = traits.get_visual_target_temperature_step(); - if (traits.get_supports_action()) { - root["action"] = PSTR_LOCAL(climate_action_to_string(obj->action)); - root["state"] = root["action"]; - has_state = true; - } - if (traits.get_supports_fan_modes() && obj->fan_mode.has_value()) { - root["fan_mode"] = PSTR_LOCAL(climate_fan_mode_to_string(obj->fan_mode.value())); - } - if (!traits.get_supported_custom_fan_modes().empty() && obj->custom_fan_mode.has_value()) { - root["custom_fan_mode"] = obj->custom_fan_mode.value().c_str(); - } - if (traits.get_supports_presets() && obj->preset.has_value()) { - root["preset"] = PSTR_LOCAL(climate_preset_to_string(obj->preset.value())); - } - if (!traits.get_supported_custom_presets().empty() && obj->custom_preset.has_value()) { - root["custom_preset"] = obj->custom_preset.value().c_str(); + if (!traits.get_supported_custom_fan_modes().empty()) { + JsonArray opt = root["custom_fan_modes"].to(); + for (auto const &custom_fan_mode : traits.get_supported_custom_fan_modes()) + opt.add(custom_fan_mode); } if (traits.get_supports_swing_modes()) { - root["swing_mode"] = PSTR_LOCAL(climate_swing_mode_to_string(obj->swing_mode)); + JsonArray opt = root["swing_modes"].to(); + for (auto swing_mode : traits.get_supported_swing_modes()) + opt.add(PSTR_LOCAL(climate::climate_swing_mode_to_string(swing_mode))); } - if (traits.get_supports_current_temperature()) { - if (!std::isnan(obj->current_temperature)) { - root["current_temperature"] = value_accuracy_to_string(obj->current_temperature, current_accuracy); - } else { - root["current_temperature"] = "NA"; - } + if (traits.get_supports_presets() && obj->preset.has_value()) { + JsonArray opt = root["presets"].to(); + for (climate::ClimatePreset m : traits.get_supported_presets()) + opt.add(PSTR_LOCAL(climate::climate_preset_to_string(m))); } - if (traits.get_supports_two_point_target_temperature()) { - root["target_temperature_low"] = value_accuracy_to_string(obj->target_temperature_low, target_accuracy); - root["target_temperature_high"] = value_accuracy_to_string(obj->target_temperature_high, target_accuracy); - if (!has_state) { - root["state"] = value_accuracy_to_string((obj->target_temperature_high + obj->target_temperature_low) / 2.0f, - target_accuracy); - } + if (!traits.get_supported_custom_presets().empty() && obj->custom_preset.has_value()) { + JsonArray opt = root["custom_presets"].to(); + for (auto const &custom_preset : traits.get_supported_custom_presets()) + opt.add(custom_preset); + } + this->add_sorting_info_(root, obj); + } + + bool has_state = false; + root["mode"] = PSTR_LOCAL(climate_mode_to_string(obj->mode)); + root["max_temp"] = value_accuracy_to_string(traits.get_visual_max_temperature(), target_accuracy); + root["min_temp"] = value_accuracy_to_string(traits.get_visual_min_temperature(), target_accuracy); + root["step"] = traits.get_visual_target_temperature_step(); + if (traits.get_supports_action()) { + root["action"] = PSTR_LOCAL(climate_action_to_string(obj->action)); + root["state"] = root["action"]; + has_state = true; + } + if (traits.get_supports_fan_modes() && obj->fan_mode.has_value()) { + root["fan_mode"] = PSTR_LOCAL(climate_fan_mode_to_string(obj->fan_mode.value())); + } + if (!traits.get_supported_custom_fan_modes().empty() && obj->custom_fan_mode.has_value()) { + root["custom_fan_mode"] = obj->custom_fan_mode.value().c_str(); + } + if (traits.get_supports_presets() && obj->preset.has_value()) { + root["preset"] = PSTR_LOCAL(climate_preset_to_string(obj->preset.value())); + } + if (!traits.get_supported_custom_presets().empty() && obj->custom_preset.has_value()) { + root["custom_preset"] = obj->custom_preset.value().c_str(); + } + if (traits.get_supports_swing_modes()) { + root["swing_mode"] = PSTR_LOCAL(climate_swing_mode_to_string(obj->swing_mode)); + } + if (traits.get_supports_current_temperature()) { + if (!std::isnan(obj->current_temperature)) { + root["current_temperature"] = value_accuracy_to_string(obj->current_temperature, current_accuracy); } else { - root["target_temperature"] = value_accuracy_to_string(obj->target_temperature, target_accuracy); - if (!has_state) - root["state"] = root["target_temperature"]; + root["current_temperature"] = "NA"; } - }); + } + if (traits.get_supports_two_point_target_temperature()) { + root["target_temperature_low"] = value_accuracy_to_string(obj->target_temperature_low, target_accuracy); + root["target_temperature_high"] = value_accuracy_to_string(obj->target_temperature_high, target_accuracy); + if (!has_state) { + root["state"] = value_accuracy_to_string((obj->target_temperature_high + obj->target_temperature_low) / 2.0f, + target_accuracy); + } + } else { + root["target_temperature"] = value_accuracy_to_string(obj->target_temperature, target_accuracy); + if (!has_state) + root["state"] = root["target_temperature"]; + } + + return builder.serialize(); // NOLINTEND(clang-analyzer-cplusplus.NewDeleteLeaks) } #endif @@ -1401,13 +1471,15 @@ std::string WebServer::lock_all_json_generator(WebServer *web_server, void *sour return web_server->lock_json((lock::Lock *) (source), ((lock::Lock *) (source))->state, DETAIL_ALL); } std::string WebServer::lock_json(lock::Lock *obj, lock::LockState value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "lock-" + obj->get_object_id(), lock::lock_state_to_string(value), value, - start_config); - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_icon_state_value(root, obj, "lock", lock::lock_state_to_string(value), value, start_config); + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -1430,15 +1502,28 @@ void WebServer::handle_valve_request(AsyncWebServerRequest *request, const UrlMa } auto call = obj->make_call(); - if (match.method_equals("open")) { - call.set_command_open(); - } else if (match.method_equals("close")) { - call.set_command_close(); - } else if (match.method_equals("stop")) { - call.set_command_stop(); - } else if (match.method_equals("toggle")) { - call.set_command_toggle(); - } else if (!match.method_equals("set")) { + + // Lookup table for valve methods + static const struct { + const char *name; + valve::ValveCall &(valve::ValveCall::*action)(); + } METHODS[] = { + {"open", &valve::ValveCall::set_command_open}, + {"close", &valve::ValveCall::set_command_close}, + {"stop", &valve::ValveCall::set_command_stop}, + {"toggle", &valve::ValveCall::set_command_toggle}, + }; + + bool found = false; + for (const auto &method : METHODS) { + if (match.method_equals(method.name)) { + (call.*method.action)(); + found = true; + break; + } + } + + if (!found && !match.method_equals("set")) { request->send(404); return; } @@ -1464,17 +1549,20 @@ std::string WebServer::valve_all_json_generator(WebServer *web_server, void *sou return web_server->valve_json((valve::Valve *) (source), DETAIL_ALL); } std::string WebServer::valve_json(valve::Valve *obj, JsonDetail start_config) { - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_icon_state_value(root, obj, "valve-" + obj->get_object_id(), obj->is_fully_closed() ? "CLOSED" : "OPEN", - obj->position, start_config); - root["current_operation"] = valve::valve_operation_to_str(obj->current_operation); + json::JsonBuilder builder; + JsonObject root = builder.root(); - if (obj->get_traits().get_supports_position()) - root["position"] = obj->position; - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + set_json_icon_state_value(root, obj, "valve", obj->is_fully_closed() ? "CLOSED" : "OPEN", obj->position, + start_config); + root["current_operation"] = valve::valve_operation_to_str(obj->current_operation); + + if (obj->get_traits().get_supports_position()) + root["position"] = obj->position; + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -1499,17 +1587,28 @@ void WebServer::handle_alarm_control_panel_request(AsyncWebServerRequest *reques auto call = obj->make_call(); parse_string_param_(request, "code", call, &decltype(call)::set_code); - if (match.method_equals("disarm")) { - call.disarm(); - } else if (match.method_equals("arm_away")) { - call.arm_away(); - } else if (match.method_equals("arm_home")) { - call.arm_home(); - } else if (match.method_equals("arm_night")) { - call.arm_night(); - } else if (match.method_equals("arm_vacation")) { - call.arm_vacation(); - } else { + // Lookup table for alarm control panel methods + static const struct { + const char *name; + alarm_control_panel::AlarmControlPanelCall &(alarm_control_panel::AlarmControlPanelCall::*action)(); + } METHODS[] = { + {"disarm", &alarm_control_panel::AlarmControlPanelCall::disarm}, + {"arm_away", &alarm_control_panel::AlarmControlPanelCall::arm_away}, + {"arm_home", &alarm_control_panel::AlarmControlPanelCall::arm_home}, + {"arm_night", &alarm_control_panel::AlarmControlPanelCall::arm_night}, + {"arm_vacation", &alarm_control_panel::AlarmControlPanelCall::arm_vacation}, + }; + + bool found = false; + for (const auto &method : METHODS) { + if (match.method_equals(method.name)) { + (call.*method.action)(); + found = true; + break; + } + } + + if (!found) { request->send(404); return; } @@ -1533,14 +1632,17 @@ std::string WebServer::alarm_control_panel_all_json_generator(WebServer *web_ser std::string WebServer::alarm_control_panel_json(alarm_control_panel::AlarmControlPanel *obj, alarm_control_panel::AlarmControlPanelState value, JsonDetail start_config) { - return json::build_json([this, obj, value, start_config](JsonObject root) { - char buf[16]; - set_json_icon_state_value(root, obj, "alarm-control-panel-" + obj->get_object_id(), - PSTR_LOCAL(alarm_control_panel_state_to_string(value)), value, start_config); - if (start_config == DETAIL_ALL) { - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + char buf[16]; + set_json_icon_state_value(root, obj, "alarm-control-panel", PSTR_LOCAL(alarm_control_panel_state_to_string(value)), + value, start_config); + if (start_config == DETAIL_ALL) { + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -1577,20 +1679,23 @@ std::string WebServer::event_all_json_generator(WebServer *web_server, void *sou return web_server->event_json(event, get_event_type(event), DETAIL_ALL); } std::string WebServer::event_json(event::Event *obj, const std::string &event_type, JsonDetail start_config) { - return json::build_json([this, obj, event_type, start_config](JsonObject root) { - set_json_id(root, obj, "event-" + obj->get_object_id(), start_config); - if (!event_type.empty()) { - root["event_type"] = event_type; + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "event", start_config); + if (!event_type.empty()) { + root["event_type"] = event_type; + } + if (start_config == DETAIL_ALL) { + JsonArray event_types = root["event_types"].to(); + for (auto const &event_type : obj->get_event_types()) { + event_types.add(event_type); } - if (start_config == DETAIL_ALL) { - JsonArray event_types = root["event_types"].to(); - for (auto const &event_type : obj->get_event_types()) { - event_types.add(event_type); - } - root["device_class"] = obj->get_device_class(); - this->add_sorting_info_(root, obj); - } - }); + root["device_class"] = obj->get_device_class_ref(); + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); } #endif @@ -1637,25 +1742,30 @@ void WebServer::handle_update_request(AsyncWebServerRequest *request, const UrlM request->send(404); } std::string WebServer::update_state_json_generator(WebServer *web_server, void *source) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson return web_server->update_json((update::UpdateEntity *) (source), DETAIL_STATE); } std::string WebServer::update_all_json_generator(WebServer *web_server, void *source) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson return web_server->update_json((update::UpdateEntity *) (source), DETAIL_STATE); } std::string WebServer::update_json(update::UpdateEntity *obj, JsonDetail start_config) { // NOLINTBEGIN(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson - return json::build_json([this, obj, start_config](JsonObject root) { - set_json_id(root, obj, "update-" + obj->get_object_id(), start_config); - root["value"] = obj->update_info.latest_version; - root["state"] = update_state_to_string(obj->state); - if (start_config == DETAIL_ALL) { - root["current_version"] = obj->update_info.current_version; - root["title"] = obj->update_info.title; - root["summary"] = obj->update_info.summary; - root["release_url"] = obj->update_info.release_url; - this->add_sorting_info_(root, obj); - } - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + + set_json_id(root, obj, "update", start_config); + root["value"] = obj->update_info.latest_version; + root["state"] = update_state_to_string(obj->state); + if (start_config == DETAIL_ALL) { + root["current_version"] = obj->update_info.current_version; + root["title"] = obj->update_info.title; + root["summary"] = obj->update_info.summary; + root["release_url"] = obj->update_info.release_url; + this->add_sorting_info_(root, obj); + } + + return builder.serialize(); // NOLINTEND(clang-analyzer-cplusplus.NewDeleteLeaks) } #endif @@ -1664,24 +1774,24 @@ bool WebServer::canHandle(AsyncWebServerRequest *request) const { const auto &url = request->url(); const auto method = request->method(); - // Simple URL checks - if (url == "/") - return true; - -#ifdef USE_ARDUINO - if (url == "/events") - return true; + // Static URL checks + static const char *const STATIC_URLS[] = { + "/", +#if !defined(USE_ESP32) && defined(USE_ARDUINO) + "/events", #endif - #ifdef USE_WEBSERVER_CSS_INCLUDE - if (url == "/0.css") - return true; + "/0.css", #endif - #ifdef USE_WEBSERVER_JS_INCLUDE - if (url == "/0.js") - return true; + "/0.js", #endif + }; + + for (const auto &static_url : STATIC_URLS) { + if (url == static_url) + return true; + } #ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS if (method == HTTP_OPTIONS && request->hasHeader(HEADER_CORS_REQ_PNA)) @@ -1701,92 +1811,87 @@ bool WebServer::canHandle(AsyncWebServerRequest *request) const { if (!is_get_or_post) return false; - // GET-only components - if (is_get) { + // Use lookup tables for domain checks + static const char *const GET_ONLY_DOMAINS[] = { #ifdef USE_SENSOR - if (match.domain_equals("sensor")) - return true; + "sensor", #endif #ifdef USE_BINARY_SENSOR - if (match.domain_equals("binary_sensor")) - return true; + "binary_sensor", #endif #ifdef USE_TEXT_SENSOR - if (match.domain_equals("text_sensor")) - return true; + "text_sensor", #endif #ifdef USE_EVENT - if (match.domain_equals("event")) - return true; + "event", #endif - } + }; - // GET+POST components - if (is_get_or_post) { + static const char *const GET_POST_DOMAINS[] = { #ifdef USE_SWITCH - if (match.domain_equals("switch")) - return true; + "switch", #endif #ifdef USE_BUTTON - if (match.domain_equals("button")) - return true; + "button", #endif #ifdef USE_FAN - if (match.domain_equals("fan")) - return true; + "fan", #endif #ifdef USE_LIGHT - if (match.domain_equals("light")) - return true; + "light", #endif #ifdef USE_COVER - if (match.domain_equals("cover")) - return true; + "cover", #endif #ifdef USE_NUMBER - if (match.domain_equals("number")) - return true; + "number", #endif #ifdef USE_DATETIME_DATE - if (match.domain_equals("date")) - return true; + "date", #endif #ifdef USE_DATETIME_TIME - if (match.domain_equals("time")) - return true; + "time", #endif #ifdef USE_DATETIME_DATETIME - if (match.domain_equals("datetime")) - return true; + "datetime", #endif #ifdef USE_TEXT - if (match.domain_equals("text")) - return true; + "text", #endif #ifdef USE_SELECT - if (match.domain_equals("select")) - return true; + "select", #endif #ifdef USE_CLIMATE - if (match.domain_equals("climate")) - return true; + "climate", #endif #ifdef USE_LOCK - if (match.domain_equals("lock")) - return true; + "lock", #endif #ifdef USE_VALVE - if (match.domain_equals("valve")) - return true; + "valve", #endif #ifdef USE_ALARM_CONTROL_PANEL - if (match.domain_equals("alarm_control_panel")) - return true; + "alarm_control_panel", #endif #ifdef USE_UPDATE - if (match.domain_equals("update")) - return true; + "update", #endif + }; + + // Check GET-only domains + if (is_get) { + for (const auto &domain : GET_ONLY_DOMAINS) { + if (match.domain_equals(domain)) + return true; + } + } + + // Check GET+POST domains + if (is_get_or_post) { + for (const auto &domain : GET_POST_DOMAINS) { + if (match.domain_equals(domain)) + return true; + } } return false; @@ -1800,7 +1905,7 @@ void WebServer::handleRequest(AsyncWebServerRequest *request) { return; } -#ifdef USE_ARDUINO +#if !defined(USE_ESP32) && defined(USE_ARDUINO) if (url == "/events") { this->events_.add_new_client(this, request); return; diff --git a/esphome/components/web_server/web_server.h b/esphome/components/web_server/web_server.h index e42c35b32d..2e5d58d375 100644 --- a/esphome/components/web_server/web_server.h +++ b/esphome/components/web_server/web_server.h @@ -81,7 +81,7 @@ enum JsonDetail { DETAIL_ALL, DETAIL_STATE }; implemented in a more straightforward way for ESP-IDF. Arduino platform will eventually go away and this workaround can be forgotten. */ -#ifdef USE_ARDUINO +#if !defined(USE_ESP32) && defined(USE_ARDUINO) using message_generator_t = std::string(WebServer *, void *); class DeferredUpdateEventSourceList; @@ -164,7 +164,7 @@ class DeferredUpdateEventSourceList : public std::listjs_url_ = js_url; } void WebServer::handle_index_request(AsyncWebServerRequest *request) { AsyncResponseStream *stream = request->beginResponseStream("text/html"); const std::string &title = App.get_name(); - stream->print(F("")); + stream->print(ESPHOME_F("<!DOCTYPE html><html lang=\"en\"><head><meta charset=UTF-8><meta " + "name=viewport content=\"width=device-width, initial-scale=1,user-scalable=no\"><title>")); stream->print(title.c_str()); - stream->print(F("")); + stream->print(ESPHOME_F("")); #ifdef USE_WEBSERVER_CSS_INCLUDE - stream->print(F("")); + stream->print(ESPHOME_F("")); #endif if (strlen(this->css_url_) > 0) { - stream->print(F(R"(print(ESPHOME_F(R"(print(this->css_url_); - stream->print(F("\">")); + stream->print(ESPHOME_F("\">")); } - stream->print(F("")); - stream->print(F("

")); + stream->print(ESPHOME_F("")); + stream->print(ESPHOME_F("

")); stream->print(title.c_str()); - stream->print(F("

")); - stream->print(F("

States

")); + stream->print(ESPHOME_F("")); + stream->print(ESPHOME_F("

States

NameStateActions
")); #ifdef USE_SENSOR for (auto *obj : App.get_sensors()) { @@ -190,26 +190,28 @@ void WebServer::handle_index_request(AsyncWebServerRequest *request) { } #endif - stream->print(F("
NameStateActions

See ESPHome Web API for " - "REST API documentation.

")); + stream->print( + ESPHOME_F("

See ESPHome Web API for " + "REST API documentation.

")); #if defined(USE_WEBSERVER_OTA) && !defined(USE_WEBSERVER_OTA_DISABLED) // Show OTA form only if web_server OTA is not explicitly disabled // Note: USE_WEBSERVER_OTA_DISABLED only affects web_server, not captive_portal - stream->print(F("

OTA Update

")); + stream->print( + ESPHOME_F("

OTA Update

")); #endif - stream->print(F("

Debug Log

"));
+  stream->print(ESPHOME_F("

Debug Log

"));
 #ifdef USE_WEBSERVER_JS_INCLUDE
   if (this->js_include_ != nullptr) {
-    stream->print(F(""));
+    stream->print(ESPHOME_F(""));
   }
 #endif
   if (strlen(this->js_url_) > 0) {
-    stream->print(F(""));
+    stream->print(ESPHOME_F("\">"));
   }
-  stream->print(F("
")); + stream->print(ESPHOME_F("

")); request->send(stream); } diff --git a/esphome/components/web_server_base/__init__.py b/esphome/components/web_server_base/__init__.py index a82ec462d9..4cf76eba0e 100644 --- a/esphome/components/web_server_base/__init__.py +++ b/esphome/components/web_server_base/__init__.py @@ -9,10 +9,10 @@ DEPENDENCIES = ["network"] def AUTO_LOAD(): + if CORE.is_esp32: + return ["web_server_idf"] if CORE.using_arduino: return ["async_tcp"] - if CORE.using_esp_idf: - return ["web_server_idf"] return [] @@ -33,6 +33,9 @@ async def to_code(config): await cg.register_component(var, config) cg.add(cg.RawExpression(f"{web_server_base_ns}::global_web_server_base = {var}")) + if CORE.is_esp32: + return + if CORE.using_arduino: if CORE.is_esp32: cg.add_library("WiFi", None) diff --git a/esphome/components/web_server_base/web_server_base.h b/esphome/components/web_server_base/web_server_base.h index cfca776ee1..039a452d64 100644 --- a/esphome/components/web_server_base/web_server_base.h +++ b/esphome/components/web_server_base/web_server_base.h @@ -7,11 +7,31 @@ #include "esphome/core/component.h" -#ifdef USE_ARDUINO -#include -#elif USE_ESP_IDF +// Platform-agnostic macros for web server components +// On ESP32 (both Arduino and IDF): Use plain strings (no PROGMEM) +// On ESP8266: Use Arduino's F() macro for PROGMEM strings +#ifdef USE_ESP32 +#define ESPHOME_F(string_literal) (string_literal) +#define ESPHOME_PGM_P const char * +#define ESPHOME_strncpy_P strncpy +#else +// ESP8266 uses Arduino macros +#define ESPHOME_F(string_literal) F(string_literal) +#define ESPHOME_PGM_P PGM_P +#define ESPHOME_strncpy_P strncpy_P +#endif + +#if USE_ESP32 #include "esphome/core/hal.h" #include "esphome/components/web_server_idf/web_server_idf.h" +#else +#include +#endif + +#if USE_ESP32 +using PlatformString = std::string; +#elif USE_ARDUINO +using PlatformString = String; #endif namespace esphome { @@ -28,8 +48,8 @@ class MiddlewareHandler : public AsyncWebHandler { bool canHandle(AsyncWebServerRequest *request) const override { return next_->canHandle(request); } void handleRequest(AsyncWebServerRequest *request) override { next_->handleRequest(request); } - void handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len, - bool final) override { + void handleUpload(AsyncWebServerRequest *request, const PlatformString &filename, size_t index, uint8_t *data, + size_t len, bool final) override { next_->handleUpload(request, filename, index, data, len, final); } void handleBody(AsyncWebServerRequest *request, uint8_t *data, size_t len, size_t index, size_t total) override { @@ -65,8 +85,8 @@ class AuthMiddlewareHandler : public MiddlewareHandler { return; MiddlewareHandler::handleRequest(request); } - void handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len, - bool final) override { + void handleUpload(AsyncWebServerRequest *request, const PlatformString &filename, size_t index, uint8_t *data, + size_t len, bool final) override { if (!check_auth(request)) return; MiddlewareHandler::handleUpload(request, filename, index, data, len, final); diff --git a/esphome/components/web_server_idf/__init__.py b/esphome/components/web_server_idf/__init__.py index 506e1c5c13..74a9d657a6 100644 --- a/esphome/components/web_server_idf/__init__.py +++ b/esphome/components/web_server_idf/__init__.py @@ -5,7 +5,7 @@ CODEOWNERS = ["@dentra"] CONFIG_SCHEMA = cv.All( cv.Schema({}), - cv.only_with_esp_idf, + cv.only_on_esp32, ) diff --git a/esphome/components/web_server_idf/multipart.cpp b/esphome/components/web_server_idf/multipart.cpp index 8655226ab9..2092a41a8e 100644 --- a/esphome/components/web_server_idf/multipart.cpp +++ b/esphome/components/web_server_idf/multipart.cpp @@ -1,5 +1,5 @@ #include "esphome/core/defines.h" -#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#if defined(USE_ESP32) && defined(USE_WEBSERVER_OTA) #include "multipart.h" #include "utils.h" #include "esphome/core/log.h" @@ -251,4 +251,4 @@ std::string str_trim(const std::string &str) { } // namespace web_server_idf } // namespace esphome -#endif // defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#endif // defined(USE_ESP32) && defined(USE_WEBSERVER_OTA) diff --git a/esphome/components/web_server_idf/multipart.h b/esphome/components/web_server_idf/multipart.h index 967c72ffa5..8fbe90c4a0 100644 --- a/esphome/components/web_server_idf/multipart.h +++ b/esphome/components/web_server_idf/multipart.h @@ -1,6 +1,6 @@ #pragma once #include "esphome/core/defines.h" -#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#if defined(USE_ESP32) && defined(USE_WEBSERVER_OTA) #include #include @@ -83,4 +83,4 @@ std::string str_trim(const std::string &str); } // namespace web_server_idf } // namespace esphome -#endif // defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#endif // defined(USE_ESP32) && defined(USE_WEBSERVER_OTA) diff --git a/esphome/components/web_server_idf/utils.cpp b/esphome/components/web_server_idf/utils.cpp index ac5df90bb8..d5d34b520b 100644 --- a/esphome/components/web_server_idf/utils.cpp +++ b/esphome/components/web_server_idf/utils.cpp @@ -1,4 +1,4 @@ -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include #include #include @@ -122,4 +122,4 @@ const char *stristr(const char *haystack, const char *needle) { } // namespace web_server_idf } // namespace esphome -#endif // USE_ESP_IDF +#endif // USE_ESP32 diff --git a/esphome/components/web_server_idf/utils.h b/esphome/components/web_server_idf/utils.h index 988b962d72..f70a5f0760 100644 --- a/esphome/components/web_server_idf/utils.h +++ b/esphome/components/web_server_idf/utils.h @@ -1,5 +1,5 @@ #pragma once -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include #include @@ -24,4 +24,4 @@ const char *stristr(const char *haystack, const char *needle); } // namespace web_server_idf } // namespace esphome -#endif // USE_ESP_IDF +#endif // USE_ESP32 diff --git a/esphome/components/web_server_idf/web_server_idf.cpp b/esphome/components/web_server_idf/web_server_idf.cpp index 55b07c0f5e..b38c5fb92a 100644 --- a/esphome/components/web_server_idf/web_server_idf.cpp +++ b/esphome/components/web_server_idf/web_server_idf.cpp @@ -1,4 +1,4 @@ -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include #include @@ -25,6 +25,10 @@ #include "esphome/components/web_server/list_entities.h" #endif // USE_WEBSERVER +// Include socket headers after Arduino headers to avoid IPADDR_NONE/INADDR_NONE macro conflicts +#include +#include + namespace esphome { namespace web_server_idf { @@ -46,6 +50,42 @@ DefaultHeaders default_headers_instance; DefaultHeaders &DefaultHeaders::Instance() { return default_headers_instance; } +namespace { +// Non-blocking send function to prevent watchdog timeouts when TCP buffers are full +/** + * Sends data on a socket in non-blocking mode. + * + * @param hd HTTP server handle (unused). + * @param sockfd Socket file descriptor. + * @param buf Buffer to send. + * @param buf_len Length of buffer. + * @param flags Flags for send(). + * @return + * - Number of bytes sent on success. + * - HTTPD_SOCK_ERR_INVALID if buf is nullptr. + * - HTTPD_SOCK_ERR_TIMEOUT if the send buffer is full (EAGAIN/EWOULDBLOCK). + * - HTTPD_SOCK_ERR_FAIL for other errors. + */ +int nonblocking_send(httpd_handle_t hd, int sockfd, const char *buf, size_t buf_len, int flags) { + if (buf == nullptr) { + return HTTPD_SOCK_ERR_INVALID; + } + + // Use MSG_DONTWAIT to prevent blocking when TCP send buffer is full + int ret = send(sockfd, buf, buf_len, flags | MSG_DONTWAIT); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // Buffer full - retry later + return HTTPD_SOCK_ERR_TIMEOUT; + } + // Real error + ESP_LOGD(TAG, "send error: errno %d", errno); + return HTTPD_SOCK_ERR_FAIL; + } + return ret; +} +} // namespace + void AsyncWebServer::end() { if (this->server_) { httpd_stop(this->server_); @@ -164,8 +204,8 @@ esp_err_t AsyncWebServer::request_handler_(AsyncWebServerRequest *request) const AsyncWebServerRequest::~AsyncWebServerRequest() { delete this->rsp_; - for (const auto &pair : this->params_) { - delete pair.second; // NOLINT(cppcoreguidelines-owning-memory) + for (auto *param : this->params_) { + delete param; // NOLINT(cppcoreguidelines-owning-memory) } } @@ -205,10 +245,22 @@ void AsyncWebServerRequest::redirect(const std::string &url) { } void AsyncWebServerRequest::init_response_(AsyncWebServerResponse *rsp, int code, const char *content_type) { - httpd_resp_set_status(*this, code == 200 ? HTTPD_200 - : code == 404 ? HTTPD_404 - : code == 409 ? HTTPD_409 - : to_string(code).c_str()); + // Set status code - use constants for common codes to avoid string allocation + const char *status = nullptr; + switch (code) { + case 200: + status = HTTPD_200; + break; + case 404: + status = HTTPD_404; + break; + case 409: + status = HTTPD_409; + break; + default: + break; + } + httpd_resp_set_status(*this, status == nullptr ? to_string(code).c_str() : status); if (content_type && *content_type) { httpd_resp_set_type(*this, content_type); @@ -265,11 +317,14 @@ void AsyncWebServerRequest::requestAuthentication(const char *realm) const { #endif AsyncWebParameter *AsyncWebServerRequest::getParam(const std::string &name) { - auto find = this->params_.find(name); - if (find != this->params_.end()) { - return find->second; + // Check cache first - only successful lookups are cached + for (auto *param : this->params_) { + if (param->name() == name) { + return param; + } } + // Look up value from query strings optional val = query_key_value(this->post_query_, name); if (!val.has_value()) { auto url_query = request_get_url_query(*this); @@ -278,11 +333,14 @@ AsyncWebParameter *AsyncWebServerRequest::getParam(const std::string &name) { } } - AsyncWebParameter *param = nullptr; - if (val.has_value()) { - param = new AsyncWebParameter(val.value()); // NOLINT(cppcoreguidelines-owning-memory) + // Don't cache misses to avoid wasting memory when handlers check for + // optional parameters that don't exist in the request + if (!val.has_value()) { + return nullptr; } - this->params_.insert({name, param}); + + auto *param = new AsyncWebParameter(name, val.value()); // NOLINT(cppcoreguidelines-owning-memory) + this->params_.push_back(param); return param; } @@ -317,8 +375,8 @@ AsyncEventSource::~AsyncEventSource() { } void AsyncEventSource::handleRequest(AsyncWebServerRequest *request) { - auto *rsp = // NOLINT(cppcoreguidelines-owning-memory) - new AsyncEventSourceResponse(request, this, this->web_server_); + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory,clang-analyzer-cplusplus.NewDeleteLeaks) + auto *rsp = new AsyncEventSourceResponse(request, this, this->web_server_); if (this->on_connect_) { this->on_connect_(rsp); } @@ -384,6 +442,9 @@ AsyncEventSourceResponse::AsyncEventSourceResponse(const AsyncWebServerRequest * this->hd_ = req->handle; this->fd_.store(httpd_req_to_sockfd(req)); + // Use non-blocking send to prevent watchdog timeouts when TCP buffers are full + httpd_sess_set_send_override(this->hd_, this->fd_.load(), nonblocking_send); + // Configure reconnect timeout and send config // this should always go through since the tcp send buffer is empty on connect std::string message = ws->get_config_json(); @@ -392,10 +453,11 @@ AsyncEventSourceResponse::AsyncEventSourceResponse(const AsyncWebServerRequest * #ifdef USE_WEBSERVER_SORTING for (auto &group : ws->sorting_groups_) { // NOLINTBEGIN(clang-analyzer-cplusplus.NewDeleteLeaks) false positive with ArduinoJson - message = json::build_json([group](JsonObject root) { - root["name"] = group.second.name; - root["sorting_weight"] = group.second.weight; - }); + json::JsonBuilder builder; + JsonObject root = builder.root(); + root["name"] = group.second.name; + root["sorting_weight"] = group.second.weight; + message = builder.serialize(); // NOLINTEND(clang-analyzer-cplusplus.NewDeleteLeaks) // a (very) large number of these should be able to be queued initially without defer @@ -458,15 +520,45 @@ void AsyncEventSourceResponse::process_buffer_() { return; } - int bytes_sent = httpd_socket_send(this->hd_, this->fd_.load(), event_buffer_.c_str() + event_bytes_sent_, - event_buffer_.size() - event_bytes_sent_, 0); - if (bytes_sent == HTTPD_SOCK_ERR_TIMEOUT || bytes_sent == HTTPD_SOCK_ERR_FAIL) { - // Socket error - just return, the connection will be closed by httpd - // and our destroy callback will be called + size_t remaining = event_buffer_.size() - event_bytes_sent_; + int bytes_sent = + httpd_socket_send(this->hd_, this->fd_.load(), event_buffer_.c_str() + event_bytes_sent_, remaining, 0); + if (bytes_sent == HTTPD_SOCK_ERR_TIMEOUT) { + // EAGAIN/EWOULDBLOCK - socket buffer full, try again later + // NOTE: Similar logic exists in web_server/web_server.cpp in DeferredUpdateEventSource::process_deferred_queue_() + // The implementations differ due to platform-specific APIs (HTTPD_SOCK_ERR_TIMEOUT vs DISCARDED, fd_.store(0) vs + // close()), but the failure counting and timeout logic should be kept in sync. If you change this logic, also + // update the Arduino implementation. + this->consecutive_send_failures_++; + if (this->consecutive_send_failures_ >= MAX_CONSECUTIVE_SEND_FAILURES) { + // Too many failures, connection is likely dead + ESP_LOGW(TAG, "Closing stuck EventSource connection after %" PRIu16 " failed sends", + this->consecutive_send_failures_); + this->fd_.store(0); // Mark for cleanup + this->deferred_queue_.clear(); + } return; } + if (bytes_sent == HTTPD_SOCK_ERR_FAIL) { + // Real socket error - connection will be closed by httpd and destroy callback will be called + return; + } + if (bytes_sent <= 0) { + // Unexpected error or zero bytes sent + ESP_LOGW(TAG, "Unexpected send result: %d", bytes_sent); + return; + } + + // Successful send - reset failure counter + this->consecutive_send_failures_ = 0; event_bytes_sent_ += bytes_sent; + // Log partial sends for debugging + if (event_bytes_sent_ < event_buffer_.size()) { + ESP_LOGV(TAG, "Partial send: %d/%zu bytes (total: %zu/%zu)", bytes_sent, remaining, event_bytes_sent_, + event_buffer_.size()); + } + if (event_bytes_sent_ == event_buffer_.size()) { event_buffer_.resize(0); event_bytes_sent_ = 0; @@ -669,4 +761,4 @@ esp_err_t AsyncWebServer::handle_multipart_upload_(httpd_req_t *r, const char *c } // namespace web_server_idf } // namespace esphome -#endif // !defined(USE_ESP_IDF) +#endif // !defined(USE_ESP32) diff --git a/esphome/components/web_server_idf/web_server_idf.h b/esphome/components/web_server_idf/web_server_idf.h index 76540ef232..bf93dcbd34 100644 --- a/esphome/components/web_server_idf/web_server_idf.h +++ b/esphome/components/web_server_idf/web_server_idf.h @@ -1,5 +1,5 @@ #pragma once -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include "esphome/core/defines.h" #include @@ -22,18 +22,14 @@ class ListEntitiesIterator; #endif namespace web_server_idf { -#define F(string_literal) (string_literal) -#define PGM_P const char * -#define strncpy_P strncpy - -using String = std::string; - class AsyncWebParameter { public: - AsyncWebParameter(std::string value) : value_(std::move(value)) {} + AsyncWebParameter(std::string name, std::string value) : name_(std::move(name)), value_(std::move(value)) {} + const std::string &name() const { return this->name_; } const std::string &value() const { return this->value_; } protected: + std::string name_; std::string value_; }; @@ -174,7 +170,11 @@ class AsyncWebServerRequest { protected: httpd_req_t *req_; AsyncWebServerResponse *rsp_{}; - std::map params_; + // Use vector instead of map/unordered_map: most requests have 0-3 params, so linear search + // is faster than tree/hash overhead. AsyncWebParameter stores both name and value to avoid + // duplicate storage. Only successful lookups are cached to prevent cache pollution when + // handlers check for optional parameters that don't exist. + std::vector params_; std::string post_query_; AsyncWebServerRequest(httpd_req_t *req) : req_(req) {} AsyncWebServerRequest(httpd_req_t *req, std::string post_query) : req_(req), post_query_(std::move(post_query)) {} @@ -283,6 +283,8 @@ class AsyncEventSourceResponse { std::unique_ptr entities_iterator_; std::string event_buffer_{""}; size_t event_bytes_sent_; + uint16_t consecutive_send_failures_{0}; + static constexpr uint16_t MAX_CONSECUTIVE_SEND_FAILURES = 2500; // ~20 seconds at 125Hz loop rate }; using AsyncEventSourceClient = AsyncEventSourceResponse; @@ -341,4 +343,4 @@ class DefaultHeaders { using namespace esphome::web_server_idf; // NOLINT(google-global-names-in-headers) -#endif // !defined(USE_ESP_IDF) +#endif // !defined(USE_ESP32) diff --git a/esphome/components/wifi/__init__.py b/esphome/components/wifi/__init__.py index c63a12f879..a784123006 100644 --- a/esphome/components/wifi/__init__.py +++ b/esphome/components/wifi/__init__.py @@ -125,8 +125,8 @@ EAP_AUTH_SCHEMA = cv.All( cv.Optional(CONF_USERNAME): cv.string_strict, cv.Optional(CONF_PASSWORD): cv.string_strict, cv.Optional(CONF_CERTIFICATE_AUTHORITY): wpa2_eap.validate_certificate, - cv.SplitDefault(CONF_TTLS_PHASE_2, esp32_idf="mschapv2"): cv.All( - cv.enum(TTLS_PHASE_2), cv.only_with_esp_idf + cv.SplitDefault(CONF_TTLS_PHASE_2, esp32="mschapv2"): cv.All( + cv.enum(TTLS_PHASE_2), cv.only_on_esp32 ), cv.Inclusive( CONF_CERTIFICATE, "certificate_and_key" @@ -280,11 +280,11 @@ CONFIG_SCHEMA = cv.All( cv.SplitDefault(CONF_OUTPUT_POWER, esp8266=20.0): cv.All( cv.decibel, cv.float_range(min=8.5, max=20.5) ), - cv.SplitDefault(CONF_ENABLE_BTM, esp32_idf=False): cv.All( - cv.boolean, cv.only_with_esp_idf + cv.SplitDefault(CONF_ENABLE_BTM, esp32=False): cv.All( + cv.boolean, cv.only_on_esp32 ), - cv.SplitDefault(CONF_ENABLE_RRM, esp32_idf=False): cv.All( - cv.boolean, cv.only_with_esp_idf + cv.SplitDefault(CONF_ENABLE_RRM, esp32=False): cv.All( + cv.boolean, cv.only_on_esp32 ), cv.Optional(CONF_PASSIVE_SCAN, default=False): cv.boolean, cv.Optional("enable_mdns"): cv.invalid( @@ -402,7 +402,7 @@ async def to_code(config): add_idf_sdkconfig_option("CONFIG_LWIP_DHCPS", False) # Disable Enterprise WiFi support if no EAP is configured - if CORE.is_esp32 and CORE.using_esp_idf and not has_eap: + if CORE.is_esp32 and not has_eap: add_idf_sdkconfig_option("CONFIG_ESP_WIFI_ENTERPRISE_SUPPORT", False) cg.add(var.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) @@ -416,10 +416,10 @@ async def to_code(config): if CORE.is_esp8266: cg.add_library("ESP8266WiFi", None) - elif (CORE.is_esp32 and CORE.using_arduino) or CORE.is_rp2040: + elif CORE.is_rp2040: cg.add_library("WiFi", None) - if CORE.is_esp32 and CORE.using_esp_idf: + if CORE.is_esp32: if config[CONF_ENABLE_BTM] or config[CONF_ENABLE_RRM]: add_idf_sdkconfig_option("CONFIG_WPA_11KV_SUPPORT", True) cg.add_define("USE_WIFI_11KV_SUPPORT") @@ -506,8 +506,10 @@ async def wifi_set_sta_to_code(config, action_id, template_arg, args): FILTER_SOURCE_FILES = filter_source_files_from_platform( { - "wifi_component_esp32_arduino.cpp": {PlatformFramework.ESP32_ARDUINO}, - "wifi_component_esp_idf.cpp": {PlatformFramework.ESP32_IDF}, + "wifi_component_esp_idf.cpp": { + PlatformFramework.ESP32_IDF, + PlatformFramework.ESP32_ARDUINO, + }, "wifi_component_esp8266.cpp": {PlatformFramework.ESP8266_ARDUINO}, "wifi_component_libretiny.cpp": { PlatformFramework.BK72XX_ARDUINO, diff --git a/esphome/components/wifi/wifi_component.cpp b/esphome/components/wifi/wifi_component.cpp index 43ece636e5..2e083d4c68 100644 --- a/esphome/components/wifi/wifi_component.cpp +++ b/esphome/components/wifi/wifi_component.cpp @@ -1,9 +1,8 @@ #include "wifi_component.h" #ifdef USE_WIFI #include -#include -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #if (ESP_IDF_VERSION_MAJOR >= 5 && ESP_IDF_VERSION_MINOR >= 1) #include #else @@ -11,7 +10,7 @@ #endif #endif -#if defined(USE_ESP32) || defined(USE_ESP_IDF) +#if defined(USE_ESP32) #include #endif #ifdef USE_ESP8266 @@ -42,6 +41,25 @@ namespace wifi { static const char *const TAG = "wifi"; +#if defined(USE_ESP32) && defined(USE_WIFI_WPA2_EAP) && ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE +static const char *eap_phase2_to_str(esp_eap_ttls_phase2_types type) { + switch (type) { + case ESP_EAP_TTLS_PHASE2_PAP: + return "pap"; + case ESP_EAP_TTLS_PHASE2_CHAP: + return "chap"; + case ESP_EAP_TTLS_PHASE2_MSCHAP: + return "mschap"; + case ESP_EAP_TTLS_PHASE2_MSCHAPV2: + return "mschapv2"; + case ESP_EAP_TTLS_PHASE2_EAP: + return "eap"; + default: + return "unknown"; + } +} +#endif + float WiFiComponent::get_setup_priority() const { return setup_priority::WIFI; } void WiFiComponent::setup() { @@ -266,30 +284,34 @@ void WiFiComponent::setup_ap_config_() { std::string name = App.get_name(); if (name.length() > 32) { if (App.is_name_add_mac_suffix_enabled()) { - name.erase(name.begin() + 25, name.end() - 7); // Remove characters between 25 and the mac address + // Keep first 25 chars and last 7 chars (MAC suffix), remove middle + name.erase(25, name.length() - 32); } else { - name = name.substr(0, 32); + name.resize(32); } } this->ap_.set_ssid(name); } + this->ap_setup_ = this->wifi_start_ap_(this->ap_); + + auto ip_address = this->wifi_soft_ap_ip().str(); ESP_LOGCONFIG(TAG, "Setting up AP:\n" " AP SSID: '%s'\n" - " AP Password: '%s'", - this->ap_.get_ssid().c_str(), this->ap_.get_password().c_str()); - if (this->ap_.get_manual_ip().has_value()) { - auto manual = *this->ap_.get_manual_ip(); + " AP Password: '%s'\n" + " IP Address: %s", + this->ap_.get_ssid().c_str(), this->ap_.get_password().c_str(), ip_address.c_str()); + + auto manual_ip = this->ap_.get_manual_ip(); + if (manual_ip.has_value()) { ESP_LOGCONFIG(TAG, " AP Static IP: '%s'\n" " AP Gateway: '%s'\n" " AP Subnet: '%s'", - manual.static_ip.str().c_str(), manual.gateway.str().c_str(), manual.subnet.str().c_str()); + manual_ip->static_ip.str().c_str(), manual_ip->gateway.str().c_str(), + manual_ip->subnet.str().c_str()); } - this->ap_setup_ = this->wifi_start_ap_(this->ap_); - ESP_LOGCONFIG(TAG, " IP Address: %s", this->wifi_soft_ap_ip().str().c_str()); - if (!this->has_sta()) { this->state_ = WIFI_COMPONENT_STATE_AP; } @@ -312,9 +334,9 @@ void WiFiComponent::set_sta(const WiFiAP &ap) { } void WiFiComponent::clear_sta() { this->sta_.clear(); } void WiFiComponent::save_wifi_sta(const std::string &ssid, const std::string &password) { - SavedWifiSettings save{}; - snprintf(save.ssid, sizeof(save.ssid), "%s", ssid.c_str()); - snprintf(save.password, sizeof(save.password), "%s", password.c_str()); + SavedWifiSettings save{}; // zero-initialized - all bytes set to \0, guaranteeing null termination + strncpy(save.ssid, ssid.c_str(), sizeof(save.ssid) - 1); // max 32 chars, byte 32 remains \0 + strncpy(save.password, password.c_str(), sizeof(save.password) - 1); // max 64 chars, byte 64 remains \0 this->pref_.save(&save); // ensure it's written immediately global_preferences->sync(); @@ -331,8 +353,7 @@ void WiFiComponent::start_connecting(const WiFiAP &ap, bool two) { ESP_LOGV(TAG, "Connection Params:"); ESP_LOGV(TAG, " SSID: '%s'", ap.get_ssid().c_str()); if (ap.get_bssid().has_value()) { - bssid_t b = *ap.get_bssid(); - ESP_LOGV(TAG, " BSSID: %02X:%02X:%02X:%02X:%02X:%02X", b[0], b[1], b[2], b[3], b[4], b[5]); + ESP_LOGV(TAG, " BSSID: %s", format_mac_address_pretty(ap.get_bssid()->data()).c_str()); } else { ESP_LOGV(TAG, " BSSID: Not Set"); } @@ -344,15 +365,8 @@ void WiFiComponent::start_connecting(const WiFiAP &ap, bool two) { ESP_LOGV(TAG, " Identity: " LOG_SECRET("'%s'"), eap_config.identity.c_str()); ESP_LOGV(TAG, " Username: " LOG_SECRET("'%s'"), eap_config.username.c_str()); ESP_LOGV(TAG, " Password: " LOG_SECRET("'%s'"), eap_config.password.c_str()); -#ifdef USE_ESP_IDF -#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE - std::map phase2types = {{ESP_EAP_TTLS_PHASE2_PAP, "pap"}, - {ESP_EAP_TTLS_PHASE2_CHAP, "chap"}, - {ESP_EAP_TTLS_PHASE2_MSCHAP, "mschap"}, - {ESP_EAP_TTLS_PHASE2_MSCHAPV2, "mschapv2"}, - {ESP_EAP_TTLS_PHASE2_EAP, "eap"}}; - ESP_LOGV(TAG, " TTLS Phase 2: " LOG_SECRET("'%s'"), phase2types[eap_config.ttls_phase_2].c_str()); -#endif +#if defined(USE_ESP32) && defined(USE_WIFI_WPA2_EAP) && ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE + ESP_LOGV(TAG, " TTLS Phase 2: " LOG_SECRET("'%s'"), eap_phase2_to_str(eap_config.ttls_phase_2)); #endif bool ca_cert_present = eap_config.ca_cert != nullptr && strlen(eap_config.ca_cert); bool client_cert_present = eap_config.client_cert != nullptr && strlen(eap_config.client_cert); @@ -446,7 +460,6 @@ void WiFiComponent::print_connect_params_() { ESP_LOGCONFIG(TAG, " Disabled"); return; } - ESP_LOGCONFIG(TAG, " SSID: " LOG_SECRET("'%s'"), wifi_ssid().c_str()); for (auto &ip : wifi_sta_ip_addresses()) { if (ip.is_set()) { ESP_LOGCONFIG(TAG, " IP Address: %s", ip.str().c_str()); @@ -454,24 +467,23 @@ void WiFiComponent::print_connect_params_() { } int8_t rssi = wifi_rssi(); ESP_LOGCONFIG(TAG, - " BSSID: " LOG_SECRET("%02X:%02X:%02X:%02X:%02X:%02X") "\n" - " Hostname: '%s'\n" - " Signal strength: %d dB %s", - bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5], App.get_name().c_str(), rssi, - LOG_STR_ARG(get_signal_bars(rssi))); + " SSID: " LOG_SECRET("'%s'") "\n" + " BSSID: " LOG_SECRET("%s") "\n" + " Hostname: '%s'\n" + " Signal strength: %d dB %s\n" + " Channel: %" PRId32 "\n" + " Subnet: %s\n" + " Gateway: %s\n" + " DNS1: %s\n" + " DNS2: %s", + wifi_ssid().c_str(), format_mac_address_pretty(bssid.data()).c_str(), App.get_name().c_str(), rssi, + LOG_STR_ARG(get_signal_bars(rssi)), get_wifi_channel(), wifi_subnet_mask_().str().c_str(), + wifi_gateway_ip_().str().c_str(), wifi_dns_ip_(0).str().c_str(), wifi_dns_ip_(1).str().c_str()); #ifdef ESPHOME_LOG_HAS_VERBOSE if (this->selected_ap_.get_bssid().has_value()) { ESP_LOGV(TAG, " Priority: %.1f", this->get_sta_priority(*this->selected_ap_.get_bssid())); } #endif - ESP_LOGCONFIG(TAG, - " Channel: %" PRId32 "\n" - " Subnet: %s\n" - " Gateway: %s\n" - " DNS1: %s\n" - " DNS2: %s", - get_wifi_channel(), wifi_subnet_mask_().str().c_str(), wifi_gateway_ip_().str().c_str(), - wifi_dns_ip_(0).str().c_str(), wifi_dns_ip_(1).str().c_str()); #ifdef USE_WIFI_11KV_SUPPORT ESP_LOGCONFIG(TAG, " BTM: %s\n" @@ -557,6 +569,25 @@ static void insertion_sort_scan_results(std::vector &results) { } } +// Helper function to log scan results - marked noinline to prevent re-inlining into loop +__attribute__((noinline)) static void log_scan_result(const WiFiScanResult &res) { + char bssid_s[18]; + auto bssid = res.get_bssid(); + format_mac_addr_upper(bssid.data(), bssid_s); + + if (res.get_matches()) { + ESP_LOGI(TAG, "- '%s' %s" LOG_SECRET("(%s) ") "%s", res.get_ssid().c_str(), res.get_is_hidden() ? "(HIDDEN) " : "", + bssid_s, LOG_STR_ARG(get_signal_bars(res.get_rssi()))); + ESP_LOGD(TAG, + " Channel: %u\n" + " RSSI: %d dB", + res.get_channel(), res.get_rssi()); + } else { + ESP_LOGD(TAG, "- " LOG_SECRET("'%s'") " " LOG_SECRET("(%s) ") "%s", res.get_ssid().c_str(), bssid_s, + LOG_STR_ARG(get_signal_bars(res.get_rssi()))); + } +} + void WiFiComponent::check_scanning_finished() { if (!this->scan_done_) { if (millis() - this->action_started_ > 30000) { @@ -591,21 +622,7 @@ void WiFiComponent::check_scanning_finished() { insertion_sort_scan_results(this->scan_result_); for (auto &res : this->scan_result_) { - char bssid_s[18]; - auto bssid = res.get_bssid(); - format_mac_addr_upper(bssid.data(), bssid_s); - - if (res.get_matches()) { - ESP_LOGI(TAG, "- '%s' %s" LOG_SECRET("(%s) ") "%s", res.get_ssid().c_str(), - res.get_is_hidden() ? "(HIDDEN) " : "", bssid_s, LOG_STR_ARG(get_signal_bars(res.get_rssi()))); - ESP_LOGD(TAG, - " Channel: %u\n" - " RSSI: %d dB", - res.get_channel(), res.get_rssi()); - } else { - ESP_LOGD(TAG, "- " LOG_SECRET("'%s'") " " LOG_SECRET("(%s) ") "%s", res.get_ssid().c_str(), bssid_s, - LOG_STR_ARG(get_signal_bars(res.get_rssi()))); - } + log_scan_result(res); } if (!this->scan_result_[0].get_matches()) { diff --git a/esphome/components/wifi/wifi_component.h b/esphome/components/wifi/wifi_component.h index bbe1bbb874..ee62ec1a69 100644 --- a/esphome/components/wifi/wifi_component.h +++ b/esphome/components/wifi/wifi_component.h @@ -20,7 +20,7 @@ #include #endif -#if defined(USE_ESP_IDF) && defined(USE_WIFI_WPA2_EAP) +#if defined(USE_ESP32) && defined(USE_WIFI_WPA2_EAP) #if (ESP_IDF_VERSION_MAJOR >= 5) && (ESP_IDF_VERSION_MINOR >= 1) #include #else @@ -113,7 +113,7 @@ struct EAPAuth { const char *client_cert; const char *client_key; // used for EAP-TTLS -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 esp_eap_ttls_phase2_types ttls_phase_2; #endif }; @@ -199,7 +199,7 @@ enum WiFiPowerSaveMode : uint8_t { WIFI_POWER_SAVE_HIGH, }; -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 struct IDFWiFiEvent; #endif @@ -368,7 +368,7 @@ class WiFiComponent : public Component { void wifi_event_callback_(arduino_event_id_t event, arduino_event_info_t info); void wifi_scan_done_callback_(); #endif -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 void wifi_process_event_(IDFWiFiEvent *data); #endif diff --git a/esphome/components/wifi/wifi_component_esp32_arduino.cpp b/esphome/components/wifi/wifi_component_esp32_arduino.cpp deleted file mode 100644 index 89298e07c7..0000000000 --- a/esphome/components/wifi/wifi_component_esp32_arduino.cpp +++ /dev/null @@ -1,860 +0,0 @@ -#include "wifi_component.h" - -#ifdef USE_WIFI -#ifdef USE_ESP32_FRAMEWORK_ARDUINO - -#include -#include - -#include -#include -#ifdef USE_WIFI_WPA2_EAP -#include -#endif - -#ifdef USE_WIFI_AP -#include "dhcpserver/dhcpserver.h" -#endif // USE_WIFI_AP - -#include "lwip/apps/sntp.h" -#include "lwip/dns.h" -#include "lwip/err.h" - -#include "esphome/core/application.h" -#include "esphome/core/hal.h" -#include "esphome/core/helpers.h" -#include "esphome/core/log.h" -#include "esphome/core/util.h" - -namespace esphome { -namespace wifi { - -static const char *const TAG = "wifi_esp32"; - -static esp_netif_t *s_sta_netif = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) -#ifdef USE_WIFI_AP -static esp_netif_t *s_ap_netif = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) -#endif // USE_WIFI_AP - -static bool s_sta_connecting = false; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) - -void WiFiComponent::wifi_pre_setup_() { - uint8_t mac[6]; - if (has_custom_mac_address()) { - get_mac_address_raw(mac); - set_mac_address(mac); - } - auto f = std::bind(&WiFiComponent::wifi_event_callback_, this, std::placeholders::_1, std::placeholders::_2); - WiFi.onEvent(f); - WiFi.persistent(false); - // Make sure WiFi is in clean state before anything starts - this->wifi_mode_(false, false); -} - -bool WiFiComponent::wifi_mode_(optional sta, optional ap) { - wifi_mode_t current_mode = WiFiClass::getMode(); - bool current_sta = current_mode == WIFI_MODE_STA || current_mode == WIFI_MODE_APSTA; - bool current_ap = current_mode == WIFI_MODE_AP || current_mode == WIFI_MODE_APSTA; - - bool set_sta = sta.value_or(current_sta); - bool set_ap = ap.value_or(current_ap); - - wifi_mode_t set_mode; - if (set_sta && set_ap) { - set_mode = WIFI_MODE_APSTA; - } else if (set_sta && !set_ap) { - set_mode = WIFI_MODE_STA; - } else if (!set_sta && set_ap) { - set_mode = WIFI_MODE_AP; - } else { - set_mode = WIFI_MODE_NULL; - } - - if (current_mode == set_mode) - return true; - - if (set_sta && !current_sta) { - ESP_LOGV(TAG, "Enabling STA"); - } else if (!set_sta && current_sta) { - ESP_LOGV(TAG, "Disabling STA"); - } - if (set_ap && !current_ap) { - ESP_LOGV(TAG, "Enabling AP"); - } else if (!set_ap && current_ap) { - ESP_LOGV(TAG, "Disabling AP"); - } - - bool ret = WiFiClass::mode(set_mode); - - if (!ret) { - ESP_LOGW(TAG, "Setting mode failed"); - return false; - } - - // WiFiClass::mode above calls esp_netif_create_default_wifi_sta() and - // esp_netif_create_default_wifi_ap(), which creates the interfaces. - // s_sta_netif handle is set during ESPHOME_EVENT_ID_WIFI_STA_START event - -#ifdef USE_WIFI_AP - if (set_ap) - s_ap_netif = esp_netif_get_handle_from_ifkey("WIFI_AP_DEF"); -#endif - - return ret; -} - -bool WiFiComponent::wifi_sta_pre_setup_() { - if (!this->wifi_mode_(true, {})) - return false; - - WiFi.setAutoReconnect(false); - delay(10); - return true; -} - -bool WiFiComponent::wifi_apply_output_power_(float output_power) { - int8_t val = static_cast(output_power * 4); - return esp_wifi_set_max_tx_power(val) == ESP_OK; -} - -bool WiFiComponent::wifi_apply_power_save_() { - wifi_ps_type_t power_save; - switch (this->power_save_) { - case WIFI_POWER_SAVE_LIGHT: - power_save = WIFI_PS_MIN_MODEM; - break; - case WIFI_POWER_SAVE_HIGH: - power_save = WIFI_PS_MAX_MODEM; - break; - case WIFI_POWER_SAVE_NONE: - default: - power_save = WIFI_PS_NONE; - break; - } - return esp_wifi_set_ps(power_save) == ESP_OK; -} - -bool WiFiComponent::wifi_sta_connect_(const WiFiAP &ap) { - // enable STA - if (!this->wifi_mode_(true, {})) - return false; - - // https://docs.espressif.com/projects/esp-idf/en/latest/esp32/api-reference/network/esp_wifi.html#_CPPv417wifi_sta_config_t - wifi_config_t conf; - memset(&conf, 0, sizeof(conf)); - if (ap.get_ssid().size() > sizeof(conf.sta.ssid)) { - ESP_LOGE(TAG, "SSID too long"); - return false; - } - if (ap.get_password().size() > sizeof(conf.sta.password)) { - ESP_LOGE(TAG, "Password too long"); - return false; - } - memcpy(reinterpret_cast(conf.sta.ssid), ap.get_ssid().c_str(), ap.get_ssid().size()); - memcpy(reinterpret_cast(conf.sta.password), ap.get_password().c_str(), ap.get_password().size()); - - // The weakest authmode to accept in the fast scan mode - if (ap.get_password().empty()) { - conf.sta.threshold.authmode = WIFI_AUTH_OPEN; - } else { - conf.sta.threshold.authmode = WIFI_AUTH_WPA_WPA2_PSK; - } - -#ifdef USE_WIFI_WPA2_EAP - if (ap.get_eap().has_value()) { - conf.sta.threshold.authmode = WIFI_AUTH_WPA2_ENTERPRISE; - } -#endif - - if (ap.get_bssid().has_value()) { - conf.sta.bssid_set = true; - memcpy(conf.sta.bssid, ap.get_bssid()->data(), 6); - } else { - conf.sta.bssid_set = false; - } - if (ap.get_channel().has_value()) { - conf.sta.channel = *ap.get_channel(); - conf.sta.scan_method = WIFI_FAST_SCAN; - } else { - conf.sta.scan_method = WIFI_ALL_CHANNEL_SCAN; - } - // Listen interval for ESP32 station to receive beacon when WIFI_PS_MAX_MODEM is set. - // Units: AP beacon intervals. Defaults to 3 if set to 0. - conf.sta.listen_interval = 0; - - // Protected Management Frame - // Device will prefer to connect in PMF mode if other device also advertises PMF capability. - conf.sta.pmf_cfg.capable = true; - conf.sta.pmf_cfg.required = false; - - // note, we do our own filtering - // The minimum rssi to accept in the fast scan mode - conf.sta.threshold.rssi = -127; - - conf.sta.threshold.authmode = WIFI_AUTH_OPEN; - - wifi_config_t current_conf; - esp_err_t err; - err = esp_wifi_get_config(WIFI_IF_STA, ¤t_conf); - if (err != ERR_OK) { - ESP_LOGW(TAG, "esp_wifi_get_config failed: %s", esp_err_to_name(err)); - // can continue - } - - if (memcmp(¤t_conf, &conf, sizeof(wifi_config_t)) != 0) { // NOLINT - err = esp_wifi_disconnect(); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_wifi_disconnect failed: %s", esp_err_to_name(err)); - return false; - } - } - - err = esp_wifi_set_config(WIFI_IF_STA, &conf); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_wifi_set_config failed: %s", esp_err_to_name(err)); - return false; - } - - if (!this->wifi_sta_ip_config_(ap.get_manual_ip())) { - return false; - } - - // setup enterprise authentication if required -#ifdef USE_WIFI_WPA2_EAP - if (ap.get_eap().has_value()) { - // note: all certificates and keys have to be null terminated. Lengths are appended by +1 to include \0. - EAPAuth eap = ap.get_eap().value(); - err = esp_eap_client_set_identity((uint8_t *) eap.identity.c_str(), eap.identity.length()); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_eap_client_set_identity failed: %d", err); - } - int ca_cert_len = strlen(eap.ca_cert); - int client_cert_len = strlen(eap.client_cert); - int client_key_len = strlen(eap.client_key); - if (ca_cert_len) { - err = esp_eap_client_set_ca_cert((uint8_t *) eap.ca_cert, ca_cert_len + 1); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_eap_client_set_ca_cert failed: %d", err); - } - } - // workout what type of EAP this is - // validation is not required as the config tool has already validated it - if (client_cert_len && client_key_len) { - // if we have certs, this must be EAP-TLS - err = esp_eap_client_set_certificate_and_key((uint8_t *) eap.client_cert, client_cert_len + 1, - (uint8_t *) eap.client_key, client_key_len + 1, - (uint8_t *) eap.password.c_str(), strlen(eap.password.c_str())); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_eap_client_set_certificate_and_key failed: %d", err); - } - } else { - // in the absence of certs, assume this is username/password based - err = esp_eap_client_set_username((uint8_t *) eap.username.c_str(), eap.username.length()); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_eap_client_set_username failed: %d", err); - } - err = esp_eap_client_set_password((uint8_t *) eap.password.c_str(), eap.password.length()); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_eap_client_set_password failed: %d", err); - } - } - err = esp_wifi_sta_enterprise_enable(); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_wifi_sta_enterprise_enable failed: %d", err); - } - } -#endif // USE_WIFI_WPA2_EAP - - this->wifi_apply_hostname_(); - - s_sta_connecting = true; - - err = esp_wifi_connect(); - if (err != ESP_OK) { - ESP_LOGW(TAG, "esp_wifi_connect failed: %s", esp_err_to_name(err)); - return false; - } - - return true; -} - -bool WiFiComponent::wifi_sta_ip_config_(optional manual_ip) { - // enable STA - if (!this->wifi_mode_(true, {})) - return false; - - // Check if the STA interface is initialized before using it - if (s_sta_netif == nullptr) { - ESP_LOGW(TAG, "STA interface not initialized"); - return false; - } - - esp_netif_dhcp_status_t dhcp_status; - esp_err_t err = esp_netif_dhcpc_get_status(s_sta_netif, &dhcp_status); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_netif_dhcpc_get_status failed: %s", esp_err_to_name(err)); - return false; - } - - if (!manual_ip.has_value()) { - // sntp_servermode_dhcp lwip/sntp.c (Required to lock TCPIP core functionality!) - // https://github.com/esphome/issues/issues/6591 - // https://github.com/espressif/arduino-esp32/issues/10526 - { - LwIPLock lock; - // lwIP starts the SNTP client if it gets an SNTP server from DHCP. We don't need the time, and more importantly, - // the built-in SNTP client has a memory leak in certain situations. Disable this feature. - // https://github.com/esphome/issues/issues/2299 - sntp_servermode_dhcp(false); - } - - // No manual IP is set; use DHCP client - if (dhcp_status != ESP_NETIF_DHCP_STARTED) { - err = esp_netif_dhcpc_start(s_sta_netif); - if (err != ESP_OK) { - ESP_LOGV(TAG, "Starting DHCP client failed: %d", err); - } - return err == ESP_OK; - } - return true; - } - - esp_netif_ip_info_t info; // struct of ip4_addr_t with ip, netmask, gw - info.ip = manual_ip->static_ip; - info.gw = manual_ip->gateway; - info.netmask = manual_ip->subnet; - err = esp_netif_dhcpc_stop(s_sta_netif); - if (err != ESP_OK && err != ESP_ERR_ESP_NETIF_DHCP_ALREADY_STOPPED) { - ESP_LOGV(TAG, "Stopping DHCP client failed: %s", esp_err_to_name(err)); - } - - err = esp_netif_set_ip_info(s_sta_netif, &info); - if (err != ESP_OK) { - ESP_LOGV(TAG, "Setting manual IP info failed: %s", esp_err_to_name(err)); - } - - esp_netif_dns_info_t dns; - if (manual_ip->dns1.is_set()) { - dns.ip = manual_ip->dns1; - esp_netif_set_dns_info(s_sta_netif, ESP_NETIF_DNS_MAIN, &dns); - } - if (manual_ip->dns2.is_set()) { - dns.ip = manual_ip->dns2; - esp_netif_set_dns_info(s_sta_netif, ESP_NETIF_DNS_BACKUP, &dns); - } - - return true; -} - -network::IPAddresses WiFiComponent::wifi_sta_ip_addresses() { - if (!this->has_sta()) - return {}; - network::IPAddresses addresses; - esp_netif_ip_info_t ip; - esp_err_t err = esp_netif_get_ip_info(s_sta_netif, &ip); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_netif_get_ip_info failed: %s", esp_err_to_name(err)); - // TODO: do something smarter - // return false; - } else { - addresses[0] = network::IPAddress(&ip.ip); - } -#if USE_NETWORK_IPV6 - struct esp_ip6_addr if_ip6s[CONFIG_LWIP_IPV6_NUM_ADDRESSES]; - uint8_t count = 0; - count = esp_netif_get_all_ip6(s_sta_netif, if_ip6s); - assert(count <= CONFIG_LWIP_IPV6_NUM_ADDRESSES); - for (int i = 0; i < count; i++) { - addresses[i + 1] = network::IPAddress(&if_ip6s[i]); - } -#endif /* USE_NETWORK_IPV6 */ - return addresses; -} - -bool WiFiComponent::wifi_apply_hostname_() { - // setting is done in SYSTEM_EVENT_STA_START callback - return true; -} -const char *get_auth_mode_str(uint8_t mode) { - switch (mode) { - case WIFI_AUTH_OPEN: - return "OPEN"; - case WIFI_AUTH_WEP: - return "WEP"; - case WIFI_AUTH_WPA_PSK: - return "WPA PSK"; - case WIFI_AUTH_WPA2_PSK: - return "WPA2 PSK"; - case WIFI_AUTH_WPA_WPA2_PSK: - return "WPA/WPA2 PSK"; - case WIFI_AUTH_WPA2_ENTERPRISE: - return "WPA2 Enterprise"; - case WIFI_AUTH_WPA3_PSK: - return "WPA3 PSK"; - case WIFI_AUTH_WPA2_WPA3_PSK: - return "WPA2/WPA3 PSK"; - case WIFI_AUTH_WAPI_PSK: - return "WAPI PSK"; - default: - return "UNKNOWN"; - } -} - -using esphome_ip4_addr_t = esp_ip4_addr_t; - -std::string format_ip4_addr(const esphome_ip4_addr_t &ip) { - char buf[20]; - sprintf(buf, "%u.%u.%u.%u", uint8_t(ip.addr >> 0), uint8_t(ip.addr >> 8), uint8_t(ip.addr >> 16), - uint8_t(ip.addr >> 24)); - return buf; -} -const char *get_op_mode_str(uint8_t mode) { - switch (mode) { - case WIFI_OFF: - return "OFF"; - case WIFI_STA: - return "STA"; - case WIFI_AP: - return "AP"; - case WIFI_AP_STA: - return "AP+STA"; - default: - return "UNKNOWN"; - } -} -const char *get_disconnect_reason_str(uint8_t reason) { - switch (reason) { - case WIFI_REASON_AUTH_EXPIRE: - return "Auth Expired"; - case WIFI_REASON_AUTH_LEAVE: - return "Auth Leave"; - case WIFI_REASON_ASSOC_EXPIRE: - return "Association Expired"; - case WIFI_REASON_ASSOC_TOOMANY: - return "Too Many Associations"; - case WIFI_REASON_NOT_AUTHED: - return "Not Authenticated"; - case WIFI_REASON_NOT_ASSOCED: - return "Not Associated"; - case WIFI_REASON_ASSOC_LEAVE: - return "Association Leave"; - case WIFI_REASON_ASSOC_NOT_AUTHED: - return "Association not Authenticated"; - case WIFI_REASON_DISASSOC_PWRCAP_BAD: - return "Disassociate Power Cap Bad"; - case WIFI_REASON_DISASSOC_SUPCHAN_BAD: - return "Disassociate Supported Channel Bad"; - case WIFI_REASON_IE_INVALID: - return "IE Invalid"; - case WIFI_REASON_MIC_FAILURE: - return "Mic Failure"; - case WIFI_REASON_4WAY_HANDSHAKE_TIMEOUT: - return "4-Way Handshake Timeout"; - case WIFI_REASON_GROUP_KEY_UPDATE_TIMEOUT: - return "Group Key Update Timeout"; - case WIFI_REASON_IE_IN_4WAY_DIFFERS: - return "IE In 4-Way Handshake Differs"; - case WIFI_REASON_GROUP_CIPHER_INVALID: - return "Group Cipher Invalid"; - case WIFI_REASON_PAIRWISE_CIPHER_INVALID: - return "Pairwise Cipher Invalid"; - case WIFI_REASON_AKMP_INVALID: - return "AKMP Invalid"; - case WIFI_REASON_UNSUPP_RSN_IE_VERSION: - return "Unsupported RSN IE version"; - case WIFI_REASON_INVALID_RSN_IE_CAP: - return "Invalid RSN IE Cap"; - case WIFI_REASON_802_1X_AUTH_FAILED: - return "802.1x Authentication Failed"; - case WIFI_REASON_CIPHER_SUITE_REJECTED: - return "Cipher Suite Rejected"; - case WIFI_REASON_BEACON_TIMEOUT: - return "Beacon Timeout"; - case WIFI_REASON_NO_AP_FOUND: - return "AP Not Found"; - case WIFI_REASON_AUTH_FAIL: - return "Authentication Failed"; - case WIFI_REASON_ASSOC_FAIL: - return "Association Failed"; - case WIFI_REASON_HANDSHAKE_TIMEOUT: - return "Handshake Failed"; - case WIFI_REASON_CONNECTION_FAIL: - return "Connection Failed"; - case WIFI_REASON_AP_TSF_RESET: - return "AP TSF reset"; - case WIFI_REASON_ROAMING: - return "Station Roaming"; - case WIFI_REASON_ASSOC_COMEBACK_TIME_TOO_LONG: - return "Association comeback time too long"; - case WIFI_REASON_SA_QUERY_TIMEOUT: - return "SA query timeout"; - case WIFI_REASON_NO_AP_FOUND_W_COMPATIBLE_SECURITY: - return "No AP found with compatible security"; - case WIFI_REASON_NO_AP_FOUND_IN_AUTHMODE_THRESHOLD: - return "No AP found in auth mode threshold"; - case WIFI_REASON_NO_AP_FOUND_IN_RSSI_THRESHOLD: - return "No AP found in RSSI threshold"; - case WIFI_REASON_UNSPECIFIED: - default: - return "Unspecified"; - } -} - -void WiFiComponent::wifi_loop_() {} - -#define ESPHOME_EVENT_ID_WIFI_READY ARDUINO_EVENT_WIFI_READY -#define ESPHOME_EVENT_ID_WIFI_SCAN_DONE ARDUINO_EVENT_WIFI_SCAN_DONE -#define ESPHOME_EVENT_ID_WIFI_STA_START ARDUINO_EVENT_WIFI_STA_START -#define ESPHOME_EVENT_ID_WIFI_STA_STOP ARDUINO_EVENT_WIFI_STA_STOP -#define ESPHOME_EVENT_ID_WIFI_STA_CONNECTED ARDUINO_EVENT_WIFI_STA_CONNECTED -#define ESPHOME_EVENT_ID_WIFI_STA_DISCONNECTED ARDUINO_EVENT_WIFI_STA_DISCONNECTED -#define ESPHOME_EVENT_ID_WIFI_STA_AUTHMODE_CHANGE ARDUINO_EVENT_WIFI_STA_AUTHMODE_CHANGE -#define ESPHOME_EVENT_ID_WIFI_STA_GOT_IP ARDUINO_EVENT_WIFI_STA_GOT_IP -#define ESPHOME_EVENT_ID_WIFI_STA_GOT_IP6 ARDUINO_EVENT_WIFI_STA_GOT_IP6 -#define ESPHOME_EVENT_ID_WIFI_STA_LOST_IP ARDUINO_EVENT_WIFI_STA_LOST_IP -#define ESPHOME_EVENT_ID_WIFI_AP_START ARDUINO_EVENT_WIFI_AP_START -#define ESPHOME_EVENT_ID_WIFI_AP_STOP ARDUINO_EVENT_WIFI_AP_STOP -#define ESPHOME_EVENT_ID_WIFI_AP_STACONNECTED ARDUINO_EVENT_WIFI_AP_STACONNECTED -#define ESPHOME_EVENT_ID_WIFI_AP_STADISCONNECTED ARDUINO_EVENT_WIFI_AP_STADISCONNECTED -#define ESPHOME_EVENT_ID_WIFI_AP_STAIPASSIGNED ARDUINO_EVENT_WIFI_AP_STAIPASSIGNED -#define ESPHOME_EVENT_ID_WIFI_AP_PROBEREQRECVED ARDUINO_EVENT_WIFI_AP_PROBEREQRECVED -#define ESPHOME_EVENT_ID_WIFI_AP_GOT_IP6 ARDUINO_EVENT_WIFI_AP_GOT_IP6 -using esphome_wifi_event_id_t = arduino_event_id_t; -using esphome_wifi_event_info_t = arduino_event_info_t; - -void WiFiComponent::wifi_event_callback_(esphome_wifi_event_id_t event, esphome_wifi_event_info_t info) { - switch (event) { - case ESPHOME_EVENT_ID_WIFI_READY: { - ESP_LOGV(TAG, "Ready"); - break; - } - case ESPHOME_EVENT_ID_WIFI_SCAN_DONE: { - auto it = info.wifi_scan_done; - ESP_LOGV(TAG, "Scan done: status=%u number=%u scan_id=%u", it.status, it.number, it.scan_id); - - this->wifi_scan_done_callback_(); - break; - } - case ESPHOME_EVENT_ID_WIFI_STA_START: { - ESP_LOGV(TAG, "STA start"); - // apply hostname - s_sta_netif = esp_netif_get_handle_from_ifkey("WIFI_STA_DEF"); - esp_err_t err = esp_netif_set_hostname(s_sta_netif, App.get_name().c_str()); - if (err != ERR_OK) { - ESP_LOGW(TAG, "esp_netif_set_hostname failed: %s", esp_err_to_name(err)); - } - break; - } - case ESPHOME_EVENT_ID_WIFI_STA_STOP: { - ESP_LOGV(TAG, "STA stop"); - break; - } - case ESPHOME_EVENT_ID_WIFI_STA_CONNECTED: { - auto it = info.wifi_sta_connected; - char buf[33]; - memcpy(buf, it.ssid, it.ssid_len); - buf[it.ssid_len] = '\0'; - ESP_LOGV(TAG, "Connected ssid='%s' bssid=" LOG_SECRET("%s") " channel=%u, authmode=%s", buf, - format_mac_address_pretty(it.bssid).c_str(), it.channel, get_auth_mode_str(it.authmode)); -#if USE_NETWORK_IPV6 - this->set_timeout(100, [] { WiFi.enableIPv6(); }); -#endif /* USE_NETWORK_IPV6 */ - - break; - } - case ESPHOME_EVENT_ID_WIFI_STA_DISCONNECTED: { - auto it = info.wifi_sta_disconnected; - char buf[33]; - memcpy(buf, it.ssid, it.ssid_len); - buf[it.ssid_len] = '\0'; - if (it.reason == WIFI_REASON_NO_AP_FOUND) { - ESP_LOGW(TAG, "Disconnected ssid='%s' reason='Probe Request Unsuccessful'", buf); - } else { - ESP_LOGW(TAG, "Disconnected ssid='%s' bssid=" LOG_SECRET("%s") " reason='%s'", buf, - format_mac_address_pretty(it.bssid).c_str(), get_disconnect_reason_str(it.reason)); - } - - uint8_t reason = it.reason; - if (reason == WIFI_REASON_AUTH_EXPIRE || reason == WIFI_REASON_BEACON_TIMEOUT || - reason == WIFI_REASON_NO_AP_FOUND || reason == WIFI_REASON_ASSOC_FAIL || - reason == WIFI_REASON_HANDSHAKE_TIMEOUT) { - err_t err = esp_wifi_disconnect(); - if (err != ESP_OK) { - ESP_LOGV(TAG, "Disconnect failed: %s", esp_err_to_name(err)); - } - this->error_from_callback_ = true; - } - - s_sta_connecting = false; - break; - } - case ESPHOME_EVENT_ID_WIFI_STA_AUTHMODE_CHANGE: { - auto it = info.wifi_sta_authmode_change; - ESP_LOGV(TAG, "Authmode Change old=%s new=%s", get_auth_mode_str(it.old_mode), get_auth_mode_str(it.new_mode)); - // Mitigate CVE-2020-12638 - // https://lbsfilm.at/blog/wpa2-authenticationmode-downgrade-in-espressif-microprocessors - if (it.old_mode != WIFI_AUTH_OPEN && it.new_mode == WIFI_AUTH_OPEN) { - ESP_LOGW(TAG, "Potential Authmode downgrade detected, disconnecting"); - // we can't call retry_connect() from this context, so disconnect immediately - // and notify main thread with error_from_callback_ - err_t err = esp_wifi_disconnect(); - if (err != ESP_OK) { - ESP_LOGW(TAG, "Disconnect failed: %s", esp_err_to_name(err)); - } - this->error_from_callback_ = true; - } - break; - } - case ESPHOME_EVENT_ID_WIFI_STA_GOT_IP: { - auto it = info.got_ip.ip_info; - ESP_LOGV(TAG, "static_ip=%s gateway=%s", format_ip4_addr(it.ip).c_str(), format_ip4_addr(it.gw).c_str()); - this->got_ipv4_address_ = true; -#if USE_NETWORK_IPV6 - s_sta_connecting = this->num_ipv6_addresses_ < USE_NETWORK_MIN_IPV6_ADDR_COUNT; -#else - s_sta_connecting = false; -#endif /* USE_NETWORK_IPV6 */ - break; - } -#if USE_NETWORK_IPV6 - case ESPHOME_EVENT_ID_WIFI_STA_GOT_IP6: { - auto it = info.got_ip6.ip6_info; - ESP_LOGV(TAG, "IPv6 address=" IPV6STR, IPV62STR(it.ip)); - this->num_ipv6_addresses_++; - s_sta_connecting = !(this->got_ipv4_address_ & (this->num_ipv6_addresses_ >= USE_NETWORK_MIN_IPV6_ADDR_COUNT)); - break; - } -#endif /* USE_NETWORK_IPV6 */ - case ESPHOME_EVENT_ID_WIFI_STA_LOST_IP: { - ESP_LOGV(TAG, "Lost IP"); - this->got_ipv4_address_ = false; - break; - } - case ESPHOME_EVENT_ID_WIFI_AP_START: { - ESP_LOGV(TAG, "AP start"); - break; - } - case ESPHOME_EVENT_ID_WIFI_AP_STOP: { - ESP_LOGV(TAG, "AP stop"); - break; - } - case ESPHOME_EVENT_ID_WIFI_AP_STACONNECTED: { - auto it = info.wifi_sta_connected; - auto &mac = it.bssid; - ESP_LOGV(TAG, "AP client connected MAC=%s", format_mac_address_pretty(mac).c_str()); - break; - } - case ESPHOME_EVENT_ID_WIFI_AP_STADISCONNECTED: { - auto it = info.wifi_sta_disconnected; - auto &mac = it.bssid; - ESP_LOGV(TAG, "AP client disconnected MAC=%s", format_mac_address_pretty(mac).c_str()); - break; - } - case ESPHOME_EVENT_ID_WIFI_AP_STAIPASSIGNED: { - ESP_LOGV(TAG, "AP client assigned IP"); - break; - } - case ESPHOME_EVENT_ID_WIFI_AP_PROBEREQRECVED: { - auto it = info.wifi_ap_probereqrecved; - ESP_LOGVV(TAG, "AP receive Probe Request MAC=%s RSSI=%d", format_mac_address_pretty(it.mac).c_str(), it.rssi); - break; - } - default: - break; - } -} - -WiFiSTAConnectStatus WiFiComponent::wifi_sta_connect_status_() { - const auto status = WiFi.status(); - if (status == WL_CONNECT_FAILED || status == WL_CONNECTION_LOST) { - return WiFiSTAConnectStatus::ERROR_CONNECT_FAILED; - } - if (status == WL_NO_SSID_AVAIL) { - return WiFiSTAConnectStatus::ERROR_NETWORK_NOT_FOUND; - } - if (s_sta_connecting) { - return WiFiSTAConnectStatus::CONNECTING; - } - if (status == WL_CONNECTED) { - return WiFiSTAConnectStatus::CONNECTED; - } - return WiFiSTAConnectStatus::IDLE; -} -bool WiFiComponent::wifi_scan_start_(bool passive) { - // enable STA - if (!this->wifi_mode_(true, {})) - return false; - - // need to use WiFi because of WiFiScanClass allocations :( - int16_t err = WiFi.scanNetworks(true, true, passive, 200); - if (err != WIFI_SCAN_RUNNING) { - ESP_LOGV(TAG, "WiFi.scanNetworks failed: %d", err); - return false; - } - - return true; -} -void WiFiComponent::wifi_scan_done_callback_() { - this->scan_result_.clear(); - - int16_t num = WiFi.scanComplete(); - if (num < 0) - return; - - this->scan_result_.reserve(static_cast(num)); - for (int i = 0; i < num; i++) { - String ssid = WiFi.SSID(i); - wifi_auth_mode_t authmode = WiFi.encryptionType(i); - int32_t rssi = WiFi.RSSI(i); - uint8_t *bssid = WiFi.BSSID(i); - int32_t channel = WiFi.channel(i); - - WiFiScanResult scan({bssid[0], bssid[1], bssid[2], bssid[3], bssid[4], bssid[5]}, std::string(ssid.c_str()), - channel, rssi, authmode != WIFI_AUTH_OPEN, ssid.length() == 0); - this->scan_result_.push_back(scan); - } - WiFi.scanDelete(); - this->scan_done_ = true; -} - -#ifdef USE_WIFI_AP -bool WiFiComponent::wifi_ap_ip_config_(optional manual_ip) { - esp_err_t err; - - // enable AP - if (!this->wifi_mode_({}, true)) - return false; - - // Check if the AP interface is initialized before using it - if (s_ap_netif == nullptr) { - ESP_LOGW(TAG, "AP interface not initialized"); - return false; - } - - esp_netif_ip_info_t info; - if (manual_ip.has_value()) { - info.ip = manual_ip->static_ip; - info.gw = manual_ip->gateway; - info.netmask = manual_ip->subnet; - } else { - info.ip = network::IPAddress(192, 168, 4, 1); - info.gw = network::IPAddress(192, 168, 4, 1); - info.netmask = network::IPAddress(255, 255, 255, 0); - } - - err = esp_netif_dhcps_stop(s_ap_netif); - if (err != ESP_OK && err != ESP_ERR_ESP_NETIF_DHCP_ALREADY_STOPPED) { - ESP_LOGE(TAG, "esp_netif_dhcps_stop failed: %s", esp_err_to_name(err)); - return false; - } - - err = esp_netif_set_ip_info(s_ap_netif, &info); - if (err != ESP_OK) { - ESP_LOGE(TAG, "esp_netif_set_ip_info failed: %d", err); - return false; - } - - dhcps_lease_t lease; - lease.enable = true; - network::IPAddress start_address = network::IPAddress(&info.ip); - start_address += 99; - lease.start_ip = start_address; - ESP_LOGV(TAG, "DHCP server IP lease start: %s", start_address.str().c_str()); - start_address += 10; - lease.end_ip = start_address; - ESP_LOGV(TAG, "DHCP server IP lease end: %s", start_address.str().c_str()); - err = esp_netif_dhcps_option(s_ap_netif, ESP_NETIF_OP_SET, ESP_NETIF_REQUESTED_IP_ADDRESS, &lease, sizeof(lease)); - - if (err != ESP_OK) { - ESP_LOGE(TAG, "esp_netif_dhcps_option failed: %d", err); - return false; - } - - err = esp_netif_dhcps_start(s_ap_netif); - - if (err != ESP_OK) { - ESP_LOGE(TAG, "esp_netif_dhcps_start failed: %d", err); - return false; - } - - return true; -} - -bool WiFiComponent::wifi_start_ap_(const WiFiAP &ap) { - // enable AP - if (!this->wifi_mode_({}, true)) - return false; - - wifi_config_t conf; - memset(&conf, 0, sizeof(conf)); - if (ap.get_ssid().size() > sizeof(conf.ap.ssid)) { - ESP_LOGE(TAG, "AP SSID too long"); - return false; - } - memcpy(reinterpret_cast(conf.ap.ssid), ap.get_ssid().c_str(), ap.get_ssid().size()); - conf.ap.channel = ap.get_channel().value_or(1); - conf.ap.ssid_hidden = ap.get_ssid().size(); - conf.ap.max_connection = 5; - conf.ap.beacon_interval = 100; - - if (ap.get_password().empty()) { - conf.ap.authmode = WIFI_AUTH_OPEN; - *conf.ap.password = 0; - } else { - conf.ap.authmode = WIFI_AUTH_WPA2_PSK; - if (ap.get_password().size() > sizeof(conf.ap.password)) { - ESP_LOGE(TAG, "AP password too long"); - return false; - } - memcpy(reinterpret_cast(conf.ap.password), ap.get_password().c_str(), ap.get_password().size()); - } - - // pairwise cipher of SoftAP, group cipher will be derived using this. - conf.ap.pairwise_cipher = WIFI_CIPHER_TYPE_CCMP; - - esp_err_t err = esp_wifi_set_config(WIFI_IF_AP, &conf); - if (err != ESP_OK) { - ESP_LOGV(TAG, "esp_wifi_set_config failed: %d", err); - return false; - } - - yield(); - - if (!this->wifi_ap_ip_config_(ap.get_manual_ip())) { - ESP_LOGV(TAG, "wifi_ap_ip_config_ failed"); - return false; - } - - return true; -} - -network::IPAddress WiFiComponent::wifi_soft_ap_ip() { - esp_netif_ip_info_t ip; - esp_netif_get_ip_info(s_ap_netif, &ip); - return network::IPAddress(&ip.ip); -} -#endif // USE_WIFI_AP - -bool WiFiComponent::wifi_disconnect_() { return esp_wifi_disconnect(); } - -bssid_t WiFiComponent::wifi_bssid() { - bssid_t bssid{}; - uint8_t *raw_bssid = WiFi.BSSID(); - if (raw_bssid != nullptr) { - for (size_t i = 0; i < bssid.size(); i++) - bssid[i] = raw_bssid[i]; - } - return bssid; -} -std::string WiFiComponent::wifi_ssid() { return WiFi.SSID().c_str(); } -int8_t WiFiComponent::wifi_rssi() { return WiFi.RSSI(); } -int32_t WiFiComponent::get_wifi_channel() { return WiFi.channel(); } -network::IPAddress WiFiComponent::wifi_subnet_mask_() { return network::IPAddress(WiFi.subnetMask()); } -network::IPAddress WiFiComponent::wifi_gateway_ip_() { return network::IPAddress(WiFi.gatewayIP()); } -network::IPAddress WiFiComponent::wifi_dns_ip_(int num) { return network::IPAddress(WiFi.dnsIP(num)); } - -} // namespace wifi -} // namespace esphome - -#endif // USE_ESP32_FRAMEWORK_ARDUINO -#endif diff --git a/esphome/components/wifi/wifi_component_esp8266.cpp b/esphome/components/wifi/wifi_component_esp8266.cpp index ae1daed8b5..3b3b4b139c 100644 --- a/esphome/components/wifi/wifi_component_esp8266.cpp +++ b/esphome/components/wifi/wifi_component_esp8266.cpp @@ -301,7 +301,7 @@ bool WiFiComponent::wifi_sta_connect_(const WiFiAP &ap) { // if we have certs, this must be EAP-TLS ret = wifi_station_set_enterprise_cert_key((uint8_t *) eap.client_cert, client_cert_len + 1, (uint8_t *) eap.client_key, client_key_len + 1, - (uint8_t *) eap.password.c_str(), strlen(eap.password.c_str())); + (uint8_t *) eap.password.c_str(), eap.password.length()); if (ret) { ESP_LOGV(TAG, "esp_wifi_sta_wpa2_ent_set_cert_key failed: %d", ret); } diff --git a/esphome/components/wifi/wifi_component_esp_idf.cpp b/esphome/components/wifi/wifi_component_esp_idf.cpp index 31ee712a48..ccec800205 100644 --- a/esphome/components/wifi/wifi_component_esp_idf.cpp +++ b/esphome/components/wifi/wifi_component_esp_idf.cpp @@ -1,7 +1,7 @@ #include "wifi_component.h" #ifdef USE_WIFI -#ifdef USE_ESP_IDF +#ifdef USE_ESP32 #include #include @@ -27,6 +27,10 @@ #include "dhcpserver/dhcpserver.h" #endif // USE_WIFI_AP +#ifdef USE_CAPTIVE_PORTAL +#include "esphome/components/captive_portal/captive_portal.h" +#endif + #include "lwip/apps/sntp.h" #include "lwip/dns.h" #include "lwip/err.h" @@ -404,11 +408,11 @@ bool WiFiComponent::wifi_sta_connect_(const WiFiAP &ap) { #if (ESP_IDF_VERSION_MAJOR >= 5) && (ESP_IDF_VERSION_MINOR >= 1) err = esp_eap_client_set_certificate_and_key((uint8_t *) eap.client_cert, client_cert_len + 1, (uint8_t *) eap.client_key, client_key_len + 1, - (uint8_t *) eap.password.c_str(), strlen(eap.password.c_str())); + (uint8_t *) eap.password.c_str(), eap.password.length()); #else err = esp_wifi_sta_wpa2_ent_set_cert_key((uint8_t *) eap.client_cert, client_cert_len + 1, (uint8_t *) eap.client_key, client_key_len + 1, - (uint8_t *) eap.password.c_str(), strlen(eap.password.c_str())); + (uint8_t *) eap.password.c_str(), eap.password.length()); #endif if (err != ESP_OK) { ESP_LOGV(TAG, "set_cert_key failed %d", err); @@ -918,6 +922,22 @@ bool WiFiComponent::wifi_ap_ip_config_(optional manual_ip) { return false; } +#if defined(USE_CAPTIVE_PORTAL) && ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 4, 0) + // Configure DHCP Option 114 (Captive Portal URI) if captive portal is enabled + // This provides a standards-compliant way for clients to discover the captive portal + if (captive_portal::global_captive_portal != nullptr) { + static char captive_portal_uri[32]; + snprintf(captive_portal_uri, sizeof(captive_portal_uri), "http://%s", network::IPAddress(&info.ip).str().c_str()); + err = esp_netif_dhcps_option(s_ap_netif, ESP_NETIF_OP_SET, ESP_NETIF_CAPTIVEPORTAL_URI, captive_portal_uri, + strlen(captive_portal_uri)); + if (err != ESP_OK) { + ESP_LOGV(TAG, "Failed to set DHCP captive portal URI: %s", esp_err_to_name(err)); + } else { + ESP_LOGV(TAG, "DHCP Captive Portal URI set to: %s", captive_portal_uri); + } + } +#endif + err = esp_netif_dhcps_start(s_ap_netif); if (err != ESP_OK) { @@ -1050,5 +1070,5 @@ network::IPAddress WiFiComponent::wifi_dns_ip_(int num) { } // namespace wifi } // namespace esphome -#endif // USE_ESP_IDF +#endif // USE_ESP32 #endif diff --git a/esphome/components/wireguard/__init__.py b/esphome/components/wireguard/__init__.py index 8eff8e7b2a..50c7980215 100644 --- a/esphome/components/wireguard/__init__.py +++ b/esphome/components/wireguard/__init__.py @@ -118,7 +118,7 @@ async def to_code(config): # Workaround for crash on IDF 5+ # See https://github.com/trombik/esp_wireguard/issues/33#issuecomment-1568503651 - if CORE.using_esp_idf: + if CORE.is_esp32: add_idf_sdkconfig_option("CONFIG_LWIP_PPP_SUPPORT", True) # This flag is added here because the esp_wireguard library statically diff --git a/esphome/components/wts01/__init__.py b/esphome/components/wts01/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/esphome/components/wts01/sensor.py b/esphome/components/wts01/sensor.py new file mode 100644 index 0000000000..bf4f0262ad --- /dev/null +++ b/esphome/components/wts01/sensor.py @@ -0,0 +1,41 @@ +import esphome.codegen as cg +from esphome.components import sensor, uart +import esphome.config_validation as cv +from esphome.const import ( + DEVICE_CLASS_TEMPERATURE, + STATE_CLASS_MEASUREMENT, + UNIT_CELSIUS, +) + +CONF_WTS01_ID = "wts01_id" +CODEOWNERS = ["@alepee"] +DEPENDENCIES = ["uart"] + +wts01_ns = cg.esphome_ns.namespace("wts01") +WTS01Sensor = wts01_ns.class_( + "WTS01Sensor", cg.Component, uart.UARTDevice, sensor.Sensor +) + +CONFIG_SCHEMA = ( + sensor.sensor_schema( + WTS01Sensor, + unit_of_measurement=UNIT_CELSIUS, + accuracy_decimals=1, + device_class=DEVICE_CLASS_TEMPERATURE, + state_class=STATE_CLASS_MEASUREMENT, + ) + .extend(cv.COMPONENT_SCHEMA) + .extend(uart.UART_DEVICE_SCHEMA) +) + +FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema( + "wts01", + baud_rate=9600, + require_rx=True, +) + + +async def to_code(config): + var = await sensor.new_sensor(config) + await cg.register_component(var, config) + await uart.register_uart_device(var, config) diff --git a/esphome/components/wts01/wts01.cpp b/esphome/components/wts01/wts01.cpp new file mode 100644 index 0000000000..cb910d89cf --- /dev/null +++ b/esphome/components/wts01/wts01.cpp @@ -0,0 +1,91 @@ +#include "wts01.h" +#include "esphome/core/log.h" +#include + +namespace esphome { +namespace wts01 { + +constexpr uint8_t HEADER_1 = 0x55; +constexpr uint8_t HEADER_2 = 0x01; +constexpr uint8_t HEADER_3 = 0x01; +constexpr uint8_t HEADER_4 = 0x04; + +static const char *const TAG = "wts01"; + +void WTS01Sensor::loop() { + // Process all available data at once + while (this->available()) { + uint8_t c; + if (this->read_byte(&c)) { + this->handle_char_(c); + } + } +} + +void WTS01Sensor::dump_config() { LOG_SENSOR("", "WTS01 Sensor", this); } + +void WTS01Sensor::handle_char_(uint8_t c) { + // State machine for processing the header. Reset if something doesn't match. + if (this->buffer_pos_ == 0 && c != HEADER_1) { + return; + } + + if (this->buffer_pos_ == 1 && c != HEADER_2) { + this->buffer_pos_ = 0; + return; + } + + if (this->buffer_pos_ == 2 && c != HEADER_3) { + this->buffer_pos_ = 0; + return; + } + + if (this->buffer_pos_ == 3 && c != HEADER_4) { + this->buffer_pos_ = 0; + return; + } + + // Add byte to buffer + this->buffer_[this->buffer_pos_++] = c; + + // Process complete packet + if (this->buffer_pos_ >= PACKET_SIZE) { + this->process_packet_(); + this->buffer_pos_ = 0; + } +} + +void WTS01Sensor::process_packet_() { + // Based on Tasmota implementation + // Format: 55 01 01 04 01 11 16 12 95 + // header T Td Ck - T = Temperature, Td = Temperature decimal, Ck = Checksum + uint8_t calculated_checksum = 0; + for (uint8_t i = 0; i < PACKET_SIZE - 1; i++) { + calculated_checksum += this->buffer_[i]; + } + + uint8_t received_checksum = this->buffer_[PACKET_SIZE - 1]; + if (calculated_checksum != received_checksum) { + ESP_LOGW(TAG, "WTS01 Checksum doesn't match: 0x%02X != 0x%02X", received_checksum, calculated_checksum); + return; + } + + // Extract temperature value + int8_t temp = this->buffer_[6]; + int32_t sign = 1; + + // Handle negative temperatures + if (temp < 0) { + sign = -1; + } + + // Calculate temperature (temp + decimal/100) + float temperature = static_cast(temp) + (sign * static_cast(this->buffer_[7]) / 100.0f); + + ESP_LOGV(TAG, "Received new temperature: %.2f°C", temperature); + + this->publish_state(temperature); +} + +} // namespace wts01 +} // namespace esphome diff --git a/esphome/components/wts01/wts01.h b/esphome/components/wts01/wts01.h new file mode 100644 index 0000000000..298595a5d6 --- /dev/null +++ b/esphome/components/wts01/wts01.h @@ -0,0 +1,27 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/components/sensor/sensor.h" +#include "esphome/components/uart/uart.h" + +namespace esphome { +namespace wts01 { + +constexpr uint8_t PACKET_SIZE = 9; + +class WTS01Sensor : public sensor::Sensor, public uart::UARTDevice, public Component { + public: + void loop() override; + void dump_config() override; + float get_setup_priority() const override { return setup_priority::DATA; } + + protected: + uint8_t buffer_[PACKET_SIZE]; + uint8_t buffer_pos_{0}; + + void handle_char_(uint8_t c); + void process_packet_(); +}; + +} // namespace wts01 +} // namespace esphome diff --git a/esphome/components/zephyr/__init__.py b/esphome/components/zephyr/__init__.py index c698122030..ff4644163e 100644 --- a/esphome/components/zephyr/__init__.py +++ b/esphome/components/zephyr/__init__.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from typing import TypedDict import esphome.codegen as cg @@ -48,7 +48,7 @@ class ZephyrData(TypedDict): bootloader: str prj_conf: dict[str, tuple[PrjConfValueType, bool]] overlay: str - extra_build_files: dict[str, str] + extra_build_files: dict[str, Path] pm_static: list[Section] user: dict[str, list[str]] @@ -93,7 +93,7 @@ def zephyr_add_overlay(content): zephyr_data()[KEY_OVERLAY] += content -def add_extra_build_file(filename: str, path: str) -> bool: +def add_extra_build_file(filename: str, path: Path) -> bool: """Add an extra build file to the project.""" extra_build_files = zephyr_data()[KEY_EXTRA_BUILD_FILES] if filename not in extra_build_files: @@ -102,7 +102,7 @@ def add_extra_build_file(filename: str, path: str) -> bool: return False -def add_extra_script(stage: str, filename: str, path: str): +def add_extra_script(stage: str, filename: str, path: Path) -> None: """Add an extra script to the project.""" key = f"{stage}:{filename}" if add_extra_build_file(filename, path): @@ -144,7 +144,7 @@ def zephyr_to_code(config): add_extra_script( "pre", "pre_build.py", - os.path.join(os.path.dirname(__file__), "pre_build.py.script"), + Path(__file__).parent / "pre_build.py.script", ) diff --git a/esphome/components/zwave_proxy/__init__.py b/esphome/components/zwave_proxy/__init__.py new file mode 100644 index 0000000000..d88f9f7041 --- /dev/null +++ b/esphome/components/zwave_proxy/__init__.py @@ -0,0 +1,43 @@ +import esphome.codegen as cg +from esphome.components import uart +import esphome.config_validation as cv +from esphome.const import CONF_ID, CONF_POWER_SAVE_MODE, CONF_WIFI +import esphome.final_validate as fv + +CODEOWNERS = ["@kbx81"] +DEPENDENCIES = ["api", "uart"] + +zwave_proxy_ns = cg.esphome_ns.namespace("zwave_proxy") +ZWaveProxy = zwave_proxy_ns.class_("ZWaveProxy", cg.Component, uart.UARTDevice) + + +def final_validate(config): + full_config = fv.full_config.get() + if (wifi_conf := full_config.get(CONF_WIFI)) and ( + wifi_conf.get(CONF_POWER_SAVE_MODE).lower() != "none" + ): + raise cv.Invalid( + f"{CONF_WIFI} {CONF_POWER_SAVE_MODE} must be set to 'none' when using Z-Wave proxy" + ) + + return config + + +CONFIG_SCHEMA = ( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(ZWaveProxy), + } + ) + .extend(cv.COMPONENT_SCHEMA) + .extend(uart.UART_DEVICE_SCHEMA) +) + +FINAL_VALIDATE_SCHEMA = final_validate + + +async def to_code(config): + var = cg.new_Pvariable(config[CONF_ID]) + await cg.register_component(var, config) + await uart.register_uart_device(var, config) + cg.add_define("USE_ZWAVE_PROXY") diff --git a/esphome/components/zwave_proxy/zwave_proxy.cpp b/esphome/components/zwave_proxy/zwave_proxy.cpp new file mode 100644 index 0000000000..a26a9b2335 --- /dev/null +++ b/esphome/components/zwave_proxy/zwave_proxy.cpp @@ -0,0 +1,346 @@ +#include "zwave_proxy.h" +#include "esphome/components/api/api_server.h" +#include "esphome/core/application.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" +#include "esphome/core/util.h" + +namespace esphome { +namespace zwave_proxy { + +static const char *const TAG = "zwave_proxy"; + +static constexpr uint8_t ZWAVE_COMMAND_GET_NETWORK_IDS = 0x20; +// GET_NETWORK_IDS response: [SOF][LENGTH][TYPE][CMD][HOME_ID(4)][NODE_ID][...] +static constexpr uint8_t ZWAVE_COMMAND_TYPE_RESPONSE = 0x01; // Response type field value +static constexpr uint8_t ZWAVE_MIN_GET_NETWORK_IDS_LENGTH = 9; // TYPE + CMD + HOME_ID(4) + NODE_ID + checksum +static constexpr uint32_t HOME_ID_TIMEOUT_MS = 100; // Timeout for waiting for home ID during setup + +static uint8_t calculate_frame_checksum(const uint8_t *data, uint8_t length) { + // Calculate Z-Wave frame checksum + // XOR all bytes between SOF and checksum position (exclusive) + // Initial value is 0xFF per Z-Wave protocol specification + uint8_t checksum = 0xFF; + for (uint8_t i = 1; i < length - 1; i++) { + checksum ^= data[i]; + } + return checksum; +} + +ZWaveProxy::ZWaveProxy() { global_zwave_proxy = this; } + +void ZWaveProxy::setup() { + this->setup_time_ = App.get_loop_component_start_time(); + this->send_simple_command_(ZWAVE_COMMAND_GET_NETWORK_IDS); +} + +float ZWaveProxy::get_setup_priority() const { + // Set up before API so home ID is ready when API starts + return setup_priority::BEFORE_CONNECTION; +} + +bool ZWaveProxy::can_proceed() { + // If we already have the home ID, we can proceed + if (this->home_id_ready_) { + return true; + } + + // Handle any pending responses + if (this->response_handler_()) { + ESP_LOGV(TAG, "Handled response during setup"); + } + + // Process UART data to check for home ID + this->process_uart_(); + + // Check if we got the home ID after processing + if (this->home_id_ready_) { + return true; + } + + // Wait up to HOME_ID_TIMEOUT_MS for home ID response + const uint32_t now = App.get_loop_component_start_time(); + if (now - this->setup_time_ > HOME_ID_TIMEOUT_MS) { + ESP_LOGW(TAG, "Timeout reading Home ID during setup"); + return true; // Proceed anyway after timeout + } + + return false; // Keep waiting +} + +void ZWaveProxy::loop() { + if (this->response_handler_()) { + ESP_LOGV(TAG, "Handled late response"); + } + if (this->api_connection_ != nullptr && (!this->api_connection_->is_connection_setup() || !api_is_connected())) { + ESP_LOGW(TAG, "Subscriber disconnected"); + this->api_connection_ = nullptr; // Unsubscribe if disconnected + } + + this->process_uart_(); + this->status_clear_warning(); +} + +void ZWaveProxy::process_uart_() { + while (this->available()) { + uint8_t byte; + if (!this->read_byte(&byte)) { + this->status_set_warning("UART read failed"); + return; + } + if (this->parse_byte_(byte)) { + // Check if this is a GET_NETWORK_IDS response frame + // Frame format: [SOF][LENGTH][TYPE][CMD][HOME_ID(4)][NODE_ID][...] + // We verify: + // - buffer_[0]: Start of frame marker (0x01) + // - buffer_[1]: Length field must be >= 9 to contain all required data + // - buffer_[2]: Command type (0x01 for response) + // - buffer_[3]: Command ID (0x20 for GET_NETWORK_IDS) + if (this->buffer_[3] == ZWAVE_COMMAND_GET_NETWORK_IDS && this->buffer_[2] == ZWAVE_COMMAND_TYPE_RESPONSE && + this->buffer_[1] >= ZWAVE_MIN_GET_NETWORK_IDS_LENGTH && this->buffer_[0] == ZWAVE_FRAME_TYPE_START) { + // Store the 4-byte Home ID, which starts at offset 4, and notify connected clients if it changed + // The frame parser has already validated the checksum and ensured all bytes are present + if (this->set_home_id(&this->buffer_[4])) { + this->send_homeid_changed_msg_(); + } + } + ESP_LOGV(TAG, "Sending to client: %s", YESNO(this->api_connection_ != nullptr)); + if (this->api_connection_ != nullptr) { + // Zero-copy: point directly to our buffer + this->outgoing_proto_msg_.data = this->buffer_.data(); + if (this->in_bootloader_) { + this->outgoing_proto_msg_.data_len = this->buffer_index_; + } else { + // If this is a data frame, use frame length indicator + 2 (for SoF + checksum), else assume 1 for ACK/NAK/CAN + this->outgoing_proto_msg_.data_len = this->buffer_[0] == ZWAVE_FRAME_TYPE_START ? this->buffer_[1] + 2 : 1; + } + this->api_connection_->send_message(this->outgoing_proto_msg_, api::ZWaveProxyFrame::MESSAGE_TYPE); + } + } + } +} + +void ZWaveProxy::dump_config() { + ESP_LOGCONFIG(TAG, + "Z-Wave Proxy:\n" + " Home ID: %s", + format_hex_pretty(this->home_id_.data(), this->home_id_.size(), ':', false).c_str()); +} + +void ZWaveProxy::api_connection_authenticated(api::APIConnection *conn) { + if (this->home_id_ready_) { + // If a client just authenticated & HomeID is ready, send the current HomeID + this->send_homeid_changed_msg_(conn); + } +} + +void ZWaveProxy::zwave_proxy_request(api::APIConnection *api_connection, api::enums::ZWaveProxyRequestType type) { + switch (type) { + case api::enums::ZWAVE_PROXY_REQUEST_TYPE_SUBSCRIBE: + if (this->api_connection_ != nullptr) { + ESP_LOGE(TAG, "Only one API subscription is allowed at a time"); + return; + } + this->api_connection_ = api_connection; + ESP_LOGV(TAG, "API connection is now subscribed"); + break; + case api::enums::ZWAVE_PROXY_REQUEST_TYPE_UNSUBSCRIBE: + if (this->api_connection_ != api_connection) { + ESP_LOGV(TAG, "API connection is not subscribed"); + return; + } + this->api_connection_ = nullptr; + break; + default: + ESP_LOGW(TAG, "Unknown request type: %d", type); + break; + } +} + +bool ZWaveProxy::set_home_id(const uint8_t *new_home_id) { + if (std::memcmp(this->home_id_.data(), new_home_id, this->home_id_.size()) == 0) { + ESP_LOGV(TAG, "Home ID unchanged"); + return false; // No change + } + std::memcpy(this->home_id_.data(), new_home_id, this->home_id_.size()); + ESP_LOGI(TAG, "Home ID: %s", format_hex_pretty(this->home_id_.data(), this->home_id_.size(), ':', false).c_str()); + this->home_id_ready_ = true; + return true; // Home ID was changed +} + +void ZWaveProxy::send_frame(const uint8_t *data, size_t length) { + if (length == 1 && data[0] == this->last_response_) { + ESP_LOGV(TAG, "Skipping sending duplicate response: 0x%02X", data[0]); + return; + } + ESP_LOGVV(TAG, "Sending: %s", format_hex_pretty(data, length).c_str()); + this->write_array(data, length); +} + +void ZWaveProxy::send_homeid_changed_msg_(api::APIConnection *conn) { + api::ZWaveProxyRequest msg; + msg.type = api::enums::ZWAVE_PROXY_REQUEST_TYPE_HOME_ID_CHANGE; + msg.data = this->home_id_.data(); + msg.data_len = this->home_id_.size(); + if (conn != nullptr) { + // Send to specific connection + conn->send_message(msg, api::ZWaveProxyRequest::MESSAGE_TYPE); + } else if (api::global_api_server != nullptr) { + // We could add code to manage a second subscription type, but, since this message is + // very infrequent and small, we simply send it to all clients + api::global_api_server->on_zwave_proxy_request(msg); + } +} + +void ZWaveProxy::send_simple_command_(const uint8_t command_id) { + // Send a simple Z-Wave command with no parameters + // Frame format: [SOF][LENGTH][TYPE][CMD][CHECKSUM] + // Where LENGTH=0x03 (3 bytes: TYPE + CMD + CHECKSUM) + uint8_t cmd[] = {0x01, 0x03, 0x00, command_id, 0x00}; + cmd[4] = calculate_frame_checksum(cmd, sizeof(cmd)); + this->send_frame(cmd, sizeof(cmd)); +} + +bool ZWaveProxy::parse_byte_(uint8_t byte) { + bool frame_completed = false; + // Basic parsing logic for received frames + switch (this->parsing_state_) { + case ZWAVE_PARSING_STATE_WAIT_START: + this->parse_start_(byte); + break; + case ZWAVE_PARSING_STATE_WAIT_LENGTH: + if (!byte) { + ESP_LOGW(TAG, "Invalid LENGTH: %u", byte); + this->parsing_state_ = ZWAVE_PARSING_STATE_SEND_NAK; + return false; + } + ESP_LOGVV(TAG, "Received LENGTH: %u", byte); + this->end_frame_after_ = this->buffer_index_ + byte; + ESP_LOGVV(TAG, "Calculated EOF: %u", this->end_frame_after_); + this->buffer_[this->buffer_index_++] = byte; + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_TYPE; + break; + case ZWAVE_PARSING_STATE_WAIT_TYPE: + this->buffer_[this->buffer_index_++] = byte; + ESP_LOGVV(TAG, "Received TYPE: 0x%02X", byte); + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_COMMAND_ID; + break; + case ZWAVE_PARSING_STATE_WAIT_COMMAND_ID: + this->buffer_[this->buffer_index_++] = byte; + ESP_LOGVV(TAG, "Received COMMAND ID: 0x%02X", byte); + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_PAYLOAD; + break; + case ZWAVE_PARSING_STATE_WAIT_PAYLOAD: + this->buffer_[this->buffer_index_++] = byte; + ESP_LOGVV(TAG, "Received PAYLOAD: 0x%02X", byte); + if (this->buffer_index_ >= this->end_frame_after_) { + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_CHECKSUM; + } + break; + case ZWAVE_PARSING_STATE_WAIT_CHECKSUM: { + this->buffer_[this->buffer_index_++] = byte; + auto checksum = calculate_frame_checksum(this->buffer_.data(), this->buffer_index_); + ESP_LOGVV(TAG, "CHECKSUM Received: 0x%02X - Calculated: 0x%02X", byte, checksum); + if (checksum != byte) { + ESP_LOGW(TAG, "Bad checksum: expected 0x%02X, got 0x%02X", checksum, byte); + this->parsing_state_ = ZWAVE_PARSING_STATE_SEND_NAK; + } else { + this->parsing_state_ = ZWAVE_PARSING_STATE_SEND_ACK; + ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(this->buffer_.data(), this->buffer_index_).c_str()); + frame_completed = true; + } + this->response_handler_(); + break; + } + case ZWAVE_PARSING_STATE_READ_BL_MENU: + this->buffer_[this->buffer_index_++] = byte; + if (!byte) { + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_START; + frame_completed = true; + } + break; + case ZWAVE_PARSING_STATE_SEND_ACK: + case ZWAVE_PARSING_STATE_SEND_NAK: + break; // Should not happen, handled in loop() + default: + ESP_LOGW(TAG, "Bad parsing state; resetting"); + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_START; + break; + } + return frame_completed; +} + +void ZWaveProxy::parse_start_(uint8_t byte) { + this->buffer_index_ = 0; + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_START; + switch (byte) { + case ZWAVE_FRAME_TYPE_START: + ESP_LOGVV(TAG, "Received START"); + if (this->in_bootloader_) { + ESP_LOGD(TAG, "Exited bootloader mode"); + this->in_bootloader_ = false; + } + this->buffer_[this->buffer_index_++] = byte; + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_LENGTH; + return; + case ZWAVE_FRAME_TYPE_BL_MENU: + ESP_LOGVV(TAG, "Received BL_MENU"); + if (!this->in_bootloader_) { + ESP_LOGD(TAG, "Entered bootloader mode"); + this->in_bootloader_ = true; + } + this->buffer_[this->buffer_index_++] = byte; + this->parsing_state_ = ZWAVE_PARSING_STATE_READ_BL_MENU; + return; + case ZWAVE_FRAME_TYPE_BL_BEGIN_UPLOAD: + ESP_LOGVV(TAG, "Received BL_BEGIN_UPLOAD"); + break; + case ZWAVE_FRAME_TYPE_ACK: + ESP_LOGVV(TAG, "Received ACK"); + break; + case ZWAVE_FRAME_TYPE_NAK: + ESP_LOGW(TAG, "Received NAK"); + break; + case ZWAVE_FRAME_TYPE_CAN: + ESP_LOGW(TAG, "Received CAN"); + break; + default: + ESP_LOGW(TAG, "Unrecognized START: 0x%02X", byte); + return; + } + // Forward response (ACK/NAK/CAN) back to client for processing + if (this->api_connection_ != nullptr) { + // Store single byte in buffer and point to it + this->buffer_[0] = byte; + this->outgoing_proto_msg_.data = this->buffer_.data(); + this->outgoing_proto_msg_.data_len = 1; + this->api_connection_->send_message(this->outgoing_proto_msg_, api::ZWaveProxyFrame::MESSAGE_TYPE); + } +} + +bool ZWaveProxy::response_handler_() { + switch (this->parsing_state_) { + case ZWAVE_PARSING_STATE_SEND_ACK: + this->last_response_ = ZWAVE_FRAME_TYPE_ACK; + break; + case ZWAVE_PARSING_STATE_SEND_CAN: + this->last_response_ = ZWAVE_FRAME_TYPE_CAN; + break; + case ZWAVE_PARSING_STATE_SEND_NAK: + this->last_response_ = ZWAVE_FRAME_TYPE_NAK; + break; + default: + return false; // No response handled + } + + ESP_LOGVV(TAG, "Sending %s (0x%02X)", this->last_response_ == ZWAVE_FRAME_TYPE_ACK ? "ACK" : "NAK/CAN", + this->last_response_); + this->write_byte(this->last_response_); + this->parsing_state_ = ZWAVE_PARSING_STATE_WAIT_START; + return true; +} + +ZWaveProxy *global_zwave_proxy = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +} // namespace zwave_proxy +} // namespace esphome diff --git a/esphome/components/zwave_proxy/zwave_proxy.h b/esphome/components/zwave_proxy/zwave_proxy.h new file mode 100644 index 0000000000..20d9090d98 --- /dev/null +++ b/esphome/components/zwave_proxy/zwave_proxy.h @@ -0,0 +1,93 @@ +#pragma once + +#include "esphome/components/api/api_connection.h" +#include "esphome/components/api/api_pb2.h" +#include "esphome/core/component.h" +#include "esphome/core/helpers.h" +#include "esphome/components/uart/uart.h" + +#include + +namespace esphome { +namespace zwave_proxy { + +static constexpr size_t MAX_ZWAVE_FRAME_SIZE = 257; // Maximum Z-Wave frame size + +enum ZWaveResponseTypes : uint8_t { + ZWAVE_FRAME_TYPE_ACK = 0x06, + ZWAVE_FRAME_TYPE_CAN = 0x18, + ZWAVE_FRAME_TYPE_NAK = 0x15, + ZWAVE_FRAME_TYPE_START = 0x01, + ZWAVE_FRAME_TYPE_BL_MENU = 0x0D, + ZWAVE_FRAME_TYPE_BL_BEGIN_UPLOAD = 0x43, +}; + +enum ZWaveParsingState : uint8_t { + ZWAVE_PARSING_STATE_WAIT_START, + ZWAVE_PARSING_STATE_WAIT_LENGTH, + ZWAVE_PARSING_STATE_WAIT_TYPE, + ZWAVE_PARSING_STATE_WAIT_COMMAND_ID, + ZWAVE_PARSING_STATE_WAIT_PAYLOAD, + ZWAVE_PARSING_STATE_WAIT_CHECKSUM, + ZWAVE_PARSING_STATE_SEND_ACK, + ZWAVE_PARSING_STATE_SEND_CAN, + ZWAVE_PARSING_STATE_SEND_NAK, + ZWAVE_PARSING_STATE_READ_BL_MENU, +}; + +enum ZWaveProxyFeature : uint32_t { + FEATURE_ZWAVE_PROXY_ENABLED = 1 << 0, +}; + +class ZWaveProxy : public uart::UARTDevice, public Component { + public: + ZWaveProxy(); + + void setup() override; + void loop() override; + void dump_config() override; + float get_setup_priority() const override; + bool can_proceed() override; + + void api_connection_authenticated(api::APIConnection *conn); + void zwave_proxy_request(api::APIConnection *api_connection, api::enums::ZWaveProxyRequestType type); + api::APIConnection *get_api_connection() { return this->api_connection_; } + + uint32_t get_feature_flags() const { return ZWaveProxyFeature::FEATURE_ZWAVE_PROXY_ENABLED; } + uint32_t get_home_id() { + return encode_uint32(this->home_id_[0], this->home_id_[1], this->home_id_[2], this->home_id_[3]); + } + bool set_home_id(const uint8_t *new_home_id); // Store a new home ID. Returns true if it changed. + + void send_frame(const uint8_t *data, size_t length); + + protected: + void send_homeid_changed_msg_(api::APIConnection *conn = nullptr); + void send_simple_command_(uint8_t command_id); + bool parse_byte_(uint8_t byte); // Returns true if frame parsing was completed (a frame is ready in the buffer) + void parse_start_(uint8_t byte); + bool response_handler_(); + void process_uart_(); // Process all available UART data + + // Pre-allocated message - always ready to send + api::ZWaveProxyFrame outgoing_proto_msg_; + std::array buffer_; // Fixed buffer for incoming data + std::array home_id_{0, 0, 0, 0}; // Fixed buffer for home ID + + // Pointers and 32-bit values (aligned together) + api::APIConnection *api_connection_{nullptr}; // Current subscribed client + uint32_t setup_time_{0}; // Time when setup() was called + + // 8-bit values (grouped together to minimize padding) + uint8_t buffer_index_{0}; // Index for populating the data buffer + uint8_t end_frame_after_{0}; // Payload reception ends after this index + uint8_t last_response_{0}; // Last response type sent + ZWaveParsingState parsing_state_{ZWAVE_PARSING_STATE_WAIT_START}; + bool in_bootloader_{false}; // True if the device is detected to be in bootloader mode + bool home_id_ready_{false}; // True when home ID has been received from Z-Wave module +}; + +extern ZWaveProxy *global_zwave_proxy; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +} // namespace zwave_proxy +} // namespace esphome diff --git a/esphome/config.py b/esphome/config.py index 90325cbf6e..10a5733575 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -32,7 +32,7 @@ from esphome.log import AnsiFore, color from esphome.types import ConfigFragmentType, ConfigType from esphome.util import OrderedDict, safe_print from esphome.voluptuous_schema import ExtraKeysInvalid -from esphome.yaml_util import ESPForceValue, ESPHomeDataBase, is_secret +from esphome.yaml_util import ESPHomeDataBase, ESPLiteralValue, is_secret _LOGGER = logging.getLogger(__name__) @@ -67,6 +67,31 @@ ConfigPath = list[str | int] path_context = contextvars.ContextVar("Config path") +def _add_auto_load_steps(result: Config, loads: list[str]) -> None: + """Add AutoLoadValidationStep for each component in loads that isn't already loaded.""" + for load in loads: + if load not in result: + result.add_validation_step(AutoLoadValidationStep(load)) + + +def _process_auto_load( + result: Config, platform: ComponentManifest, path: ConfigPath +) -> None: + # Process platform's AUTO_LOAD + auto_load = platform.auto_load + if isinstance(auto_load, list): + _add_auto_load_steps(result, auto_load) + elif callable(auto_load): + import inspect + + if inspect.signature(auto_load).parameters: + result.add_validation_step( + AddDynamicAutoLoadsValidationStep(path, platform) + ) + else: + _add_auto_load_steps(result, auto_load()) + + def _process_platform_config( result: Config, component_name: str, @@ -91,9 +116,7 @@ def _process_platform_config( CORE.loaded_platforms.add(f"{component_name}/{platform_name}") # Process platform's AUTO_LOAD - for load in platform.auto_load: - if load not in result: - result.add_validation_step(AutoLoadValidationStep(load)) + _process_auto_load(result, platform, path) # Add validation steps for the platform p_domain = f"{component_name}.{platform_name}" @@ -306,7 +329,7 @@ def recursive_check_replaceme(value): return cv.Schema([recursive_check_replaceme])(value) if isinstance(value, dict): return cv.Schema({cv.valid: recursive_check_replaceme})(value) - if isinstance(value, ESPForceValue): + if isinstance(value, ESPLiteralValue): pass if isinstance(value, str) and value == "REPLACEME": raise cv.Invalid( @@ -314,7 +337,7 @@ def recursive_check_replaceme(value): "Please make sure you have replaced all fields from the sample " "configuration.\n" "If you want to use the literal REPLACEME string, " - 'please use "!force REPLACEME"' + 'please use "!literal REPLACEME"' ) return value @@ -382,11 +405,15 @@ class LoadValidationStep(ConfigValidationStep): result.add_str_error(f"Component not found: {self.domain}", path) return CORE.loaded_integrations.add(self.domain) + # For platform components, normalize conf before creating MetadataValidationStep + if component.is_platform_component: + if not self.conf: + result[self.domain] = self.conf = [] + elif not isinstance(self.conf, list): + result[self.domain] = self.conf = [self.conf] # Process AUTO_LOAD - for load in component.auto_load: - if load not in result: - result.add_validation_step(AutoLoadValidationStep(load)) + _process_auto_load(result, component, path) result.add_validation_step( MetadataValidationStep([self.domain], self.domain, self.conf, component) @@ -399,12 +426,6 @@ class LoadValidationStep(ConfigValidationStep): # Remove this is as an output path result.remove_output_path([self.domain], self.domain) - # Ensure conf is a list - if not self.conf: - result[self.domain] = self.conf = [] - elif not isinstance(self.conf, list): - result[self.domain] = self.conf = [self.conf] - for i, p_config in enumerate(self.conf): path = [self.domain, i] # Construct temporary unknown output path @@ -618,6 +639,34 @@ class MetadataValidationStep(ConfigValidationStep): result.add_validation_step(FinalValidateValidationStep(self.path, self.comp)) +class AddDynamicAutoLoadsValidationStep(ConfigValidationStep): + """Add dynamic auto loads step. + + This step is used to auto-load components where one component can alter its + AUTO_LOAD based on its configuration. + """ + + # Has to happen after normal schema is validated and before final schema validation + priority = -5.0 + + def __init__(self, path: ConfigPath, comp: ComponentManifest) -> None: + self.path = path + self.comp = comp + + def run(self, result: Config) -> None: + if result.errors: + # If result already has errors, skip this step + return + + conf = result.get_nested_item(self.path) + with result.catch_error(self.path): + auto_load = self.comp.auto_load + if not callable(auto_load): + return + loads = auto_load(conf) + _add_auto_load_steps(result, loads) + + class SchemaValidationStep(ConfigValidationStep): """Schema validation step. @@ -846,7 +895,9 @@ class PinUseValidationCheck(ConfigValidationStep): def validate_config( - config: dict[str, Any], command_line_substitutions: dict[str, Any] + config: dict[str, Any], + command_line_substitutions: dict[str, Any], + skip_external_update: bool = False, ) -> Config: result = Config() @@ -859,7 +910,7 @@ def validate_config( result.add_output_path([CONF_PACKAGES], CONF_PACKAGES) try: - config = do_packages_pass(config) + config = do_packages_pass(config, skip_update=skip_external_update) except vol.Invalid as err: result.update(config) result.add_error(err) @@ -896,7 +947,7 @@ def validate_config( result.add_output_path([CONF_EXTERNAL_COMPONENTS], CONF_EXTERNAL_COMPONENTS) try: - do_external_components_pass(config) + do_external_components_pass(config, skip_update=skip_external_update) except vol.Invalid as err: result.update(config) result.add_error(err) @@ -1020,7 +1071,9 @@ class InvalidYAMLError(EsphomeError): self.base_exc = base_exc -def _load_config(command_line_substitutions: dict[str, Any]) -> Config: +def _load_config( + command_line_substitutions: dict[str, Any], skip_external_update: bool = False +) -> Config: """Load the configuration file.""" try: config = yaml_util.load_yaml(CORE.config_path) @@ -1028,7 +1081,7 @@ def _load_config(command_line_substitutions: dict[str, Any]) -> Config: raise InvalidYAMLError(e) from e try: - return validate_config(config, command_line_substitutions) + return validate_config(config, command_line_substitutions, skip_external_update) except EsphomeError: raise except Exception: @@ -1036,9 +1089,11 @@ def _load_config(command_line_substitutions: dict[str, Any]) -> Config: raise -def load_config(command_line_substitutions: dict[str, Any]) -> Config: +def load_config( + command_line_substitutions: dict[str, Any], skip_external_update: bool = False +) -> Config: try: - return _load_config(command_line_substitutions) + return _load_config(command_line_substitutions, skip_external_update) except vol.Invalid as err: raise EsphomeError(f"Error while parsing config: {err}") from err @@ -1178,10 +1233,10 @@ def strip_default_ids(config): return config -def read_config(command_line_substitutions): +def read_config(command_line_substitutions, skip_external_update=False): _LOGGER.info("Reading configuration %s...", CORE.config_path) try: - res = load_config(command_line_substitutions) + res = load_config(command_line_substitutions, skip_external_update) except EsphomeError as err: _LOGGER.error("Error while reading config: %s", err) return None diff --git a/esphome/config_validation.py b/esphome/config_validation.py index 866ed4f8aa..7aaba886e3 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -15,7 +15,7 @@ from ipaddress import ( ip_network, ) import logging -import os +from pathlib import Path import re from string import ascii_letters, digits import uuid as uuid_ @@ -1609,34 +1609,32 @@ def dimensions(value): return dimensions([match.group(1), match.group(2)]) -def directory(value): +def directory(value: object) -> Path: value = string(value) path = CORE.relative_config_path(value) - if not os.path.exists(path): + if not path.exists(): raise Invalid( - f"Could not find directory '{path}'. Please make sure it exists (full path: {os.path.abspath(path)})." + f"Could not find directory '{path}'. Please make sure it exists (full path: {path.resolve()})." ) - if not os.path.isdir(path): + if not path.is_dir(): raise Invalid( - f"Path '{path}' is not a directory (full path: {os.path.abspath(path)})." + f"Path '{path}' is not a directory (full path: {path.resolve()})." ) - return value + return path -def file_(value): +def file_(value: object) -> Path: value = string(value) path = CORE.relative_config_path(value) - if not os.path.exists(path): + if not path.exists(): raise Invalid( - f"Could not find file '{path}'. Please make sure it exists (full path: {os.path.abspath(path)})." + f"Could not find file '{path}'. Please make sure it exists (full path: {path.resolve()})." ) - if not os.path.isfile(path): - raise Invalid( - f"Path '{path}' is not a file (full path: {os.path.abspath(path)})." - ) - return value + if not path.is_file(): + raise Invalid(f"Path '{path}' is not a file (full path: {path.resolve()}).") + return path ENTITY_ID_CHARACTERS = "abcdefghijklmnopqrstuvwxyz0123456789_" diff --git a/esphome/const.py b/esphome/const.py index a7e1752a67..5f81790a10 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -4,7 +4,7 @@ from enum import Enum from esphome.enum import StrEnum -__version__ = "2025.9.3" +__version__ = "2025.10.0b1" ALLOWED_NAME_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789-_" VALID_SUBSTITUTIONS_CHARACTERS = ( @@ -174,6 +174,7 @@ CONF_CALIBRATE_LINEAR = "calibrate_linear" CONF_CALIBRATION = "calibration" CONF_CAPACITANCE = "capacitance" CONF_CAPACITY = "capacity" +CONF_CAPTURE_RESPONSE = "capture_response" CONF_CARBON_MONOXIDE = "carbon_monoxide" CONF_CARRIER_DUTY_PERCENT = "carrier_duty_percent" CONF_CARRIER_FREQUENCY = "carrier_frequency" @@ -186,6 +187,7 @@ CONF_CHARACTERISTIC_UUID = "characteristic_uuid" CONF_CHECK = "check" CONF_CHIPSET = "chipset" CONF_CLEAN_SESSION = "clean_session" +CONF_CLEAR = "clear" CONF_CLEAR_IMPEDANCE = "clear_impedance" CONF_CLIENT_CERTIFICATE = "client_certificate" CONF_CLIENT_CERTIFICATE_KEY = "client_certificate_key" @@ -541,6 +543,7 @@ CONF_MANUAL_IP = "manual_ip" CONF_MANUFACTURER_ID = "manufacturer_id" CONF_MASK_DISTURBER = "mask_disturber" CONF_MAX_BRIGHTNESS = "max_brightness" +CONF_MAX_CONNECTIONS = "max_connections" CONF_MAX_COOLING_RUN_TIME = "max_cooling_run_time" CONF_MAX_CURRENT = "max_current" CONF_MAX_DURATION = "max_duration" @@ -670,9 +673,11 @@ CONF_ON_PRESET_SET = "on_preset_set" CONF_ON_PRESS = "on_press" CONF_ON_RAW_VALUE = "on_raw_value" CONF_ON_RELEASE = "on_release" +CONF_ON_RESPONSE = "on_response" CONF_ON_SHUTDOWN = "on_shutdown" CONF_ON_SPEED_SET = "on_speed_set" CONF_ON_STATE = "on_state" +CONF_ON_SUCCESS = "on_success" CONF_ON_TAG = "on_tag" CONF_ON_TAG_REMOVED = "on_tag_removed" CONF_ON_TIME = "on_time" @@ -815,6 +820,7 @@ CONF_RESET_DURATION = "reset_duration" CONF_RESET_PIN = "reset_pin" CONF_RESIZE = "resize" CONF_RESOLUTION = "resolution" +CONF_RESPONSE_TEMPLATE = "response_template" CONF_RESTART = "restart" CONF_RESTORE = "restore" CONF_RESTORE_MODE = "restore_mode" @@ -1167,7 +1173,7 @@ UNIT_KILOMETER = "km" UNIT_KILOMETER_PER_HOUR = "km/h" UNIT_KILOVOLT_AMPS = "kVA" UNIT_KILOVOLT_AMPS_HOURS = "kVAh" -UNIT_KILOVOLT_AMPS_REACTIVE = "kVAR" +UNIT_KILOVOLT_AMPS_REACTIVE = "kvar" UNIT_KILOVOLT_AMPS_REACTIVE_HOURS = "kvarh" UNIT_KILOWATT = "kW" UNIT_KILOWATT_HOURS = "kWh" @@ -1268,6 +1274,7 @@ DEVICE_CLASS_PLUG = "plug" DEVICE_CLASS_PM1 = "pm1" DEVICE_CLASS_PM10 = "pm10" DEVICE_CLASS_PM25 = "pm25" +DEVICE_CLASS_PM4 = "pm4" DEVICE_CLASS_POWER = "power" DEVICE_CLASS_POWER_FACTOR = "power_factor" DEVICE_CLASS_PRECIPITATION = "precipitation" diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 89e3eff7d8..7ab8a3ba71 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -3,6 +3,7 @@ from contextlib import contextmanager import logging import math import os +from pathlib import Path import re from typing import TYPE_CHECKING @@ -39,6 +40,8 @@ from esphome.helpers import ensure_unique_string, get_str_env, is_ha_addon from esphome.util import OrderedDict if TYPE_CHECKING: + from esphome.address_cache import AddressCache + from ..cpp_generator import MockObj, MockObjClass, Statement from ..types import ConfigType, EntityMetadata @@ -381,7 +384,7 @@ class DocumentLocation: @classmethod def from_mark(cls, mark): - return cls(mark.name, mark.line, mark.column) + return cls(str(mark.name), mark.line, mark.column) def __str__(self): return f"{self.document} {self.line}:{self.column}" @@ -536,9 +539,9 @@ class EsphomeCore: # The first key to this dict should always be the integration name self.data = {} # The relative path to the configuration YAML - self.config_path: str | None = None + self.config_path: Path | None = None # The relative path to where all build files are stored - self.build_path: str | None = None + self.build_path: Path | None = None # The validated configuration, this is None until the config has been validated self.config: ConfigType | None = None # The pending tasks in the task queue (mostly for C++ generation) @@ -583,6 +586,8 @@ class EsphomeCore: self.id_classes = {} # The current component being processed during validation self.current_component: str | None = None + # Address cache for DNS and mDNS lookups from command line arguments + self.address_cache: AddressCache | None = None def reset(self): from esphome.pins import PIN_SCHEMA_REGISTRY @@ -610,6 +615,7 @@ class EsphomeCore: self.platform_counts = defaultdict(int) self.unique_ids = {} self.current_component = None + self.address_cache = None PIN_SCHEMA_REGISTRY.reset() @contextmanager @@ -659,43 +665,46 @@ class EsphomeCore: return None @property - def config_dir(self): - return os.path.abspath(os.path.dirname(self.config_path)) + def config_dir(self) -> Path: + if self.config_path.is_dir(): + return self.config_path.absolute() + return self.config_path.absolute().parent @property - def data_dir(self): + def data_dir(self) -> Path: if is_ha_addon(): - return os.path.join("/data") + return Path("/data") if "ESPHOME_DATA_DIR" in os.environ: - return get_str_env("ESPHOME_DATA_DIR", None) + return Path(get_str_env("ESPHOME_DATA_DIR", None)) return self.relative_config_path(".esphome") @property - def config_filename(self): - return os.path.basename(self.config_path) + def config_filename(self) -> str: + return self.config_path.name - def relative_config_path(self, *path): - path_ = os.path.expanduser(os.path.join(*path)) - return os.path.join(self.config_dir, path_) + def relative_config_path(self, *path: str | Path) -> Path: + path_ = Path(*path).expanduser() + return self.config_dir / path_ - def relative_internal_path(self, *path: str) -> str: - return os.path.join(self.data_dir, *path) + def relative_internal_path(self, *path: str | Path) -> Path: + path_ = Path(*path).expanduser() + return self.data_dir / path_ - def relative_build_path(self, *path): - path_ = os.path.expanduser(os.path.join(*path)) - return os.path.join(self.build_path, path_) + def relative_build_path(self, *path: str | Path) -> Path: + path_ = Path(*path).expanduser() + return self.build_path / path_ - def relative_src_path(self, *path): + def relative_src_path(self, *path: str | Path) -> Path: return self.relative_build_path("src", *path) - def relative_pioenvs_path(self, *path): + def relative_pioenvs_path(self, *path: str | Path) -> Path: return self.relative_build_path(".pioenvs", *path) - def relative_piolibdeps_path(self, *path): + def relative_piolibdeps_path(self, *path: str | Path) -> Path: return self.relative_build_path(".piolibdeps", *path) @property - def firmware_bin(self): + def firmware_bin(self) -> Path: if self.is_libretiny: return self.relative_pioenvs_path(self.name, "firmware.uf2") return self.relative_pioenvs_path(self.name, "firmware.bin") diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index ce4e2bf788..11d9501bb8 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -33,12 +33,22 @@ static const char *const TAG = "component"; // Using namespace-scope static to avoid guard variables (saves 16 bytes total) // This is safe because ESPHome is single-threaded during initialization namespace { +struct ComponentErrorMessage { + const Component *component; + const char *message; +}; + +struct ComponentPriorityOverride { + const Component *component; + float priority; +}; + // Error messages for failed components // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -std::unique_ptr>> component_error_messages; +std::unique_ptr> component_error_messages; // Setup priority overrides - freed after setup completes // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -std::unique_ptr>> setup_priority_overrides; +std::unique_ptr> setup_priority_overrides; } // namespace namespace setup_priority { @@ -134,9 +144,9 @@ void Component::call_dump_config() { // Look up error message from global vector const char *error_msg = nullptr; if (component_error_messages) { - for (const auto &pair : *component_error_messages) { - if (pair.first == this) { - error_msg = pair.second; + for (const auto &entry : *component_error_messages) { + if (entry.component == this) { + error_msg = entry.message; break; } } @@ -306,17 +316,17 @@ void Component::status_set_error(const char *message) { if (message != nullptr) { // Lazy allocate the error messages vector if needed if (!component_error_messages) { - component_error_messages = std::make_unique>>(); + component_error_messages = std::make_unique>(); } // Check if this component already has an error message - for (auto &pair : *component_error_messages) { - if (pair.first == this) { - pair.second = message; + for (auto &entry : *component_error_messages) { + if (entry.component == this) { + entry.message = message; return; } } // Add new error message - component_error_messages->emplace_back(this, message); + component_error_messages->emplace_back(ComponentErrorMessage{this, message}); } } void Component::status_clear_warning() { @@ -356,9 +366,9 @@ float Component::get_actual_setup_priority() const { // Check if there's an override in the global vector if (setup_priority_overrides) { // Linear search is fine for small n (typically < 5 overrides) - for (const auto &pair : *setup_priority_overrides) { - if (pair.first == this) { - return pair.second; + for (const auto &entry : *setup_priority_overrides) { + if (entry.component == this) { + return entry.priority; } } } @@ -367,21 +377,21 @@ float Component::get_actual_setup_priority() const { void Component::set_setup_priority(float priority) { // Lazy allocate the vector if needed if (!setup_priority_overrides) { - setup_priority_overrides = std::make_unique>>(); + setup_priority_overrides = std::make_unique>(); // Reserve some space to avoid reallocations (most configs have < 10 overrides) setup_priority_overrides->reserve(10); } // Check if this component already has an override - for (auto &pair : *setup_priority_overrides) { - if (pair.first == this) { - pair.second = priority; + for (auto &entry : *setup_priority_overrides) { + if (entry.component == this) { + entry.priority = priority; return; } } // Add new override - setup_priority_overrides->emplace_back(this, priority); + setup_priority_overrides->emplace_back(ComponentPriorityOverride{this, priority}); } bool Component::has_overridden_loop() const { diff --git a/esphome/core/component_iterator.h b/esphome/core/component_iterator.h index fdc30485bc..641d42898a 100644 --- a/esphome/core/component_iterator.h +++ b/esphome/core/component_iterator.h @@ -168,8 +168,9 @@ class ComponentIterator { UPDATE, #endif MAX, - } state_{IteratorState::NONE}; + }; uint16_t at_{0}; // Supports up to 65,535 entities per type + IteratorState state_{IteratorState::NONE}; bool include_internal_{false}; template diff --git a/esphome/core/config.py b/esphome/core/config.py index 87e529143d..7bf7f82a8b 100644 --- a/esphome/core/config.py +++ b/esphome/core/config.py @@ -136,21 +136,21 @@ def validate_ids_and_references(config: ConfigType) -> ConfigType: return config -def valid_include(value): +def valid_include(value: str) -> str: # Look for "<...>" includes if value.startswith("<") and value.endswith(">"): return value try: - return cv.directory(value) + return str(cv.directory(value)) except cv.Invalid: pass - value = cv.file_(value) - _, ext = os.path.splitext(value) + path = cv.file_(value) + ext = path.suffix if ext not in VALID_INCLUDE_EXTS: raise cv.Invalid( f"Include has invalid file extension {ext} - valid extensions are {', '.join(VALID_INCLUDE_EXTS)}" ) - return value + return str(path) def valid_project_name(value: str): @@ -311,9 +311,9 @@ def preload_core_config(config, result) -> str: CORE.data[KEY_CORE] = {} if CONF_BUILD_PATH not in conf: - build_path = get_str_env("ESPHOME_BUILD_PATH", "build") - conf[CONF_BUILD_PATH] = os.path.join(build_path, CORE.name) - CORE.build_path = CORE.relative_internal_path(conf[CONF_BUILD_PATH]) + build_path = Path(get_str_env("ESPHOME_BUILD_PATH", "build")) + conf[CONF_BUILD_PATH] = str(build_path / CORE.name) + CORE.build_path = CORE.data_dir / conf[CONF_BUILD_PATH] target_platforms = [] @@ -339,12 +339,12 @@ def preload_core_config(config, result) -> str: return target_platforms[0] -def include_file(path, basename): - parts = basename.split(os.path.sep) +def include_file(path: Path, basename: Path): + parts = basename.parts dst = CORE.relative_src_path(*parts) copy_file_if_changed(path, dst) - _, ext = os.path.splitext(path) + ext = path.suffix if ext in [".h", ".hpp", ".tcc"]: # Header, add include statement cg.add_global(cg.RawStatement(f'#include "{basename}"')) @@ -377,18 +377,18 @@ async def add_arduino_global_workaround(): @coroutine_with_priority(CoroPriority.FINAL) -async def add_includes(includes): +async def add_includes(includes: list[str]) -> None: # Add includes at the very end, so that the included files can access global variables for include in includes: path = CORE.relative_config_path(include) - if os.path.isdir(path): + if path.is_dir(): # Directory, copy tree for p in walk_files(path): - basename = os.path.relpath(p, os.path.dirname(path)) + basename = p.relative_to(path.parent) include_file(p, basename) else: # Copy file - basename = os.path.basename(path) + basename = Path(path.name) include_file(path, basename) diff --git a/esphome/core/defines.h b/esphome/core/defines.h index 6e8d5ed74c..2317c0ed32 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -48,6 +48,7 @@ #define USE_LIGHT #define USE_LOCK #define USE_LOGGER +#define USE_LOGGER_RUNTIME_TAG_LEVELS #define USE_LVGL #define USE_LVGL_ANIMIMG #define USE_LVGL_ARC @@ -82,6 +83,7 @@ #define USE_LVGL_TILEVIEW #define USE_LVGL_TOUCHSCREEN #define USE_MDNS +#define MDNS_SERVICE_COUNT 3 #define USE_MEDIA_PLAYER #define USE_NEXTION_TFT_UPLOAD #define USE_NUMBER @@ -100,6 +102,7 @@ #define USE_UART_DEBUGGER #define USE_UPDATE #define USE_VALVE +#define USE_ZWAVE_PROXY // Feature flags which do not work for zephyr #ifndef USE_ZEPHYR @@ -109,19 +112,26 @@ #define USE_API #define USE_API_CLIENT_CONNECTED_TRIGGER #define USE_API_CLIENT_DISCONNECTED_TRIGGER +#define USE_API_HOMEASSISTANT_ACTION_RESPONSES +#define USE_API_HOMEASSISTANT_ACTION_RESPONSES_JSON #define USE_API_HOMEASSISTANT_SERVICES #define USE_API_HOMEASSISTANT_STATES #define USE_API_NOISE #define USE_API_PLAINTEXT #define USE_API_SERVICES +#define API_MAX_SEND_QUEUE 8 #define USE_MD5 +#define USE_SHA256 #define USE_MQTT #define USE_NETWORK #define USE_ONLINE_IMAGE_BMP_SUPPORT #define USE_ONLINE_IMAGE_PNG_SUPPORT #define USE_ONLINE_IMAGE_JPEG_SUPPORT #define USE_OTA +#define USE_OTA_MD5 #define USE_OTA_PASSWORD +#define USE_OTA_SHA256 +#define ALLOW_OTA_DOWNGRADE_MD5 #define USE_OTA_STATE_CALLBACK #define USE_OTA_VERSION 2 #define USE_TIME_TIMEZONE @@ -151,11 +161,20 @@ #define BLUETOOTH_PROXY_ADVERTISEMENT_BATCH_SIZE 16 #define USE_CAPTIVE_PORTAL #define USE_ESP32_BLE +#define USE_ESP32_BLE_MAX_CONNECTIONS 3 #define USE_ESP32_BLE_CLIENT #define USE_ESP32_BLE_DEVICE #define USE_ESP32_BLE_SERVER #define USE_ESP32_BLE_UUID #define USE_ESP32_BLE_ADVERTISING +#define USE_ESP32_BLE_SERVER_SET_VALUE_ACTION +#define USE_ESP32_BLE_SERVER_DESCRIPTOR_SET_VALUE_ACTION +#define USE_ESP32_BLE_SERVER_NOTIFY_ACTION +#define USE_ESP32_BLE_SERVER_CHARACTERISTIC_ON_WRITE +#define USE_ESP32_BLE_SERVER_DESCRIPTOR_ON_WRITE +#define USE_ESP32_BLE_SERVER_ON_CONNECT +#define USE_ESP32_BLE_SERVER_ON_DISCONNECT +#define USE_ESP32_CAMERA_JPEG_ENCODER #define USE_I2C #define USE_IMPROV #define USE_MICROPHONE diff --git a/esphome/core/hash_base.h b/esphome/core/hash_base.h new file mode 100644 index 0000000000..c45c4df70b --- /dev/null +++ b/esphome/core/hash_base.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include +#include "esphome/core/helpers.h" + +namespace esphome { + +/// Base class for hash algorithms +class HashBase { + public: + virtual ~HashBase() = default; + + /// Initialize a new hash computation + virtual void init() = 0; + + /// Add bytes of data for the hash + virtual void add(const uint8_t *data, size_t len) = 0; + void add(const char *data, size_t len) { this->add((const uint8_t *) data, len); } + + /// Compute the hash based on provided data + virtual void calculate() = 0; + + /// Retrieve the hash as bytes + void get_bytes(uint8_t *output) { memcpy(output, this->digest_, this->get_size()); } + + /// Retrieve the hash as hex characters + void get_hex(char *output) { + for (size_t i = 0; i < this->get_size(); i++) { + uint8_t byte = this->digest_[i]; + output[i * 2] = format_hex_char(byte >> 4); + output[i * 2 + 1] = format_hex_char(byte & 0x0F); + } + } + + /// Compare the hash against a provided byte-encoded hash + bool equals_bytes(const uint8_t *expected) { return memcmp(this->digest_, expected, this->get_size()) == 0; } + + /// Compare the hash against a provided hex-encoded hash + bool equals_hex(const char *expected) { + uint8_t parsed[32]; // Fixed size for max hash (SHA256 = 32 bytes) + if (!parse_hex(expected, parsed, this->get_size())) { + return false; + } + return this->equals_bytes(parsed); + } + + /// Get the size of the hash in bytes (16 for MD5, 32 for SHA256) + virtual size_t get_size() const = 0; + + protected: + uint8_t digest_[32]; // Storage sized for max(MD5=16, SHA256=32) bytes +}; + +} // namespace esphome diff --git a/esphome/core/helpers.cpp b/esphome/core/helpers.cpp index f1560711ef..d4f6809776 100644 --- a/esphome/core/helpers.cpp +++ b/esphome/core/helpers.cpp @@ -3,6 +3,7 @@ #include "esphome/core/defines.h" #include "esphome/core/hal.h" #include "esphome/core/log.h" +#include "esphome/core/string_ref.h" #include #include @@ -348,17 +349,34 @@ ParseOnOffState parse_on_off(const char *str, const char *on, const char *off) { return PARSE_NONE; } -std::string value_accuracy_to_string(float value, int8_t accuracy_decimals) { +static inline void normalize_accuracy_decimals(float &value, int8_t &accuracy_decimals) { if (accuracy_decimals < 0) { auto multiplier = powf(10.0f, accuracy_decimals); value = roundf(value * multiplier) / multiplier; accuracy_decimals = 0; } +} + +std::string value_accuracy_to_string(float value, int8_t accuracy_decimals) { + normalize_accuracy_decimals(value, accuracy_decimals); char tmp[32]; // should be enough, but we should maybe improve this at some point. snprintf(tmp, sizeof(tmp), "%.*f", accuracy_decimals, value); return std::string(tmp); } +std::string value_accuracy_with_uom_to_string(float value, int8_t accuracy_decimals, StringRef unit_of_measurement) { + normalize_accuracy_decimals(value, accuracy_decimals); + // Buffer sized for float (up to ~15 chars) + space + typical UOM (usually <20 chars like "μS/cm") + // snprintf truncates safely if exceeded, though ESPHome UOMs are typically short + char tmp[64]; + if (unit_of_measurement.empty()) { + snprintf(tmp, sizeof(tmp), "%.*f", accuracy_decimals, value); + } else { + snprintf(tmp, sizeof(tmp), "%.*f %s", accuracy_decimals, value, unit_of_measurement.c_str()); + } + return std::string(tmp); +} + int8_t step_to_accuracy_decimals(float step) { // use printf %g to find number of digits based on temperature step char buf[32]; @@ -613,8 +631,6 @@ bool mac_address_is_valid(const uint8_t *mac) { if (mac[i] != 0) { is_all_zeros = false; } - } - for (uint8_t i = 0; i < 6; i++) { if (mac[i] != 0xFF) { is_all_ones = false; } diff --git a/esphome/core/helpers.h b/esphome/core/helpers.h index 21aa159b25..e06f2d15ef 100644 --- a/esphome/core/helpers.h +++ b/esphome/core/helpers.h @@ -45,6 +45,9 @@ namespace esphome { +// Forward declaration to avoid circular dependency with string_ref.h +class StringRef; + /// @name STL backports ///@{ @@ -82,6 +85,16 @@ template constexpr T byteswap(T n) { return m; } template<> constexpr uint8_t byteswap(uint8_t n) { return n; } +#ifdef USE_LIBRETINY +// LibreTiny's Beken framework redefines __builtin_bswap functions as non-constexpr +template<> inline uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); } +template<> inline uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); } +template<> inline uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); } +template<> inline int8_t byteswap(int8_t n) { return n; } +template<> inline int16_t byteswap(int16_t n) { return __builtin_bswap16(n); } +template<> inline int32_t byteswap(int32_t n) { return __builtin_bswap32(n); } +template<> inline int64_t byteswap(int64_t n) { return __builtin_bswap64(n); } +#else template<> constexpr uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); } template<> constexpr uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); } template<> constexpr uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); } @@ -89,6 +102,7 @@ template<> constexpr int8_t byteswap(int8_t n) { return n; } template<> constexpr int16_t byteswap(int16_t n) { return __builtin_bswap16(n); } template<> constexpr int32_t byteswap(int32_t n) { return __builtin_bswap32(n); } template<> constexpr int64_t byteswap(int64_t n) { return __builtin_bswap64(n); } +#endif ///@} @@ -116,6 +130,16 @@ template class StaticVector { } } + // Return reference to next element and increment count (with bounds checking) + T &emplace_next() { + if (count_ >= N) { + // Should never happen with proper size calculation + // Return reference to last element to avoid crash + return data_[N - 1]; + } + return data_[count_++]; + } + size_t size() const { return count_; } bool empty() const { return count_ == 0; } @@ -589,6 +613,8 @@ ParseOnOffState parse_on_off(const char *str, const char *on = nullptr, const ch /// Create a string from a value and an accuracy in decimals. std::string value_accuracy_to_string(float value, int8_t accuracy_decimals); +/// Create a string from a value, an accuracy in decimals, and a unit of measurement. +std::string value_accuracy_with_uom_to_string(float value, int8_t accuracy_decimals, StringRef unit_of_measurement); /// Derive accuracy in decimals from an increment step. int8_t step_to_accuracy_decimals(float step); diff --git a/esphome/core/scheduler.cpp b/esphome/core/scheduler.cpp index 71e2a00fbe..402084f306 100644 --- a/esphome/core/scheduler.cpp +++ b/esphome/core/scheduler.cpp @@ -118,7 +118,6 @@ void HOT Scheduler::set_timer_common_(Component *component, SchedulerItem::Type item->type = type; item->callback = std::move(func); // Initialize remove to false (though it should already be from constructor) - // Not using mark_item_removed_ helper since we're setting to false, not true #ifdef ESPHOME_THREAD_MULTI_ATOMICS item->remove.store(false, std::memory_order_relaxed); #else @@ -600,12 +599,7 @@ bool HOT Scheduler::cancel_item_locked_(Component *component, const char *name_c #ifndef ESPHOME_THREAD_SINGLE // Mark items in defer queue as cancelled (they'll be skipped when processed) if (type == SchedulerItem::TIMEOUT) { - for (auto &item : this->defer_queue_) { - if (this->matches_item_(item, component, name_cstr, type, match_retry)) { - this->mark_item_removed_(item.get()); - total_cancelled++; - } - } + total_cancelled += this->mark_matching_items_removed_(this->defer_queue_, component, name_cstr, type, match_retry); } #endif /* not ESPHOME_THREAD_SINGLE */ @@ -620,23 +614,13 @@ bool HOT Scheduler::cancel_item_locked_(Component *component, const char *name_c total_cancelled++; } // For other items in heap, we can only mark for removal (can't remove from middle of heap) - for (auto &item : this->items_) { - if (this->matches_item_(item, component, name_cstr, type, match_retry)) { - this->mark_item_removed_(item.get()); - total_cancelled++; - this->to_remove_++; // Track removals for heap items - } - } + size_t heap_cancelled = this->mark_matching_items_removed_(this->items_, component, name_cstr, type, match_retry); + total_cancelled += heap_cancelled; + this->to_remove_ += heap_cancelled; // Track removals for heap items } // Cancel items in to_add_ - for (auto &item : this->to_add_) { - if (this->matches_item_(item, component, name_cstr, type, match_retry)) { - this->mark_item_removed_(item.get()); - total_cancelled++; - // Don't track removals for to_add_ items - } - } + total_cancelled += this->mark_matching_items_removed_(this->to_add_, component, name_cstr, type, match_retry); return total_cancelled > 0; } diff --git a/esphome/core/scheduler.h b/esphome/core/scheduler.h index 885ee13754..2237915e07 100644 --- a/esphome/core/scheduler.h +++ b/esphome/core/scheduler.h @@ -280,19 +280,30 @@ class Scheduler { #endif } - // Helper to mark item for removal (platform-specific) + // Helper to mark matching items in a container as removed + // Returns the number of items marked for removal // For ESPHOME_THREAD_MULTI_NO_ATOMICS platforms, the caller must hold the scheduler lock before calling this // function. - void mark_item_removed_(SchedulerItem *item) { + template + size_t mark_matching_items_removed_(Container &container, Component *component, const char *name_cstr, + SchedulerItem::Type type, bool match_retry) { + size_t count = 0; + for (auto &item : container) { + if (this->matches_item_(item, component, name_cstr, type, match_retry)) { + // Mark item for removal (platform-specific) #ifdef ESPHOME_THREAD_MULTI_ATOMICS - // Multi-threaded with atomics: use atomic store - item->remove.store(true, std::memory_order_release); + // Multi-threaded with atomics: use atomic store + item->remove.store(true, std::memory_order_release); #else - // Single-threaded (ESPHOME_THREAD_SINGLE) or - // multi-threaded without atomics (ESPHOME_THREAD_MULTI_NO_ATOMICS): direct write - // For ESPHOME_THREAD_MULTI_NO_ATOMICS, caller MUST hold lock! - item->remove = true; + // Single-threaded (ESPHOME_THREAD_SINGLE) or + // multi-threaded without atomics (ESPHOME_THREAD_MULTI_NO_ATOMICS): direct write + // For ESPHOME_THREAD_MULTI_NO_ATOMICS, caller MUST hold lock! + item->remove = true; #endif + count++; + } + } + return count; } // Template helper to check if any item in a container matches our criteria diff --git a/esphome/core/string_ref.cpp b/esphome/core/string_ref.cpp deleted file mode 100644 index ce1e33cbb7..0000000000 --- a/esphome/core/string_ref.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "string_ref.h" - -namespace esphome { - -#ifdef USE_JSON - -// NOLINTNEXTLINE(readability-identifier-naming) -void convertToJson(const StringRef &src, JsonVariant dst) { dst.set(src.c_str()); } - -#endif // USE_JSON - -} // namespace esphome diff --git a/esphome/core/string_ref.h b/esphome/core/string_ref.h index c4320107e3..efaa17181d 100644 --- a/esphome/core/string_ref.h +++ b/esphome/core/string_ref.h @@ -130,7 +130,7 @@ inline std::string operator+(const StringRef &lhs, const char *rhs) { #ifdef USE_JSON // NOLINTNEXTLINE(readability-identifier-naming) -void convertToJson(const StringRef &src, JsonVariant dst); +inline void convertToJson(const StringRef &src, JsonVariant dst) { dst.set(src.c_str()); } #endif // USE_JSON } // namespace esphome diff --git a/esphome/core/time.cpp b/esphome/core/time.cpp index fe6f50158c..1285ec6448 100644 --- a/esphome/core/time.cpp +++ b/esphome/core/time.cpp @@ -77,7 +77,7 @@ bool ESPTime::strptime(const std::string &time_to_parse, ESPTime &esp_time) { &hour, // NOLINT &minute, // NOLINT &second, &num) == 6 && // NOLINT - num == time_to_parse.size()) { + num == static_cast(time_to_parse.size())) { esp_time.year = year; esp_time.month = month; esp_time.day_of_month = day; @@ -87,7 +87,7 @@ bool ESPTime::strptime(const std::string &time_to_parse, ESPTime &esp_time) { } else if (sscanf(time_to_parse.c_str(), "%04hu-%02hhu-%02hhu %02hhu:%02hhu %n", &year, &month, &day, // NOLINT &hour, // NOLINT &minute, &num) == 5 && // NOLINT - num == time_to_parse.size()) { + num == static_cast(time_to_parse.size())) { esp_time.year = year; esp_time.month = month; esp_time.day_of_month = day; @@ -95,17 +95,17 @@ bool ESPTime::strptime(const std::string &time_to_parse, ESPTime &esp_time) { esp_time.minute = minute; esp_time.second = 0; } else if (sscanf(time_to_parse.c_str(), "%02hhu:%02hhu:%02hhu %n", &hour, &minute, &second, &num) == 3 && // NOLINT - num == time_to_parse.size()) { + num == static_cast(time_to_parse.size())) { esp_time.hour = hour; esp_time.minute = minute; esp_time.second = second; } else if (sscanf(time_to_parse.c_str(), "%02hhu:%02hhu %n", &hour, &minute, &num) == 2 && // NOLINT - num == time_to_parse.size()) { + num == static_cast(time_to_parse.size())) { esp_time.hour = hour; esp_time.minute = minute; esp_time.second = 0; } else if (sscanf(time_to_parse.c_str(), "%04hu-%02hhu-%02hhu %n", &year, &month, &day, &num) == 3 && // NOLINT - num == time_to_parse.size()) { + num == static_cast(time_to_parse.size())) { esp_time.year = year; esp_time.month = month; esp_time.day_of_month = day; diff --git a/esphome/cpp_generator.py b/esphome/cpp_generator.py index 291592dd2b..b2022c7ae6 100644 --- a/esphome/cpp_generator.py +++ b/esphome/cpp_generator.py @@ -1,5 +1,5 @@ import abc -from collections.abc import Callable, Sequence +from collections.abc import Callable import inspect import math import re @@ -13,7 +13,6 @@ from esphome.core import ( HexInt, Lambda, Library, - TimePeriod, TimePeriodMicroseconds, TimePeriodMilliseconds, TimePeriodMinutes, @@ -21,35 +20,11 @@ from esphome.core import ( TimePeriodSeconds, ) from esphome.helpers import cpp_string_escape, indent_all_but_first_and_last +from esphome.types import Expression, SafeExpType, TemplateArgsType from esphome.util import OrderedDict from esphome.yaml_util import ESPHomeDataBase -class Expression(abc.ABC): - __slots__ = () - - @abc.abstractmethod - def __str__(self): - """ - Convert expression into C++ code - """ - - -SafeExpType = ( - Expression - | bool - | str - | str - | int - | float - | TimePeriod - | type[bool] - | type[int] - | type[float] - | Sequence[Any] -) - - class RawExpression(Expression): __slots__ = ("text",) @@ -575,7 +550,7 @@ def Pvariable(id_: ID, rhs: SafeExpType, type_: "MockObj" = None) -> "MockObj": return obj -def new_Pvariable(id_: ID, *args: SafeExpType) -> Pvariable: +def new_Pvariable(id_: ID, *args: SafeExpType) -> "MockObj": """Declare a new pointer variable in the code generation by calling it's constructor with the given arguments. @@ -681,7 +656,7 @@ async def get_variable_with_full_id(id_: ID) -> tuple[ID, "MockObj"]: async def process_lambda( value: Lambda, - parameters: list[tuple[SafeExpType, str]], + parameters: TemplateArgsType, capture: str = "=", return_type: SafeExpType = None, ) -> LambdaExpression | None: diff --git a/esphome/dashboard/const.py b/esphome/dashboard/const.py index db66cb5ead..ada5575d0e 100644 --- a/esphome/dashboard/const.py +++ b/esphome/dashboard/const.py @@ -1,9 +1,26 @@ from __future__ import annotations -EVENT_ENTRY_ADDED = "entry_added" -EVENT_ENTRY_REMOVED = "entry_removed" -EVENT_ENTRY_UPDATED = "entry_updated" -EVENT_ENTRY_STATE_CHANGED = "entry_state_changed" +from esphome.enum import StrEnum + + +class DashboardEvent(StrEnum): + """Dashboard WebSocket event types.""" + + # Server -> Client events (backend sends to frontend) + ENTRY_ADDED = "entry_added" + ENTRY_REMOVED = "entry_removed" + ENTRY_UPDATED = "entry_updated" + ENTRY_STATE_CHANGED = "entry_state_changed" + IMPORTABLE_DEVICE_ADDED = "importable_device_added" + IMPORTABLE_DEVICE_REMOVED = "importable_device_removed" + INITIAL_STATE = "initial_state" # Sent on WebSocket connection + PONG = "pong" # Response to client ping + + # Client -> Server events (frontend sends to backend) + PING = "ping" # WebSocket keepalive from client + REFRESH = "refresh" # Force backend to poll for changes + + MAX_EXECUTOR_WORKERS = 48 diff --git a/esphome/dashboard/core.py b/esphome/dashboard/core.py index 410ef0c29d..b9ec56cd00 100644 --- a/esphome/dashboard/core.py +++ b/esphome/dashboard/core.py @@ -7,13 +7,13 @@ from dataclasses import dataclass from functools import partial import json import logging -from pathlib import Path import threading from typing import Any from esphome.storage_json import ignored_devices_storage_path from ..zeroconf import DiscoveredImport +from .const import DashboardEvent from .dns import DNSCache from .entries import DashboardEntries from .settings import DashboardSettings @@ -31,7 +31,7 @@ MDNS_BOOTSTRAP_TIME = 7.5 class Event: """Dashboard Event.""" - event_type: str + event_type: DashboardEvent data: dict[str, Any] @@ -40,22 +40,24 @@ class EventBus: def __init__(self) -> None: """Initialize the Dashboard event bus.""" - self._listeners: dict[str, set[Callable[[Event], None]]] = {} + self._listeners: dict[DashboardEvent, set[Callable[[Event], None]]] = {} def async_add_listener( - self, event_type: str, listener: Callable[[Event], None] + self, event_type: DashboardEvent, listener: Callable[[Event], None] ) -> Callable[[], None]: """Add a listener to the event bus.""" self._listeners.setdefault(event_type, set()).add(listener) return partial(self._async_remove_listener, event_type, listener) def _async_remove_listener( - self, event_type: str, listener: Callable[[Event], None] + self, event_type: DashboardEvent, listener: Callable[[Event], None] ) -> None: """Remove a listener from the event bus.""" self._listeners[event_type].discard(listener) - def async_fire(self, event_type: str, event_data: dict[str, Any]) -> None: + def async_fire( + self, event_type: DashboardEvent, event_data: dict[str, Any] + ) -> None: """Fire an event.""" event = Event(event_type, event_data) @@ -108,7 +110,7 @@ class ESPHomeDashboard: await self.loop.run_in_executor(None, self.load_ignored_devices) def load_ignored_devices(self) -> None: - storage_path = Path(ignored_devices_storage_path()) + storage_path = ignored_devices_storage_path() try: with storage_path.open("r", encoding="utf-8") as f_handle: data = json.load(f_handle) @@ -117,7 +119,7 @@ class ESPHomeDashboard: pass def save_ignored_devices(self) -> None: - storage_path = Path(ignored_devices_storage_path()) + storage_path = ignored_devices_storage_path() with storage_path.open("w", encoding="utf-8") as f_handle: json.dump( {"ignored_devices": sorted(self.ignored_devices)}, indent=2, fp=f_handle diff --git a/esphome/dashboard/dns.py b/esphome/dashboard/dns.py index 98134062f4..58867f7bc1 100644 --- a/esphome/dashboard/dns.py +++ b/esphome/dashboard/dns.py @@ -28,6 +28,21 @@ class DNSCache: self._cache: dict[str, tuple[float, list[str] | Exception]] = {} self._ttl = ttl + def get_cached_addresses( + self, hostname: str, now_monotonic: float + ) -> list[str] | None: + """Get cached addresses without triggering resolution. + + Returns None if not in cache, list of addresses if found. + """ + # Normalize hostname for consistent lookups + normalized = hostname.rstrip(".").lower() + if expire_time_addresses := self._cache.get(normalized): + expire_time, addresses = expire_time_addresses + if expire_time > now_monotonic and not isinstance(addresses, Exception): + return addresses + return None + async def async_resolve( self, hostname: str, now_monotonic: float ) -> list[str] | Exception: diff --git a/esphome/dashboard/entries.py b/esphome/dashboard/entries.py index b138cfd272..95b8a7b2ae 100644 --- a/esphome/dashboard/entries.py +++ b/esphome/dashboard/entries.py @@ -5,20 +5,14 @@ from collections import defaultdict from dataclasses import dataclass from functools import lru_cache import logging -import os +from pathlib import Path from typing import TYPE_CHECKING, Any from esphome import const, util from esphome.enum import StrEnum from esphome.storage_json import StorageJSON, ext_storage_path -from .const import ( - DASHBOARD_COMMAND, - EVENT_ENTRY_ADDED, - EVENT_ENTRY_REMOVED, - EVENT_ENTRY_STATE_CHANGED, - EVENT_ENTRY_UPDATED, -) +from .const import DASHBOARD_COMMAND, DashboardEvent from .util.subprocess import async_run_system_command if TYPE_CHECKING: @@ -102,12 +96,12 @@ class DashboardEntries: # "path/to/file.yaml": DashboardEntry, # ... # } - self._entries: dict[str, DashboardEntry] = {} + self._entries: dict[Path, DashboardEntry] = {} self._loaded_entries = False self._update_lock = asyncio.Lock() self._name_to_entry: dict[str, set[DashboardEntry]] = defaultdict(set) - def get(self, path: str) -> DashboardEntry | None: + def get(self, path: Path) -> DashboardEntry | None: """Get an entry by path.""" return self._entries.get(path) @@ -192,7 +186,7 @@ class DashboardEntries: return entry.state = state self._dashboard.bus.async_fire( - EVENT_ENTRY_STATE_CHANGED, {"entry": entry, "state": state} + DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state} ) async def async_request_update_entries(self) -> None: @@ -260,22 +254,22 @@ class DashboardEntries: for entry in added: entries[entry.path] = entry name_to_entry[entry.name].add(entry) - bus.async_fire(EVENT_ENTRY_ADDED, {"entry": entry}) + bus.async_fire(DashboardEvent.ENTRY_ADDED, {"entry": entry}) for entry in removed: del entries[entry.path] name_to_entry[entry.name].discard(entry) - bus.async_fire(EVENT_ENTRY_REMOVED, {"entry": entry}) + bus.async_fire(DashboardEvent.ENTRY_REMOVED, {"entry": entry}) for entry in updated: if (original_name := original_names[entry]) != (current_name := entry.name): name_to_entry[original_name].discard(entry) name_to_entry[current_name].add(entry) - bus.async_fire(EVENT_ENTRY_UPDATED, {"entry": entry}) + bus.async_fire(DashboardEvent.ENTRY_UPDATED, {"entry": entry}) - def _get_path_to_cache_key(self) -> dict[str, DashboardCacheKeyType]: + def _get_path_to_cache_key(self) -> dict[Path, DashboardCacheKeyType]: """Return a dict of path to cache key.""" - path_to_cache_key: dict[str, DashboardCacheKeyType] = {} + path_to_cache_key: dict[Path, DashboardCacheKeyType] = {} # # The cache key is (inode, device, mtime, size) # which allows us to avoid locking since it ensures @@ -287,12 +281,12 @@ class DashboardEntries: for file in util.list_yaml_files([self._config_dir]): try: # Prefer the json storage path if it exists - stat = os.stat(ext_storage_path(os.path.basename(file))) + stat = ext_storage_path(file.name).stat() except OSError: try: # Fallback to the yaml file if the storage # file does not exist or could not be generated - stat = os.stat(file) + stat = file.stat() except OSError: # File was deleted, ignore continue @@ -329,10 +323,10 @@ class DashboardEntry: "_to_dict", ) - def __init__(self, path: str, cache_key: DashboardCacheKeyType) -> None: + def __init__(self, path: Path, cache_key: DashboardCacheKeyType) -> None: """Initialize the DashboardEntry.""" self.path = path - self.filename: str = os.path.basename(path) + self.filename: str = path.name self._storage_path = ext_storage_path(self.filename) self.cache_key = cache_key self.storage: StorageJSON | None = None @@ -365,7 +359,7 @@ class DashboardEntry: "loaded_integrations": sorted(self.loaded_integrations), "deployed_version": self.update_old, "current_version": self.update_new, - "path": self.path, + "path": str(self.path), "comment": self.comment, "address": self.address, "web_port": self.web_port, diff --git a/esphome/dashboard/models.py b/esphome/dashboard/models.py new file mode 100644 index 0000000000..47ddddd5ce --- /dev/null +++ b/esphome/dashboard/models.py @@ -0,0 +1,76 @@ +"""Data models and builders for the dashboard.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict + +if TYPE_CHECKING: + from esphome.zeroconf import DiscoveredImport + + from .core import ESPHomeDashboard + from .entries import DashboardEntry + + +class ImportableDeviceDict(TypedDict): + """Dictionary representation of an importable device.""" + + name: str + friendly_name: str | None + package_import_url: str + project_name: str + project_version: str + network: str + ignored: bool + + +class ConfiguredDeviceDict(TypedDict, total=False): + """Dictionary representation of a configured device.""" + + name: str + friendly_name: str | None + configuration: str + loaded_integrations: list[str] | None + deployed_version: str | None + current_version: str | None + path: str + comment: str | None + address: str | None + web_port: int | None + target_platform: str | None + + +class DeviceListResponse(TypedDict): + """Response for device list API.""" + + configured: list[ConfiguredDeviceDict] + importable: list[ImportableDeviceDict] + + +def build_importable_device_dict( + dashboard: ESPHomeDashboard, discovered: DiscoveredImport +) -> ImportableDeviceDict: + """Build the importable device dictionary.""" + return ImportableDeviceDict( + name=discovered.device_name, + friendly_name=discovered.friendly_name, + package_import_url=discovered.package_import_url, + project_name=discovered.project_name, + project_version=discovered.project_version, + network=discovered.network, + ignored=discovered.device_name in dashboard.ignored_devices, + ) + + +def build_device_list_response( + dashboard: ESPHomeDashboard, entries: list[DashboardEntry] +) -> DeviceListResponse: + """Build the device list response data.""" + configured = {entry.name for entry in entries} + return DeviceListResponse( + configured=[entry.to_dict() for entry in entries], + importable=[ + build_importable_device_dict(dashboard, res) + for res in dashboard.import_result.values() + if res.device_name not in configured + ], + ) diff --git a/esphome/dashboard/settings.py b/esphome/dashboard/settings.py index fa39b55016..35b67c0d23 100644 --- a/esphome/dashboard/settings.py +++ b/esphome/dashboard/settings.py @@ -27,7 +27,7 @@ class DashboardSettings: def __init__(self) -> None: """Initialize the dashboard settings.""" - self.config_dir: str = "" + self.config_dir: Path = None self.password_hash: str = "" self.username: str = "" self.using_password: bool = False @@ -45,10 +45,10 @@ class DashboardSettings: self.using_password = bool(password) if self.using_password: self.password_hash = password_hash(password) - self.config_dir = args.configuration - self.absolute_config_dir = Path(self.config_dir).resolve() + self.config_dir = Path(args.configuration) + self.absolute_config_dir = self.config_dir.resolve() self.verbose = args.verbose - CORE.config_path = os.path.join(self.config_dir, ".") + CORE.config_path = self.config_dir / "." @property def relative_url(self) -> str: @@ -81,9 +81,9 @@ class DashboardSettings: # Compare password in constant running time (to prevent timing attacks) return hmac.compare_digest(self.password_hash, password_hash(password)) - def rel_path(self, *args: Any) -> str: + def rel_path(self, *args: Any) -> Path: """Return a path relative to the ESPHome config folder.""" - joined_path = os.path.join(self.config_dir, *args) + joined_path = self.config_dir / Path(*args) # Raises ValueError if not relative to ESPHome config folder - Path(joined_path).resolve().relative_to(self.absolute_config_dir) + joined_path.resolve().relative_to(self.absolute_config_dir) return joined_path diff --git a/esphome/dashboard/status/mdns.py b/esphome/dashboard/status/mdns.py index f9ac7b4289..881340ab24 100644 --- a/esphome/dashboard/status/mdns.py +++ b/esphome/dashboard/status/mdns.py @@ -4,16 +4,21 @@ import asyncio import logging import typing +from zeroconf import AddressResolver, IPVersion + +from esphome.address_cache import normalize_hostname from esphome.zeroconf import ( ESPHOME_SERVICE_TYPE, AsyncEsphomeZeroconf, DashboardBrowser, DashboardImportDiscovery, DashboardStatus, + DiscoveredImport, ) -from ..const import SENTINEL +from ..const import SENTINEL, DashboardEvent from ..entries import DashboardEntry, EntryStateSource, bool_to_entry_state +from ..models import build_importable_device_dict if typing.TYPE_CHECKING: from ..core import ESPHomeDashboard @@ -50,6 +55,44 @@ class MDNSStatus: return await aiozc.async_resolve_host(host_name) return None + def get_cached_addresses(self, host_name: str) -> list[str] | None: + """Get cached addresses for a host without triggering resolution. + + Returns None if not in cache or no zeroconf available. + """ + if not self.aiozc: + _LOGGER.debug("No zeroconf instance available for %s", host_name) + return None + + # Normalize hostname and get the base name + normalized = normalize_hostname(host_name) + base_name = normalized.partition(".")[0] + + # Try to load from zeroconf cache without triggering resolution + resolver_name = f"{base_name}.local." + info = AddressResolver(resolver_name) + # Let zeroconf use its own current time for cache checking + if info.load_from_cache(self.aiozc.zeroconf): + addresses = info.parsed_scoped_addresses(IPVersion.All) + _LOGGER.debug("Found %s in zeroconf cache: %s", resolver_name, addresses) + return addresses + _LOGGER.debug("Not found in zeroconf cache: %s", resolver_name) + return None + + def _on_import_update(self, name: str, discovered: DiscoveredImport | None) -> None: + """Handle importable device updates.""" + if discovered is None: + # Device removed + self.dashboard.bus.async_fire( + DashboardEvent.IMPORTABLE_DEVICE_REMOVED, {"name": name} + ) + else: + # Device added + self.dashboard.bus.async_fire( + DashboardEvent.IMPORTABLE_DEVICE_ADDED, + {"device": build_importable_device_dict(self.dashboard, discovered)}, + ) + async def async_refresh_hosts(self) -> None: """Refresh the hosts to track.""" dashboard = self.dashboard @@ -106,7 +149,8 @@ class MDNSStatus: self._async_set_state(entry, result) stat = DashboardStatus(on_update) - imports = DashboardImportDiscovery() + + imports = DashboardImportDiscovery(self._on_import_update) dashboard.import_result = imports.import_state browser = DashboardBrowser( diff --git a/esphome/dashboard/util/file.py b/esphome/dashboard/util/file.py deleted file mode 100644 index bb263f9ad7..0000000000 --- a/esphome/dashboard/util/file.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging -import os -from pathlib import Path -import tempfile - -_LOGGER = logging.getLogger(__name__) - - -def write_utf8_file( - filename: Path, - utf8_str: str, - private: bool = False, -) -> None: - """Write a file and rename it into place. - - Writes all or nothing. - """ - write_file(filename, utf8_str.encode("utf-8"), private) - - -# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py -def write_file( - filename: Path, - utf8_data: bytes, - private: bool = False, -) -> None: - """Write a file and rename it into place. - - Writes all or nothing. - """ - - tmp_filename = "" - missing_fchmod = False - try: - # Modern versions of Python tempfile create this file with mode 0o600 - with tempfile.NamedTemporaryFile( - mode="wb", dir=os.path.dirname(filename), delete=False - ) as fdesc: - fdesc.write(utf8_data) - tmp_filename = fdesc.name - if not private: - try: - os.fchmod(fdesc.fileno(), 0o644) - except AttributeError: - # os.fchmod is not available on Windows - missing_fchmod = True - - os.replace(tmp_filename, filename) - if missing_fchmod: - os.chmod(filename, 0o644) - finally: - if os.path.exists(tmp_filename): - try: - os.remove(tmp_filename) - except OSError as err: - # If we are cleaning up then something else went wrong, so - # we should suppress likely follow-on errors in the cleanup - _LOGGER.error( - "File replacement cleanup failed for %s while saving %s: %s", - tmp_filename, - filename, - err, - ) diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index e6c5fd3d84..a79c67c3d2 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -4,8 +4,10 @@ import asyncio import base64 import binascii from collections.abc import Callable, Iterable +import contextlib import datetime import functools +from functools import partial import gzip import hashlib import importlib @@ -49,10 +51,11 @@ from esphome.storage_json import ( from esphome.util import get_serial_ports, shlex_quote from esphome.yaml_util import FastestAvailableSafeLoader -from .const import DASHBOARD_COMMAND -from .core import DASHBOARD -from .entries import UNKNOWN_STATE, entry_state_to_bool -from .util.file import write_file +from ..helpers import write_file +from .const import DASHBOARD_COMMAND, DashboardEvent +from .core import DASHBOARD, ESPHomeDashboard, Event +from .entries import UNKNOWN_STATE, DashboardEntry, entry_state_to_bool +from .models import build_device_list_response from .util.subprocess import async_run_system_command from .util.text import friendly_name_slugify @@ -283,11 +286,23 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): def _stdout_thread(self) -> None: if not self._use_popen: return + line = b"" + cr = False while True: - data = self._proc.stdout.readline() + data = self._proc.stdout.read(1) if data: - data = data.replace(b"\r", b"") - self._queue.put_nowait(data) + if data == b"\r": + cr = True + elif data == b"\n": + self._queue.put_nowait(line + b"\n") + line = b"" + cr = False + elif cr: + self._queue.put_nowait(line + b"\r") + line = data + cr = False + else: + line += data if self._proc.poll() is not None: break self._proc.wait(1.0) @@ -314,6 +329,73 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): raise NotImplementedError +def build_cache_arguments( + entry: DashboardEntry | None, + dashboard: ESPHomeDashboard, + now: float, +) -> list[str]: + """Build cache arguments for passing to CLI. + + Args: + entry: Dashboard entry for the configuration + dashboard: Dashboard instance with cache access + now: Current monotonic time for DNS cache expiry checks + + Returns: + List of cache arguments to pass to CLI + """ + cache_args: list[str] = [] + + if not entry: + return cache_args + + _LOGGER.debug( + "Building cache for entry (address=%s, name=%s)", + entry.address, + entry.name, + ) + + def add_cache_entry(hostname: str, addresses: list[str], cache_type: str) -> None: + """Add a cache entry to the command arguments.""" + if not addresses: + return + normalized = hostname.rstrip(".").lower() + cache_args.extend( + [ + f"--{cache_type}-address-cache", + f"{normalized}={','.join(sort_ip_addresses(addresses))}", + ] + ) + + # Check entry.address for cached addresses + if use_address := entry.address: + if use_address.endswith(".local"): + # mDNS cache for .local addresses + if (mdns := dashboard.mdns_status) and ( + cached := mdns.get_cached_addresses(use_address) + ): + _LOGGER.debug("mDNS cache hit for %s: %s", use_address, cached) + add_cache_entry(use_address, cached, "mdns") + # DNS cache for non-.local addresses + elif cached := dashboard.dns_cache.get_cached_addresses(use_address, now): + _LOGGER.debug("DNS cache hit for %s: %s", use_address, cached) + add_cache_entry(use_address, cached, "dns") + + # Check entry.name if we haven't already cached via address + # For mDNS devices, entry.name typically doesn't have .local suffix + if entry.name and not use_address: + mdns_name = ( + f"{entry.name}.local" if not entry.name.endswith(".local") else entry.name + ) + if (mdns := dashboard.mdns_status) and ( + cached := mdns.get_cached_addresses(mdns_name) + ): + _LOGGER.debug("mDNS cache hit for %s: %s", mdns_name, cached) + add_cache_entry(mdns_name, cached, "mdns") + + return cache_args + + class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): """Base class for commands that require a port.""" @@ -326,52 +408,22 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): configuration = json_message["configuration"] config_file = settings.rel_path(configuration) port = json_message["port"] - addresses: list[str] = [] + + # Build cache arguments to pass to CLI + cache_args: list[str] = [] + if ( port == "OTA" # pylint: disable=too-many-boolean-expressions and (entry := entries.get(config_file)) and entry.loaded_integrations and "api" in entry.loaded_integrations ): - # First priority: entry.address AKA use_address - if ( - (use_address := entry.address) - and ( - address_list := await dashboard.dns_cache.async_resolve( - use_address, time.monotonic() - ) - ) - and not isinstance(address_list, Exception) - ): - addresses.extend(sort_ip_addresses(address_list)) + cache_args = build_cache_arguments(entry, dashboard, time.monotonic()) - # Second priority: mDNS - if ( - (mdns := dashboard.mdns_status) - and (address_list := await mdns.async_resolve_host(entry.name)) - and ( - new_addresses := [ - addr for addr in address_list if addr not in addresses - ] - ) - ): - # Use the IP address if available but only - # if the API is loaded and the device is online - # since MQTT logging will not work otherwise - addresses.extend(sort_ip_addresses(new_addresses)) - - if not addresses: - # If no address was found, use the port directly - # as otherwise they will get the chooser which - # does not work with the dashboard as there is no - # interactive way to get keyboard input - addresses = [port] - - device_args: list[str] = [ - arg for address in addresses for arg in ("--device", address) - ] - - return [*DASHBOARD_COMMAND, *args, config_file, *device_args] + # Cache arguments must come before the subcommand + cmd = [*DASHBOARD_COMMAND, *cache_args, *args, config_file, "--device", port] + _LOGGER.debug("Built command: %s", cmd) + return cmd class EsphomeLogsHandler(EsphomePortCommandWebSocket): @@ -442,6 +494,14 @@ class EsphomeCleanMqttHandler(EsphomeCommandWebSocket): return [*DASHBOARD_COMMAND, "clean-mqtt", config_file] +class EsphomeCleanAllHandler(EsphomeCommandWebSocket): + async def build_command(self, json_message: dict[str, Any]) -> list[str]: + clean_build_dir = json_message.get("clean_build_dir", True) + if clean_build_dir: + return [*DASHBOARD_COMMAND, "clean-all", settings.config_dir] + return [*DASHBOARD_COMMAND, "clean-all"] + + class EsphomeCleanHandler(EsphomeCommandWebSocket): async def build_command(self, json_message: dict[str, Any]) -> list[str]: config_file = settings.rel_path(json_message["configuration"]) @@ -463,6 +523,243 @@ class EsphomeUpdateAllHandler(EsphomeCommandWebSocket): return [*DASHBOARD_COMMAND, "update-all", settings.config_dir] +# Dashboard polling constants +DASHBOARD_POLL_INTERVAL = 2 # seconds +DASHBOARD_ENTRIES_UPDATE_INTERVAL = 10 # seconds +DASHBOARD_ENTRIES_UPDATE_ITERATIONS = ( + DASHBOARD_ENTRIES_UPDATE_INTERVAL // DASHBOARD_POLL_INTERVAL +) + + +class DashboardSubscriber: + """Manages dashboard event polling task lifecycle based on active subscribers.""" + + def __init__(self) -> None: + """Initialize the dashboard subscriber.""" + self._subscribers: set[DashboardEventsWebSocket] = set() + self._event_loop_task: asyncio.Task | None = None + self._refresh_event: asyncio.Event = asyncio.Event() + + def subscribe(self, subscriber: DashboardEventsWebSocket) -> Callable[[], None]: + """Subscribe to dashboard updates and start event loop if needed.""" + self._subscribers.add(subscriber) + if not self._event_loop_task or self._event_loop_task.done(): + self._event_loop_task = asyncio.create_task(self._event_loop()) + _LOGGER.info("Started dashboard event loop") + return partial(self._unsubscribe, subscriber) + + def _unsubscribe(self, subscriber: DashboardEventsWebSocket) -> None: + """Unsubscribe from dashboard updates and stop event loop if no subscribers.""" + self._subscribers.discard(subscriber) + if ( + not self._subscribers + and self._event_loop_task + and not self._event_loop_task.done() + ): + self._event_loop_task.cancel() + self._event_loop_task = None + _LOGGER.info("Stopped dashboard event loop - no subscribers") + + def request_refresh(self) -> None: + """Signal the polling loop to refresh immediately.""" + self._refresh_event.set() + + async def _event_loop(self) -> None: + """Run the event polling loop while there are subscribers.""" + dashboard = DASHBOARD + entries_update_counter = 0 + + while self._subscribers: + # Signal that we need ping updates (non-blocking) + dashboard.ping_request.set() + if settings.status_use_mqtt: + dashboard.mqtt_ping_request.set() + + # Check if it's time to update entries or if refresh was requested + entries_update_counter += 1 + if ( + entries_update_counter >= DASHBOARD_ENTRIES_UPDATE_ITERATIONS + or self._refresh_event.is_set() + ): + entries_update_counter = 0 + await dashboard.entries.async_request_update_entries() + # Clear the refresh event if it was set + self._refresh_event.clear() + + # Wait for either timeout or refresh event + try: + async with asyncio.timeout(DASHBOARD_POLL_INTERVAL): + await self._refresh_event.wait() + # If we get here, refresh was requested - continue loop immediately + except TimeoutError: + # Normal timeout - continue with regular polling + pass + + +# Global dashboard subscriber instance +DASHBOARD_SUBSCRIBER = DashboardSubscriber() + + +@websocket_class +class DashboardEventsWebSocket(tornado.websocket.WebSocketHandler): + """WebSocket handler for real-time dashboard events.""" + + _event_listeners: list[Callable[[], None]] | None = None + _dashboard_unsubscribe: Callable[[], None] | None = None + + async def get(self, *args: str, **kwargs: str) -> None: + """Handle WebSocket upgrade request.""" + if not is_authenticated(self): + self.set_status(401) + self.finish("Unauthorized") + return + await super().get(*args, **kwargs) + + async def open(self, *args: str, **kwargs: str) -> None: # pylint: disable=invalid-overridden-method + """Handle new WebSocket connection.""" + # Ensure messages are sent immediately to avoid + # a 200-500ms delay when nodelay is not set. + self.set_nodelay(True) + + # Update entries first + await DASHBOARD.entries.async_request_update_entries() + # Send initial state + self._send_initial_state() + # Subscribe to events + self._subscribe_to_events() + # Subscribe to dashboard updates + self._dashboard_unsubscribe = DASHBOARD_SUBSCRIBER.subscribe(self) + _LOGGER.debug("Dashboard status WebSocket opened") + + def _send_initial_state(self) -> None: + """Send initial device list and ping status.""" + entries = DASHBOARD.entries.async_all() + + # Send initial state + self._safe_send_message( + { + "event": DashboardEvent.INITIAL_STATE, + "data": { + "devices": build_device_list_response(DASHBOARD, entries), + "ping": { + entry.filename: entry_state_to_bool(entry.state) + for entry in entries + }, + }, + } + ) + + def _subscribe_to_events(self) -> None: + """Subscribe to dashboard events.""" + async_add_listener = DASHBOARD.bus.async_add_listener + # Subscribe to all events + self._event_listeners = [ + async_add_listener( + DashboardEvent.ENTRY_STATE_CHANGED, self._on_entry_state_changed + ), + async_add_listener( + DashboardEvent.ENTRY_ADDED, + self._make_entry_handler(DashboardEvent.ENTRY_ADDED), + ), + async_add_listener( + DashboardEvent.ENTRY_REMOVED, + self._make_entry_handler(DashboardEvent.ENTRY_REMOVED), + ), + async_add_listener( + DashboardEvent.ENTRY_UPDATED, + self._make_entry_handler(DashboardEvent.ENTRY_UPDATED), + ), + async_add_listener( + DashboardEvent.IMPORTABLE_DEVICE_ADDED, self._on_importable_added + ), + async_add_listener( + DashboardEvent.IMPORTABLE_DEVICE_REMOVED, + self._on_importable_removed, + ), + ] + + def _on_entry_state_changed(self, event: Event) -> None: + """Handle entry state change event.""" + entry = event.data["entry"] + state = event.data["state"] + self._safe_send_message( + { + "event": DashboardEvent.ENTRY_STATE_CHANGED, + "data": { + "filename": entry.filename, + "name": entry.name, + "state": entry_state_to_bool(state), + }, + } + ) + + def _make_entry_handler( + self, event_type: DashboardEvent + ) -> Callable[[Event], None]: + """Create an entry event handler.""" + + def handler(event: Event) -> None: + self._safe_send_message( + {"event": event_type, "data": {"device": event.data["entry"].to_dict()}} + ) + + return handler + + def _on_importable_added(self, event: Event) -> None: + """Handle importable device added event.""" + # Don't send if device is already configured + device_name = event.data.get("device", {}).get("name") + if device_name and DASHBOARD.entries.get_by_name(device_name): + return + self._safe_send_message( + {"event": DashboardEvent.IMPORTABLE_DEVICE_ADDED, "data": event.data} + ) + + def _on_importable_removed(self, event: Event) -> None: + """Handle importable device removed event.""" + self._safe_send_message( + {"event": DashboardEvent.IMPORTABLE_DEVICE_REMOVED, "data": event.data} + ) + + def _safe_send_message(self, message: dict[str, Any]) -> None: + """Send a message to the WebSocket client, ignoring closed errors.""" + with contextlib.suppress(tornado.websocket.WebSocketClosedError): + self.write_message(json.dumps(message)) + + def on_message(self, message: str) -> None: + """Handle incoming WebSocket messages.""" + _LOGGER.debug("WebSocket received message: %s", message) + try: + data = json.loads(message) + except json.JSONDecodeError as err: + _LOGGER.debug("Failed to parse WebSocket message: %s", err) + return + + event = data.get("event") + _LOGGER.debug("WebSocket message event: %s", event) + if event == DashboardEvent.PING: + # Send pong response for client ping + _LOGGER.debug("Received client ping, sending pong") + self._safe_send_message({"event": DashboardEvent.PONG}) + elif event == DashboardEvent.REFRESH: + # Signal the polling loop to refresh immediately + _LOGGER.debug("Received refresh request, signaling polling loop") + DASHBOARD_SUBSCRIBER.request_refresh() + + def on_close(self) -> None: + """Handle WebSocket close.""" + # Unsubscribe from dashboard updates + if self._dashboard_unsubscribe: + self._dashboard_unsubscribe() + self._dashboard_unsubscribe = None + + # Unsubscribe from events + for remove_listener in self._event_listeners or []: + remove_listener() + + _LOGGER.debug("Dashboard status WebSocket closed") + + class SerialPortRequestHandler(BaseHandler): @authenticated async def get(self) -> None: @@ -544,7 +841,7 @@ class WizardRequestHandler(BaseHandler): destination = settings.rel_path(filename) # Check if destination file already exists - if os.path.exists(destination): + if destination.exists(): self.set_status(409) # Conflict status code self.set_header("content-type", "application/json") self.write( @@ -761,10 +1058,9 @@ class DownloadBinaryRequestHandler(BaseHandler): "download", f"{storage_json.name}-{file_name}", ) - path = os.path.dirname(storage_json.firmware_bin_path) - path = os.path.join(path, file_name) + path = storage_json.firmware_bin_path.with_name(file_name) - if not Path(path).is_file(): + if not path.is_file(): args = ["esphome", "idedata", settings.rel_path(configuration)] rc, stdout, _ = await async_run_system_command(args) @@ -818,28 +1114,7 @@ class ListDevicesHandler(BaseHandler): await dashboard.entries.async_request_update_entries() entries = dashboard.entries.async_all() self.set_header("content-type", "application/json") - configured = {entry.name for entry in entries} - - self.write( - json.dumps( - { - "configured": [entry.to_dict() for entry in entries], - "importable": [ - { - "name": res.device_name, - "friendly_name": res.friendly_name, - "package_import_url": res.package_import_url, - "project_name": res.project_name, - "project_version": res.project_version, - "network": res.network, - "ignored": res.device_name in dashboard.ignored_devices, - } - for res in dashboard.import_result.values() - if res.device_name not in configured - ], - } - ) - ) + self.write(json.dumps(build_device_list_response(dashboard, entries))) class MainRequestHandler(BaseHandler): @@ -979,7 +1254,7 @@ class EditRequestHandler(BaseHandler): return filename = settings.rel_path(configuration) - if Path(filename).resolve().parent != settings.absolute_config_dir: + if filename.resolve().parent != settings.absolute_config_dir: self.send_error(404) return @@ -1002,10 +1277,6 @@ class EditRequestHandler(BaseHandler): self.set_status(404) return None - def _write_file(self, filename: str, content: bytes) -> None: - """Write a file with the given content.""" - write_file(filename, content) - @authenticated @bind_config async def post(self, configuration: str | None = None) -> None: @@ -1015,12 +1286,12 @@ class EditRequestHandler(BaseHandler): return filename = settings.rel_path(configuration) - if Path(filename).resolve().parent != settings.absolute_config_dir: + if filename.resolve().parent != settings.absolute_config_dir: self.send_error(404) return loop = asyncio.get_running_loop() - await loop.run_in_executor(None, self._write_file, filename, self.request.body) + await loop.run_in_executor(None, write_file, filename, self.request.body) # Ensure the StorageJSON is updated as well DASHBOARD.entries.async_schedule_storage_json_update(filename) self.set_status(200) @@ -1035,7 +1306,7 @@ class ArchiveRequestHandler(BaseHandler): archive_path = archive_storage_path() mkdir_p(archive_path) - shutil.move(config_file, os.path.join(archive_path, configuration)) + shutil.move(config_file, archive_path / configuration) storage_json = StorageJSON.load(storage_path) if storage_json is not None and storage_json.build_path: @@ -1049,7 +1320,7 @@ class UnArchiveRequestHandler(BaseHandler): def post(self, configuration: str | None = None) -> None: config_file = settings.rel_path(configuration) archive_path = archive_storage_path() - shutil.move(os.path.join(archive_path, configuration), config_file) + shutil.move(archive_path / configuration, config_file) class LoginHandler(BaseHandler): @@ -1136,7 +1407,7 @@ class SecretKeysRequestHandler(BaseHandler): for secret_filename in const.SECRETS_FILES: relative_filename = settings.rel_path(secret_filename) - if os.path.isfile(relative_filename): + if relative_filename.is_file(): filename = relative_filename break @@ -1169,16 +1440,17 @@ class JsonConfigRequestHandler(BaseHandler): @bind_config async def get(self, configuration: str | None = None) -> None: filename = settings.rel_path(configuration) - if not os.path.isfile(filename): + if not filename.is_file(): self.send_error(404) return - args = ["esphome", "config", filename, "--show-secrets"] + args = ["esphome", "config", str(filename), "--show-secrets"] - rc, stdout, _ = await async_run_system_command(args) + rc, stdout, stderr = await async_run_system_command(args) if rc != 0: - self.send_error(422) + self.set_status(422) + self.write(stderr) return data = yaml.load(stdout, Loader=SafeLoaderIgnoreUnknown) @@ -1187,7 +1459,7 @@ class JsonConfigRequestHandler(BaseHandler): self.finish() -def get_base_frontend_path() -> str: +def get_base_frontend_path() -> Path: if ENV_DEV not in os.environ: import esphome_dashboard @@ -1198,11 +1470,12 @@ def get_base_frontend_path() -> str: static_path += "/" # This path can be relative, so resolve against the root or else templates don't work - return os.path.abspath(os.path.join(os.getcwd(), static_path, "esphome_dashboard")) + path = Path(os.getcwd()) / static_path / "esphome_dashboard" + return path.resolve() -def get_static_path(*args: Iterable[str]) -> str: - return os.path.join(get_base_frontend_path(), "static", *args) +def get_static_path(*args: Iterable[str]) -> Path: + return get_base_frontend_path() / "static" / Path(*args) @functools.cache @@ -1219,8 +1492,7 @@ def get_static_file_url(name: str) -> str: return base.replace("index.js", esphome_dashboard.entrypoint()) path = get_static_path(name) - with open(path, "rb") as f_handle: - hash_ = hashlib.md5(f_handle.read()).hexdigest()[:8] + hash_ = hashlib.md5(path.read_bytes()).hexdigest()[:8] return f"{base}?hash={hash_}" @@ -1280,6 +1552,7 @@ def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application: (f"{rel}compile", EsphomeCompileHandler), (f"{rel}validate", EsphomeValidateHandler), (f"{rel}clean-mqtt", EsphomeCleanMqttHandler), + (f"{rel}clean-all", EsphomeCleanAllHandler), (f"{rel}clean", EsphomeCleanHandler), (f"{rel}vscode", EsphomeVscodeHandler), (f"{rel}ace", EsphomeAceEditorHandler), @@ -1297,6 +1570,7 @@ def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application: (f"{rel}wizard", WizardRequestHandler), (f"{rel}static/(.*)", StaticFileHandler, {"path": get_static_path()}), (f"{rel}devices", ListDevicesHandler), + (f"{rel}events", DashboardEventsWebSocket), (f"{rel}import", ImportRequestHandler), (f"{rel}secret_keys", SecretKeysRequestHandler), (f"{rel}json-config", JsonConfigRequestHandler), @@ -1320,7 +1594,7 @@ def start_web_server( """Start the web server listener.""" trash_path = trash_storage_path() - if os.path.exists(trash_path): + if trash_path.is_dir() and trash_path.exists(): _LOGGER.info("Renaming 'trash' folder to 'archive'") archive_path = archive_storage_path() shutil.move(trash_path, archive_path) diff --git a/esphome/espota2.py b/esphome/espota2.py index 3d25af985b..2712d00127 100644 --- a/esphome/espota2.py +++ b/esphome/espota2.py @@ -1,19 +1,23 @@ from __future__ import annotations +from collections.abc import Callable import gzip import hashlib import io import logging +from pathlib import Path import random import socket import sys import time +from typing import Any from esphome.core import EsphomeError from esphome.helpers import resolve_ip_address RESPONSE_OK = 0x00 RESPONSE_REQUEST_AUTH = 0x01 +RESPONSE_REQUEST_SHA256_AUTH = 0x02 RESPONSE_HEADER_OK = 0x40 RESPONSE_AUTH_OK = 0x41 @@ -44,6 +48,7 @@ OTA_VERSION_2_0 = 2 MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45] FEATURE_SUPPORTS_COMPRESSION = 0x01 +FEATURE_SUPPORTS_SHA256_AUTH = 0x02 UPLOAD_BLOCK_SIZE = 8192 @@ -51,6 +56,12 @@ UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8 _LOGGER = logging.getLogger(__name__) +# Authentication method lookup table: response -> (hash_func, nonce_size, name) +_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = { + RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"), + RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"), +} + class ProgressBar: def __init__(self): @@ -80,18 +91,43 @@ class OTAError(EsphomeError): pass -def recv_decode(sock, amount, decode=True): +def recv_decode( + sock: socket.socket, amount: int, decode: bool = True +) -> bytes | list[int]: + """Receive data from socket and optionally decode to list of integers. + + :param sock: Socket to receive data from. + :param amount: Number of bytes to receive. + :param decode: If True, convert bytes to list of integers, otherwise return raw bytes. + :return: List of integers if decode=True, otherwise raw bytes. + """ data = sock.recv(amount) if not decode: return data return list(data) -def receive_exactly(sock, amount, msg, expect, decode=True): - data = [] if decode else b"" +def receive_exactly( + sock: socket.socket, + amount: int, + msg: str, + expect: int | list[int] | None, + decode: bool = True, +) -> list[int] | bytes: + """Receive exactly the specified amount of data from socket with error checking. + + :param sock: Socket to receive data from. + :param amount: Exact number of bytes to receive. + :param msg: Description of what is being received for error messages. + :param expect: Expected response code(s) for validation, None to skip validation. + :param decode: If True, return list of integers, otherwise return raw bytes. + :return: List of integers if decode=True, otherwise raw bytes. + :raises OTAError: If receiving fails or response doesn't match expected. + """ + data: list[int] | bytes = [] if decode else b"" try: - data += recv_decode(sock, 1, decode=decode) + data += recv_decode(sock, 1, decode=decode) # type: ignore[operator] except OSError as err: raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err @@ -103,13 +139,19 @@ def receive_exactly(sock, amount, msg, expect, decode=True): while len(data) < amount: try: - data += recv_decode(sock, amount - len(data), decode=decode) + data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator] except OSError as err: raise OTAError(f"Error receiving {msg}: {err}") from err return data -def check_error(data, expect): +def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None: + """Check response data for error codes and validate against expected response. + + :param data: Response data from device (first byte is the response code). + :param expect: Expected response code(s), None to skip validation. + :raises OTAError: If an error code is detected or response doesn't match expected. + """ if not expect: return dat = data[0] @@ -124,7 +166,7 @@ def check_error(data, expect): raise OTAError("Error: Authentication invalid. Is the password correct?") if dat == RESPONSE_ERROR_WRITING_FLASH: raise OTAError( - "Error: Wring OTA data to flash memory failed. See USB logs for more " + "Error: Writing OTA data to flash memory failed. See USB logs for more " "information." ) if dat == RESPONSE_ERROR_UPDATE_END: @@ -176,7 +218,16 @@ def check_error(data, expect): raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}") -def send_check(sock, data, msg): +def send_check( + sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str +) -> None: + """Send data to socket with error handling. + + :param sock: Socket to send data to. + :param data: Data to send (can be list/tuple of ints, single int, string, or bytes). + :param msg: Description of what is being sent for error messages. + :raises OTAError: If sending fails. + """ try: if isinstance(data, (list, tuple)): data = bytes(data) @@ -191,7 +242,7 @@ def send_check(sock, data, msg): def perform_ota( - sock: socket.socket, password: str, file_handle: io.IOBase, filename: str + sock: socket.socket, password: str, file_handle: io.IOBase, filename: Path ) -> None: file_contents = file_handle.read() file_size = len(file_contents) @@ -209,10 +260,14 @@ def perform_ota( f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}" ) - # Features - send_check(sock, FEATURE_SUPPORTS_COMPRESSION, "features") + # Features - send both compression and SHA256 auth support + features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH + send_check(sock, features_to_send, "features") features = receive_exactly( - sock, 1, "features", [RESPONSE_HEADER_OK, RESPONSE_SUPPORTS_COMPRESSION] + sock, + 1, + "features", + None, # Accept any response )[0] if features == RESPONSE_SUPPORTS_COMPRESSION: @@ -221,31 +276,52 @@ def perform_ota( else: upload_contents = file_contents - (auth,) = receive_exactly( - sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK] - ) - if auth == RESPONSE_REQUEST_AUTH: + def perform_auth( + sock: socket.socket, + password: str, + hash_func: Callable[..., Any], + nonce_size: int, + hash_name: str, + ) -> None: + """Perform challenge-response authentication using specified hash algorithm.""" if not password: raise OTAError("ESP requests password, but no password given!") - nonce = receive_exactly( - sock, 32, "authentication nonce", [], decode=False - ).decode() - _LOGGER.debug("Auth: Nonce is %s", nonce) - cnonce = hashlib.md5(str(random.random()).encode()).hexdigest() - _LOGGER.debug("Auth: CNonce is %s", cnonce) + + nonce_bytes = receive_exactly( + sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False + ) + assert isinstance(nonce_bytes, bytes) + nonce = nonce_bytes.decode() + _LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce) + + # Generate cnonce + cnonce = hash_func(str(random.random()).encode()).hexdigest() + _LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce) send_check(sock, cnonce, "auth cnonce") - result_md5 = hashlib.md5() - result_md5.update(password.encode("utf-8")) - result_md5.update(nonce.encode()) - result_md5.update(cnonce.encode()) - result = result_md5.hexdigest() - _LOGGER.debug("Auth: Result is %s", result) + # Calculate challenge response + hasher = hash_func() + hasher.update(password.encode("utf-8")) + hasher.update(nonce.encode()) + hasher.update(cnonce.encode()) + result = hasher.hexdigest() + _LOGGER.debug("Auth: %s Result is %s", hash_name, result) send_check(sock, result, "auth result") receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK) + (auth,) = receive_exactly( + sock, + 1, + "auth", + [RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK], + ) + + if auth != RESPONSE_AUTH_OK: + hash_func, nonce_size, hash_name = _AUTH_METHODS[auth] + perform_auth(sock, password, hash_func, nonce_size, hash_name) + # Set higher timeout during upload sock.settimeout(30.0) @@ -309,12 +385,16 @@ def perform_ota( def run_ota_impl_( - remote_host: str | list[str], remote_port: int, password: str, filename: str + remote_host: str | list[str], remote_port: int, password: str, filename: Path ) -> tuple[int, str | None]: + from esphome.core import CORE + # Handle both single host and list of hosts try: # Resolve all hosts at once for parallel DNS resolution - res = resolve_ip_address(remote_host, remote_port) + res = resolve_ip_address( + remote_host, remote_port, address_cache=CORE.address_cache + ) except EsphomeError as err: _LOGGER.error( "Error resolving IP address of %s. Is it connected to WiFi?", @@ -356,7 +436,7 @@ def run_ota_impl_( def run_ota( - remote_host: str | list[str], remote_port: int, password: str, filename: str + remote_host: str | list[str], remote_port: int, password: str, filename: Path ) -> tuple[int, str | None]: try: return run_ota_impl_(remote_host, remote_port, password, filename) diff --git a/esphome/external_files.py b/esphome/external_files.py index 057ff52f3f..80b54ebb2f 100644 --- a/esphome/external_files.py +++ b/esphome/external_files.py @@ -2,7 +2,6 @@ from __future__ import annotations from datetime import datetime import logging -import os from pathlib import Path import requests @@ -23,11 +22,11 @@ CONTENT_DISPOSITION = "content-disposition" TEMP_DIR = "temp" -def has_remote_file_changed(url, local_file_path): - if os.path.exists(local_file_path): +def has_remote_file_changed(url: str, local_file_path: Path) -> bool: + if local_file_path.exists(): _LOGGER.debug("has_remote_file_changed: File exists at %s", local_file_path) try: - local_modification_time = os.path.getmtime(local_file_path) + local_modification_time = local_file_path.stat().st_mtime local_modification_time_str = datetime.utcfromtimestamp( local_modification_time ).strftime("%a, %d %b %Y %H:%M:%S GMT") @@ -65,9 +64,9 @@ def has_remote_file_changed(url, local_file_path): return True -def is_file_recent(file_path: str, refresh: TimePeriodSeconds) -> bool: - if os.path.exists(file_path): - creation_time = os.path.getctime(file_path) +def is_file_recent(file_path: Path, refresh: TimePeriodSeconds) -> bool: + if file_path.exists(): + creation_time = file_path.stat().st_ctime current_time = datetime.now().timestamp() return current_time - creation_time <= refresh.total_seconds return False diff --git a/esphome/git.py b/esphome/git.py index 56aedd1519..62fe37a3fe 100644 --- a/esphome/git.py +++ b/esphome/git.py @@ -13,6 +13,9 @@ from esphome.core import CORE, TimePeriodSeconds _LOGGER = logging.getLogger(__name__) +# Special value to indicate never refresh +NEVER_REFRESH = TimePeriodSeconds(seconds=-1) + def run_git_command(cmd, cwd=None) -> str: _LOGGER.debug("Running git command: %s", " ".join(cmd)) @@ -85,6 +88,11 @@ def clone_or_update( else: # Check refresh needed + # Skip refresh if NEVER_REFRESH is specified + if refresh == NEVER_REFRESH: + _LOGGER.debug("Skipping update for %s (refresh disabled)", key) + return repo_dir, None + file_timestamp = Path(repo_dir / ".git" / "FETCH_HEAD") # On first clone, FETCH_HEAD does not exists if not file_timestamp.exists(): diff --git a/esphome/helpers.py b/esphome/helpers.py index 6beaa24a96..fb7b71775d 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -1,6 +1,5 @@ from __future__ import annotations -import codecs from contextlib import suppress import ipaddress import logging @@ -8,11 +7,16 @@ import os from pathlib import Path import platform import re +import shutil import tempfile +from typing import TYPE_CHECKING from urllib.parse import urlparse from esphome.const import __version__ as ESPHOME_VERSION +if TYPE_CHECKING: + from esphome.address_cache import AddressCache + # Type aliases for socket address information AddrInfo = tuple[ int, # family (AF_INET, AF_INET6, etc.) @@ -136,16 +140,16 @@ def run_system_command(*args): return rc, stdout, stderr -def mkdir_p(path): +def mkdir_p(path: Path): if not path: # Empty path - means create current dir return try: - os.makedirs(path) + path.mkdir(parents=True, exist_ok=True) except OSError as err: import errno - if err.errno == errno.EEXIST and os.path.isdir(path): + if err.errno == errno.EEXIST and path.is_dir(): pass else: from esphome.core import EsphomeError @@ -173,7 +177,24 @@ def addr_preference_(res: AddrInfo) -> int: return 1 -def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: +def _add_ip_addresses_to_addrinfo( + addresses: list[str], port: int, res: list[AddrInfo] +) -> None: + """Helper to add IP addresses to addrinfo results with error handling.""" + import socket + + for addr in addresses: + try: + res += socket.getaddrinfo( + addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST + ) + except OSError: + _LOGGER.debug("Failed to parse IP address '%s'", addr) + + +def resolve_ip_address( + host: str | list[str], port: int, address_cache: AddressCache | None = None +) -> list[AddrInfo]: import socket # There are five cases here. The host argument could be one of: @@ -194,47 +215,69 @@ def resolve_ip_address(host: str | list[str], port: int) -> list[AddrInfo]: hosts = [host] res: list[AddrInfo] = [] + + # Fast path: if all hosts are already IP addresses if all(is_ip_address(h) for h in hosts): - # Fast path: all are IP addresses, use socket.getaddrinfo with AI_NUMERICHOST - for addr in hosts: - try: - res += socket.getaddrinfo( - addr, port, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST - ) - except OSError: - _LOGGER.debug("Failed to parse IP address '%s'", addr) + _add_ip_addresses_to_addrinfo(hosts, port, res) # Sort by preference res.sort(key=addr_preference_) return res - from esphome.resolver import AsyncResolver + # Process hosts + cached_addresses: list[str] = [] + uncached_hosts: list[str] = [] + has_cache = address_cache is not None - resolver = AsyncResolver(hosts, port) - addr_infos = resolver.resolve() - # Convert aioesphomeapi AddrInfo to our format - for addr_info in addr_infos: - sockaddr = addr_info.sockaddr - if addr_info.family == socket.AF_INET6: - # IPv6 - sockaddr_tuple = ( - sockaddr.address, - sockaddr.port, - sockaddr.flowinfo, - sockaddr.scope_id, - ) + for h in hosts: + if is_ip_address(h): + if has_cache: + # If we have a cache, treat IPs as cached + cached_addresses.append(h) + else: + # If no cache, pass IPs through to resolver with hostnames + uncached_hosts.append(h) + elif address_cache and (cached := address_cache.get_addresses(h)): + # Found in cache + cached_addresses.extend(cached) else: - # IPv4 - sockaddr_tuple = (sockaddr.address, sockaddr.port) + # Not cached, need to resolve + if address_cache and address_cache.has_cache(): + _LOGGER.info("Host %s not in cache, will need to resolve", h) + uncached_hosts.append(h) - res.append( - ( - addr_info.family, - addr_info.type, - addr_info.proto, - "", # canonname - sockaddr_tuple, + # Process cached addresses (includes direct IPs and cached lookups) + _add_ip_addresses_to_addrinfo(cached_addresses, port, res) + + # If we have uncached hosts (only non-IP hostnames), resolve them + if uncached_hosts: + from esphome.resolver import AsyncResolver + + resolver = AsyncResolver(uncached_hosts, port) + addr_infos = resolver.resolve() + # Convert aioesphomeapi AddrInfo to our format + for addr_info in addr_infos: + sockaddr = addr_info.sockaddr + if addr_info.family == socket.AF_INET6: + # IPv6 + sockaddr_tuple = ( + sockaddr.address, + sockaddr.port, + sockaddr.flowinfo, + sockaddr.scope_id, + ) + else: + # IPv4 + sockaddr_tuple = (sockaddr.address, sockaddr.port) + + res.append( + ( + addr_info.family, + addr_info.type, + addr_info.proto, + "", # canonname + sockaddr_tuple, + ) ) - ) # Sort by preference res.sort(key=addr_preference_) @@ -256,14 +299,7 @@ def sort_ip_addresses(address_list: list[str]) -> list[str]: # First "resolve" all the IP addresses to getaddrinfo() tuples of the form # (family, type, proto, canonname, sockaddr) res: list[AddrInfo] = [] - for addr in address_list: - # This should always work as these are supposed to be IP addresses - try: - res += socket.getaddrinfo( - addr, 0, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST - ) - except OSError: - _LOGGER.info("Failed to parse IP address '%s'", addr) + _add_ip_addresses_to_addrinfo(address_list, 0, res) # Now use that information to sort them. res.sort(key=addr_preference_) @@ -295,16 +331,15 @@ def is_ha_addon(): return get_bool_env("ESPHOME_IS_HA_ADDON") -def walk_files(path): +def walk_files(path: Path): for root, _, files in os.walk(path): for name in files: - yield os.path.join(root, name) + yield Path(root) / name -def read_file(path): +def read_file(path: Path) -> str: try: - with codecs.open(path, "r", encoding="utf-8") as f_handle: - return f_handle.read() + return path.read_text(encoding="utf-8") except OSError as err: from esphome.core import EsphomeError @@ -315,13 +350,15 @@ def read_file(path): raise EsphomeError(f"Error reading file {path}: {err}") from err -def _write_file(path: Path | str, text: str | bytes): +def _write_file( + path: Path, + text: str | bytes, + private: bool = False, +) -> None: """Atomically writes `text` to the given path. Automatically creates all parent directories. """ - if not isinstance(path, Path): - path = Path(path) data = text if isinstance(text, str): data = text.encode() @@ -329,42 +366,54 @@ def _write_file(path: Path | str, text: str | bytes): directory = path.parent directory.mkdir(exist_ok=True, parents=True) - tmp_path = None + tmp_filename: Path | None = None + missing_fchmod = False try: + # Modern versions of Python tempfile create this file with mode 0o600 with tempfile.NamedTemporaryFile( mode="wb", dir=directory, delete=False ) as f_handle: - tmp_path = f_handle.name f_handle.write(data) - # Newer tempfile implementations create the file with mode 0o600 - os.chmod(tmp_path, 0o644) - # If destination exists, will be overwritten - os.replace(tmp_path, path) + tmp_filename = Path(f_handle.name) + + if not private: + try: + os.fchmod(f_handle.fileno(), 0o644) + except AttributeError: + # os.fchmod is not available on Windows + missing_fchmod = True + shutil.move(tmp_filename, path) + if missing_fchmod: + path.chmod(0o644) finally: - if tmp_path is not None and os.path.exists(tmp_path): + if tmp_filename and tmp_filename.exists(): try: - os.remove(tmp_path) + tmp_filename.unlink() except OSError as err: - _LOGGER.error("Write file cleanup failed: %s", err) + # If we are cleaning up then something else went wrong, so + # we should suppress likely follow-on errors in the cleanup + _LOGGER.error( + "File replacement cleanup failed for %s while saving %s: %s", + tmp_filename, + path, + err, + ) -def write_file(path: Path | str, text: str): +def write_file(path: Path, text: str | bytes, private: bool = False) -> None: try: - _write_file(path, text) + _write_file(path, text, private=private) except OSError as err: from esphome.core import EsphomeError raise EsphomeError(f"Could not write file at {path}") from err -def write_file_if_changed(path: Path | str, text: str) -> bool: +def write_file_if_changed(path: Path, text: str) -> bool: """Write text to the given path, but not if the contents match already. Returns true if the file was changed. """ - if not isinstance(path, Path): - path = Path(path) - src_content = None if path.is_file(): src_content = read_file(path) @@ -374,12 +423,10 @@ def write_file_if_changed(path: Path | str, text: str) -> bool: return True -def copy_file_if_changed(src: os.PathLike, dst: os.PathLike) -> None: - import shutil - +def copy_file_if_changed(src: Path, dst: Path) -> None: if file_compare(src, dst): return - mkdir_p(os.path.dirname(dst)) + dst.parent.mkdir(parents=True, exist_ok=True) try: shutil.copyfile(src, dst) except OSError as err: @@ -404,12 +451,12 @@ def list_starts_with(list_, sub): return len(sub) <= len(list_) and all(list_[i] == x for i, x in enumerate(sub)) -def file_compare(path1: os.PathLike, path2: os.PathLike) -> bool: +def file_compare(path1: Path, path2: Path) -> bool: """Return True if the files path1 and path2 have the same contents.""" import stat try: - stat1, stat2 = os.stat(path1), os.stat(path2) + stat1, stat2 = path1.stat(), path2.stat() except OSError: # File doesn't exist or another error -> not equal return False @@ -426,7 +473,7 @@ def file_compare(path1: os.PathLike, path2: os.PathLike) -> bool: bufsize = 8 * 1024 # Read files in blocks until a mismatch is found - with open(path1, "rb") as fh1, open(path2, "rb") as fh2: + with path1.open("rb") as fh1, path2.open("rb") as fh2: while True: blob1, blob2 = fh1.read(bufsize), fh2.read(bufsize) if blob1 != blob2: diff --git a/esphome/idf_component.yml b/esphome/idf_component.yml index 687efd2b49..1a6dc8b97d 100644 --- a/esphome/idf_component.yml +++ b/esphome/idf_component.yml @@ -19,3 +19,7 @@ dependencies: - if: "target in [esp32h2, esp32p4]" zorxx/multipart-parser: version: 1.0.1 + espressif/lan867x: + version: "2.0.0" + rules: + - if: "target in [esp32, esp32p4]" diff --git a/esphome/loader.py b/esphome/loader.py index 7b2472521a..387443c032 100644 --- a/esphome/loader.py +++ b/esphome/loader.py @@ -82,11 +82,10 @@ class ComponentManifest: return getattr(self.module, "CONFLICTS_WITH", []) @property - def auto_load(self) -> list[str]: - al = getattr(self.module, "AUTO_LOAD", []) - if callable(al): - return al() - return al + def auto_load( + self, + ) -> list[str] | Callable[[], list[str]] | Callable[[ConfigType], list[str]]: + return getattr(self.module, "AUTO_LOAD", []) @property def codeowners(self) -> list[str]: @@ -192,7 +191,7 @@ def install_custom_components_meta_finder(): install_meta_finder(custom_components_dir) -def _lookup_module(domain, exception): +def _lookup_module(domain: str, exception: bool) -> ComponentManifest | None: if domain in _COMPONENT_CACHE: return _COMPONENT_CACHE[domain] @@ -219,16 +218,16 @@ def _lookup_module(domain, exception): return manif -def get_component(domain, exception=False): +def get_component(domain: str, exception: bool = False) -> ComponentManifest | None: assert "." not in domain return _lookup_module(domain, exception) -def get_platform(domain, platform): +def get_platform(domain: str, platform: str) -> ComponentManifest | None: full = f"{platform}.{domain}" return _lookup_module(full, False) -_COMPONENT_CACHE = {} +_COMPONENT_CACHE: dict[str, ComponentManifest] = {} CORE_COMPONENTS_PATH = (Path(__file__).parent / "components").resolve() _COMPONENT_CACHE["esphome"] = ComponentManifest(esphome.core.config) diff --git a/esphome/platformio_api.py b/esphome/platformio_api.py index 267277ebe1..8b7b790829 100644 --- a/esphome/platformio_api.py +++ b/esphome/platformio_api.py @@ -18,23 +18,25 @@ def patch_structhash(): # removed/added. This might have unintended consequences, but this improves compile # times greatly when adding/removing components and a simple clean build solves # all issues - from os import makedirs - from os.path import getmtime, isdir, join - from platformio.run import cli, helpers def patched_clean_build_dir(build_dir, *args): from platformio import fs from platformio.project.helpers import get_project_dir - platformio_ini = join(get_project_dir(), "platformio.ini") + platformio_ini = Path(get_project_dir()) / "platformio.ini" + + build_dir = Path(build_dir) # if project's config is modified - if isdir(build_dir) and getmtime(platformio_ini) > getmtime(build_dir): + if ( + build_dir.is_dir() + and platformio_ini.stat().st_mtime > build_dir.stat().st_mtime + ): fs.rmtree(build_dir) - if not isdir(build_dir): - makedirs(build_dir) + if not build_dir.is_dir(): + build_dir.mkdir(parents=True) helpers.clean_build_dir = patched_clean_build_dir cli.clean_build_dir = patched_clean_build_dir @@ -70,14 +72,19 @@ FILTER_PLATFORMIO_LINES = [ r" - tool-esptool.* \(.*\)", r" - toolchain-.* \(.*\)", r"Creating BIN file .*", + r"Warning! Could not find file \".*.crt\"", + r"Warning! Arduino framework as an ESP-IDF component doesn't handle the `variant` field! The default `esp32` variant will be used.", + r"Warning: DEPRECATED: 'esptool.py' is deprecated. Please use 'esptool' instead. The '.py' suffix will be removed in a future major release.", + r"Warning: esp-idf-size exited with code 2", + r"esp_idf_size: error: unrecognized arguments: --ng", ] def run_platformio_cli(*args, **kwargs) -> str | int: os.environ["PLATFORMIO_FORCE_COLOR"] = "true" - os.environ["PLATFORMIO_BUILD_DIR"] = os.path.abspath(CORE.relative_pioenvs_path()) + os.environ["PLATFORMIO_BUILD_DIR"] = str(CORE.relative_pioenvs_path().absolute()) os.environ.setdefault( - "PLATFORMIO_LIBDEPS_DIR", os.path.abspath(CORE.relative_piolibdeps_path()) + "PLATFORMIO_LIBDEPS_DIR", str(CORE.relative_piolibdeps_path().absolute()) ) # Suppress Python syntax warnings from third-party scripts during compilation os.environ.setdefault("PYTHONWARNINGS", "ignore::SyntaxWarning") @@ -96,7 +103,7 @@ def run_platformio_cli(*args, **kwargs) -> str | int: def run_platformio_cli_run(config, verbose, *args, **kwargs) -> str | int: - command = ["run", "-d", CORE.build_path] + command = ["run", "-d", str(CORE.build_path)] if verbose: command += ["-v"] command += list(args) @@ -128,8 +135,8 @@ def _run_idedata(config): def _load_idedata(config): - platformio_ini = Path(CORE.relative_build_path("platformio.ini")) - temp_idedata = Path(CORE.relative_internal_path("idedata", f"{CORE.name}.json")) + platformio_ini = CORE.relative_build_path("platformio.ini") + temp_idedata = CORE.relative_internal_path("idedata", f"{CORE.name}.json") changed = False if ( @@ -299,7 +306,7 @@ def process_stacktrace(config, line, backtrace_state): @dataclass class FlashImage: - path: str + path: Path offset: str @@ -308,17 +315,17 @@ class IDEData: self.raw = raw @property - def firmware_elf_path(self): - return self.raw["prog_path"] + def firmware_elf_path(self) -> Path: + return Path(self.raw["prog_path"]) @property - def firmware_bin_path(self) -> str: - return str(Path(self.firmware_elf_path).with_suffix(".bin")) + def firmware_bin_path(self) -> Path: + return self.firmware_elf_path.with_suffix(".bin") @property def extra_flash_images(self) -> list[FlashImage]: return [ - FlashImage(path=entry["path"], offset=entry["offset"]) + FlashImage(path=Path(entry["path"]), offset=entry["offset"]) for entry in self.raw["extra"]["flash_images"] ] diff --git a/esphome/storage_json.py b/esphome/storage_json.py index b69dc2dd3f..d5423ab1c7 100644 --- a/esphome/storage_json.py +++ b/esphome/storage_json.py @@ -1,11 +1,11 @@ from __future__ import annotations import binascii -import codecs from datetime import datetime import json import logging import os +from pathlib import Path from esphome import const from esphome.const import CONF_DISABLED, CONF_MDNS @@ -16,30 +16,35 @@ from esphome.types import CoreType _LOGGER = logging.getLogger(__name__) -def storage_path() -> str: - return os.path.join(CORE.data_dir, "storage", f"{CORE.config_filename}.json") +def storage_path() -> Path: + return CORE.data_dir / "storage" / f"{CORE.config_filename}.json" -def ext_storage_path(config_filename: str) -> str: - return os.path.join(CORE.data_dir, "storage", f"{config_filename}.json") +def ext_storage_path(config_filename: str) -> Path: + return CORE.data_dir / "storage" / f"{config_filename}.json" -def esphome_storage_path() -> str: - return os.path.join(CORE.data_dir, "esphome.json") +def esphome_storage_path() -> Path: + return CORE.data_dir / "esphome.json" -def ignored_devices_storage_path() -> str: - return os.path.join(CORE.data_dir, "ignored-devices.json") +def ignored_devices_storage_path() -> Path: + return CORE.data_dir / "ignored-devices.json" -def trash_storage_path() -> str: +def trash_storage_path() -> Path: return CORE.relative_config_path("trash") -def archive_storage_path() -> str: +def archive_storage_path() -> Path: return CORE.relative_config_path("archive") +def _to_path_if_not_none(value: str | None) -> Path | None: + """Convert a string to Path if it's not None.""" + return Path(value) if value is not None else None + + class StorageJSON: def __init__( self, @@ -52,8 +57,8 @@ class StorageJSON: address: str, web_port: int | None, target_platform: str, - build_path: str | None, - firmware_bin_path: str | None, + build_path: Path | None, + firmware_bin_path: Path | None, loaded_integrations: set[str], loaded_platforms: set[str], no_mdns: bool, @@ -107,8 +112,8 @@ class StorageJSON: "address": self.address, "web_port": self.web_port, "esp_platform": self.target_platform, - "build_path": self.build_path, - "firmware_bin_path": self.firmware_bin_path, + "build_path": str(self.build_path), + "firmware_bin_path": str(self.firmware_bin_path), "loaded_integrations": sorted(self.loaded_integrations), "loaded_platforms": sorted(self.loaded_platforms), "no_mdns": self.no_mdns, @@ -176,8 +181,8 @@ class StorageJSON: ) @staticmethod - def _load_impl(path: str) -> StorageJSON | None: - with codecs.open(path, "r", encoding="utf-8") as f_handle: + def _load_impl(path: Path) -> StorageJSON | None: + with path.open("r", encoding="utf-8") as f_handle: storage = json.load(f_handle) storage_version = storage["storage_version"] name = storage.get("name") @@ -190,8 +195,8 @@ class StorageJSON: address = storage.get("address") web_port = storage.get("web_port") esp_platform = storage.get("esp_platform") - build_path = storage.get("build_path") - firmware_bin_path = storage.get("firmware_bin_path") + build_path = _to_path_if_not_none(storage.get("build_path")) + firmware_bin_path = _to_path_if_not_none(storage.get("firmware_bin_path")) loaded_integrations = set(storage.get("loaded_integrations", [])) loaded_platforms = set(storage.get("loaded_platforms", [])) no_mdns = storage.get("no_mdns", False) @@ -217,7 +222,7 @@ class StorageJSON: ) @staticmethod - def load(path: str) -> StorageJSON | None: + def load(path: Path) -> StorageJSON | None: try: return StorageJSON._load_impl(path) except Exception: # pylint: disable=broad-except @@ -268,7 +273,7 @@ class EsphomeStorageJSON: @staticmethod def _load_impl(path: str) -> EsphomeStorageJSON | None: - with codecs.open(path, "r", encoding="utf-8") as f_handle: + with Path(path).open("r", encoding="utf-8") as f_handle: storage = json.load(f_handle) storage_version = storage["storage_version"] cookie_secret = storage.get("cookie_secret") diff --git a/esphome/types.py b/esphome/types.py index 62499a953c..c474d0d076 100644 --- a/esphome/types.py +++ b/esphome/types.py @@ -1,8 +1,10 @@ """This helper module tracks commonly used types in the esphome python codebase.""" -from typing import TypedDict +import abc +from collections.abc import Sequence +from typing import Any, TypedDict -from esphome.core import ID, EsphomeCore, Lambda +from esphome.core import ID, EsphomeCore, Lambda, TimePeriod ConfigFragmentType = ( str @@ -20,6 +22,32 @@ CoreType = EsphomeCore ConfigPathType = str | int +class Expression(abc.ABC): + __slots__ = () + + @abc.abstractmethod + def __str__(self): + """ + Convert expression into C++ code + """ + + +SafeExpType = ( + Expression + | bool + | str + | int + | float + | TimePeriod + | type[bool] + | type[int] + | type[float] + | Sequence[Any] +) + +TemplateArgsType = list[tuple[SafeExpType, str]] + + class EntityMetadata(TypedDict): """Metadata stored for each entity to help with duplicate detection.""" diff --git a/esphome/util.py b/esphome/util.py index 23a66be4eb..d41800dc20 100644 --- a/esphome/util.py +++ b/esphome/util.py @@ -1,20 +1,30 @@ import collections +from collections.abc import Callable import io import logging -import os from pathlib import Path import re import subprocess import sys -from typing import Any +from typing import TYPE_CHECKING, Any from esphome import const _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from esphome.config_validation import Schema + from esphome.cpp_generator import MockObjClass + class RegistryEntry: - def __init__(self, name, fun, type_id, schema): + def __init__( + self, + name: str, + fun: Callable[..., Any], + type_id: "MockObjClass", + schema: "Schema", + ): self.name = name self.fun = fun self.type_id = type_id @@ -39,8 +49,8 @@ class Registry(dict[str, RegistryEntry]): self.base_schema = base_schema or {} self.type_id_key = type_id_key - def register(self, name, type_id, schema): - def decorator(fun): + def register(self, name: str, type_id: "MockObjClass", schema: "Schema"): + def decorator(fun: Callable[..., Any]): self[name] = RegistryEntry(name, fun, type_id, schema) return fun @@ -48,8 +58,8 @@ class Registry(dict[str, RegistryEntry]): class SimpleRegistry(dict): - def register(self, name, data): - def decorator(fun): + def register(self, name: str, data: Any): + def decorator(fun: Callable[..., Any]): self[name] = (fun, data) return fun @@ -86,7 +96,10 @@ def safe_input(prompt=""): return input() -def shlex_quote(s): +def shlex_quote(s: str | Path) -> str: + # Convert Path objects to strings + if isinstance(s, Path): + s = str(s) if not s: return "''" if re.search(r"[^\w@%+=:,./-]", s) is None: @@ -272,25 +285,28 @@ class OrderedDict(collections.OrderedDict): return dict(self).__repr__() -def list_yaml_files(configs: list[str]) -> list[str]: - files: list[str] = [] +def list_yaml_files(configs: list[str | Path]) -> list[Path]: + files: list[Path] = [] for config in configs: - if os.path.isfile(config): + config = Path(config) + if not config.exists(): + raise FileNotFoundError(f"Config path '{config}' does not exist!") + if config.is_file(): files.append(config) else: - files.extend(os.path.join(config, p) for p in os.listdir(config)) + files.extend(config.glob("*")) files = filter_yaml_files(files) return sorted(files) -def filter_yaml_files(files: list[str]) -> list[str]: +def filter_yaml_files(files: list[Path]) -> list[Path]: return [ f for f in files if ( - os.path.splitext(f)[1] in (".yaml", ".yml") - and os.path.basename(f) not in ("secrets.yaml", "secrets.yml") - and not os.path.basename(f).startswith(".") + f.suffix in (".yaml", ".yml") + and f.name not in ("secrets.yaml", "secrets.yml") + and not f.name.startswith(".") ) ] diff --git a/esphome/vscode.py b/esphome/vscode.py index f5e2a20b97..53bb339a8e 100644 --- a/esphome/vscode.py +++ b/esphome/vscode.py @@ -2,7 +2,7 @@ from __future__ import annotations from io import StringIO import json -import os +from pathlib import Path from typing import Any from esphome.config import Config, _format_vol_invalid, validate_config @@ -67,24 +67,24 @@ def _read_file_content_from_json_on_stdin() -> str: return data["content"] -def _print_file_read_event(path: str) -> None: +def _print_file_read_event(path: Path) -> None: """Print a file read event.""" print( json.dumps( { "type": "read_file", - "path": path, + "path": str(path), } ) ) -def _request_and_get_stream_on_stdin(fname: str) -> StringIO: +def _request_and_get_stream_on_stdin(fname: Path) -> StringIO: _print_file_read_event(fname) return StringIO(_read_file_content_from_json_on_stdin()) -def _vscode_loader(fname: str) -> dict[str, Any]: +def _vscode_loader(fname: Path) -> dict[str, Any]: raw_yaml_stream = _request_and_get_stream_on_stdin(fname) # it is required to set the name on StringIO so document on start_mark # is set properly. Otherwise it is initialized with "" @@ -92,7 +92,7 @@ def _vscode_loader(fname: str) -> dict[str, Any]: return parse_yaml(fname, raw_yaml_stream, _vscode_loader) -def _ace_loader(fname: str) -> dict[str, Any]: +def _ace_loader(fname: Path) -> dict[str, Any]: raw_yaml_stream = _request_and_get_stream_on_stdin(fname) return parse_yaml(fname, raw_yaml_stream) @@ -120,10 +120,10 @@ def read_config(args): return CORE.vscode = True if args.ace: # Running from ESPHome Compiler dashboard, not vscode - CORE.config_path = os.path.join(args.configuration, data["file"]) + CORE.config_path = Path(args.configuration) / data["file"] loader = _ace_loader else: - CORE.config_path = data["file"] + CORE.config_path = Path(data["file"]) loader = _vscode_loader file_name = CORE.config_path diff --git a/esphome/wizard.py b/esphome/wizard.py index 3edf519816..97343eea99 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path import random import string from typing import Literal, NotRequired, TypedDict, Unpack @@ -213,7 +213,7 @@ class WizardWriteKwargs(TypedDict): file_text: NotRequired[str] -def wizard_write(path: str, **kwargs: Unpack[WizardWriteKwargs]) -> bool: +def wizard_write(path: Path, **kwargs: Unpack[WizardWriteKwargs]) -> bool: from esphome.components.bk72xx import boards as bk72xx_boards from esphome.components.esp32 import boards as esp32_boards from esphome.components.esp8266 import boards as esp8266_boards @@ -256,13 +256,13 @@ def wizard_write(path: str, **kwargs: Unpack[WizardWriteKwargs]) -> bool: file_text = wizard_file(**kwargs) # Check if file already exists to prevent overwriting - if os.path.exists(path) and os.path.isfile(path): + if path.exists() and path.is_file(): safe_print(color(AnsiFore.RED, f'The file "{path}" already exists.')) return False write_file(path, file_text) storage = StorageJSON.from_wizard(name, name, f"{name}.local", hardware) - storage_path = ext_storage_path(os.path.basename(path)) + storage_path = ext_storage_path(path.name) storage.save(storage_path) return True @@ -301,7 +301,7 @@ def strip_accents(value: str) -> str: ) -def wizard(path: str) -> int: +def wizard(path: Path) -> int: from esphome.components.bk72xx import boards as bk72xx_boards from esphome.components.esp32 import boards as esp32_boards from esphome.components.esp8266 import boards as esp8266_boards @@ -309,14 +309,14 @@ def wizard(path: str) -> int: from esphome.components.rp2040 import boards as rp2040_boards from esphome.components.rtl87xx import boards as rtl87xx_boards - if not path.endswith(".yaml") and not path.endswith(".yml"): + if path.suffix not in (".yaml", ".yml"): safe_print( - f"Please make your configuration file {color(AnsiFore.CYAN, path)} have the extension .yaml or .yml" + f"Please make your configuration file {color(AnsiFore.CYAN, str(path))} have the extension .yaml or .yml" ) return 1 - if os.path.exists(path): + if path.exists(): safe_print( - f"Uh oh, it seems like {color(AnsiFore.CYAN, path)} already exists, please delete that file first or chose another configuration file." + f"Uh oh, it seems like {color(AnsiFore.CYAN, str(path))} already exists, please delete that file first or chose another configuration file." ) return 2 @@ -549,7 +549,7 @@ def wizard(path: str) -> int: safe_print() safe_print( color(AnsiFore.CYAN, "DONE! I've now written a new configuration file to ") - + color(AnsiFore.BOLD_CYAN, path) + + color(AnsiFore.BOLD_CYAN, str(path)) ) safe_print() safe_print("Next steps:") diff --git a/esphome/writer.py b/esphome/writer.py index 2a9c6a770d..b5cfd9b667 100644 --- a/esphome/writer.py +++ b/esphome/writer.py @@ -266,7 +266,7 @@ def generate_version_h(): def write_cpp(code_s): path = CORE.relative_src_path("main.cpp") - if os.path.isfile(path): + if path.is_file(): text = read_file(path) code_format = find_begin_end( text, CPP_AUTO_GENERATE_BEGIN, CPP_AUTO_GENERATE_END @@ -292,43 +292,79 @@ def write_cpp(code_s): def clean_cmake_cache(): pioenvs = CORE.relative_pioenvs_path() - if os.path.isdir(pioenvs): - pioenvs_cmake_path = CORE.relative_pioenvs_path(CORE.name, "CMakeCache.txt") - if os.path.isfile(pioenvs_cmake_path): + if pioenvs.is_dir(): + pioenvs_cmake_path = pioenvs / CORE.name / "CMakeCache.txt" + if pioenvs_cmake_path.is_file(): _LOGGER.info("Deleting %s", pioenvs_cmake_path) - os.remove(pioenvs_cmake_path) + pioenvs_cmake_path.unlink() def clean_build(): import shutil + # Allow skipping cache cleaning for integration tests + if os.environ.get("ESPHOME_SKIP_CLEAN_BUILD"): + _LOGGER.warning("Skipping build cleaning (ESPHOME_SKIP_CLEAN_BUILD set)") + return + pioenvs = CORE.relative_pioenvs_path() - if os.path.isdir(pioenvs): + if pioenvs.is_dir(): _LOGGER.info("Deleting %s", pioenvs) shutil.rmtree(pioenvs) piolibdeps = CORE.relative_piolibdeps_path() - if os.path.isdir(piolibdeps): + if piolibdeps.is_dir(): _LOGGER.info("Deleting %s", piolibdeps) shutil.rmtree(piolibdeps) dependencies_lock = CORE.relative_build_path("dependencies.lock") - if os.path.isfile(dependencies_lock): + if dependencies_lock.is_file(): _LOGGER.info("Deleting %s", dependencies_lock) - os.remove(dependencies_lock) + dependencies_lock.unlink() # Clean PlatformIO cache to resolve CMake compiler detection issues # This helps when toolchain paths change or get corrupted try: - from platformio.project.helpers import get_project_cache_dir + from platformio.project.config import ProjectConfig except ImportError: # PlatformIO is not available, skip cache cleaning pass else: - cache_dir = get_project_cache_dir() - if cache_dir and cache_dir.strip() and os.path.isdir(cache_dir): + config = ProjectConfig.get_instance() + cache_dir = Path(config.get("platformio", "cache_dir")) + if cache_dir.is_dir(): _LOGGER.info("Deleting PlatformIO cache %s", cache_dir) shutil.rmtree(cache_dir) +def clean_all(configuration: list[str]): + import shutil + + # Clean entire build dir + for dir in configuration: + build_dir = Path(dir) / ".esphome" + if build_dir.is_dir(): + _LOGGER.info("Cleaning %s", build_dir) + # Don't remove storage as it will cause the dashboard to regenerate all configs + for item in build_dir.iterdir(): + if item.is_file(): + item.unlink() + elif item.name != "storage" and item.is_dir(): + shutil.rmtree(item) + + # Clean PlatformIO project files + try: + from platformio.project.config import ProjectConfig + except ImportError: + # PlatformIO is not available, skip cleaning + pass + else: + config = ProjectConfig.get_instance() + for pio_dir in ["cache_dir", "packages_dir", "platforms_dir", "core_dir"]: + path = Path(config.get("platformio", pio_dir)) + if path.is_dir(): + _LOGGER.info("Deleting PlatformIO %s %s", pio_dir, path) + shutil.rmtree(path) + + GITIGNORE_CONTENT = """# Gitignore settings for ESPHome # This is an example and may include too much for your use-case. # You can modify this file to suit your needs. @@ -339,6 +375,5 @@ GITIGNORE_CONTENT = """# Gitignore settings for ESPHome def write_gitignore(): path = CORE.relative_config_path(".gitignore") - if not os.path.isfile(path): - with open(file=path, mode="w", encoding="utf-8") as f: - f.write(GITIGNORE_CONTENT) + if not path.is_file(): + path.write_text(GITIGNORE_CONTENT, encoding="utf-8") diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index f26bc0502d..359b72b48f 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Callable -import fnmatch import functools import inspect from io import BytesIO, TextIOBase, TextIOWrapper @@ -9,6 +8,7 @@ from ipaddress import _BaseAddress, _BaseNetwork import logging import math import os +from pathlib import Path from typing import Any import uuid @@ -69,7 +69,7 @@ class ESPHomeDataBase: self._content_offset = database.content_offset -class ESPForceValue: +class ESPLiteralValue: pass @@ -109,7 +109,9 @@ def _add_data_ref(fn): class ESPHomeLoaderMixin: """Loader class that keeps track of line numbers.""" - def __init__(self, name: str, yaml_loader: Callable[[str], dict[str, Any]]) -> None: + def __init__( + self, name: Path, yaml_loader: Callable[[Path], dict[str, Any]] + ) -> None: """Initialize the loader.""" self.name = name self.yaml_loader = yaml_loader @@ -254,12 +256,8 @@ class ESPHomeLoaderMixin: f"Environment variable '{node.value}' not defined", node.start_mark ) - @property - def _directory(self) -> str: - return os.path.dirname(self.name) - - def _rel_path(self, *args: str) -> str: - return os.path.join(self._directory, *args) + def _rel_path(self, *args: str) -> Path: + return self.name.parent / Path(*args) @_add_data_ref def construct_secret(self, node: yaml.Node) -> str: @@ -269,8 +267,8 @@ class ESPHomeLoaderMixin: if self.name == CORE.config_path: raise e try: - main_config_dir = os.path.dirname(CORE.config_path) - main_secret_yml = os.path.join(main_config_dir, SECRET_YAML) + main_config_dir = CORE.config_path.parent + main_secret_yml = main_config_dir / SECRET_YAML secrets = self.yaml_loader(main_secret_yml) except EsphomeError as er: raise EsphomeError(f"{e}\n{er}") from er @@ -329,7 +327,7 @@ class ESPHomeLoaderMixin: files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml")) mapping = OrderedDict() for fname in files: - filename = os.path.splitext(os.path.basename(fname))[0] + filename = fname.stem mapping[filename] = self.yaml_loader(fname) return mapping @@ -350,9 +348,15 @@ class ESPHomeLoaderMixin: return Lambda(str(node.value)) @_add_data_ref - def construct_force(self, node: yaml.Node) -> ESPForceValue: - obj = self.construct_scalar(node) - return add_class_to_obj(obj, ESPForceValue) + def construct_literal(self, node: yaml.Node) -> ESPLiteralValue: + obj = None + if isinstance(node, yaml.ScalarNode): + obj = self.construct_scalar(node) + elif isinstance(node, yaml.SequenceNode): + obj = self.construct_sequence(node) + elif isinstance(node, yaml.MappingNode): + obj = self.construct_mapping(node) + return add_class_to_obj(obj, ESPLiteralValue) @_add_data_ref def construct_extend(self, node: yaml.Node) -> Extend: @@ -369,8 +373,8 @@ class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader): def __init__( self, stream: TextIOBase | BytesIO, - name: str, - yaml_loader: Callable[[str], dict[str, Any]], + name: Path, + yaml_loader: Callable[[Path], dict[str, Any]], ) -> None: FastestAvailableSafeLoader.__init__(self, stream) ESPHomeLoaderMixin.__init__(self, name, yaml_loader) @@ -382,8 +386,8 @@ class ESPHomePurePythonLoader(ESPHomeLoaderMixin, PurePythonLoader): def __init__( self, stream: TextIOBase | BytesIO, - name: str, - yaml_loader: Callable[[str], dict[str, Any]], + name: Path, + yaml_loader: Callable[[Path], dict[str, Any]], ) -> None: PurePythonLoader.__init__(self, stream) ESPHomeLoaderMixin.__init__(self, name, yaml_loader) @@ -409,29 +413,29 @@ for _loader in (ESPHomeLoader, ESPHomePurePythonLoader): "!include_dir_merge_named", _loader.construct_include_dir_merge_named ) _loader.add_constructor("!lambda", _loader.construct_lambda) - _loader.add_constructor("!force", _loader.construct_force) + _loader.add_constructor("!literal", _loader.construct_literal) _loader.add_constructor("!extend", _loader.construct_extend) _loader.add_constructor("!remove", _loader.construct_remove) -def load_yaml(fname: str, clear_secrets: bool = True) -> Any: +def load_yaml(fname: Path, clear_secrets: bool = True) -> Any: if clear_secrets: _SECRET_VALUES.clear() _SECRET_CACHE.clear() return _load_yaml_internal(fname) -def _load_yaml_internal(fname: str) -> Any: +def _load_yaml_internal(fname: Path) -> Any: """Load a YAML file.""" try: - with open(fname, encoding="utf-8") as f_handle: + with fname.open(encoding="utf-8") as f_handle: return parse_yaml(fname, f_handle) except (UnicodeDecodeError, OSError) as err: raise EsphomeError(f"Error reading file {fname}: {err}") from err def parse_yaml( - file_name: str, file_handle: TextIOWrapper, yaml_loader=_load_yaml_internal + file_name: Path, file_handle: TextIOWrapper, yaml_loader=_load_yaml_internal ) -> Any: """Parse a YAML file.""" try: @@ -483,9 +487,9 @@ def substitute_vars(config, vars): def _load_yaml_internal_with_type( loader_type: type[ESPHomeLoader] | type[ESPHomePurePythonLoader], - fname: str, + fname: Path, content: TextIOWrapper, - yaml_loader: Any, + yaml_loader: Callable[[Path], dict[str, Any]], ) -> Any: """Load a YAML file.""" loader = loader_type(content, fname, yaml_loader) @@ -512,13 +516,14 @@ def _is_file_valid(name: str) -> bool: return not name.startswith(".") -def _find_files(directory, pattern): +def _find_files(directory: Path, pattern): """Recursively load files in a directory.""" - for root, dirs, files in os.walk(directory, topdown=True): + for root, dirs, files in os.walk(directory): dirs[:] = [d for d in dirs if _is_file_valid(d)] - for basename in files: - if _is_file_valid(basename) and fnmatch.fnmatch(basename, pattern): - filename = os.path.join(root, basename) + for f in files: + filename = Path(f) + if _is_file_valid(f) and filename.match(pattern): + filename = Path(root) / filename yield filename @@ -627,3 +632,4 @@ ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringif ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda) ESPHomeDumper.add_multi_representer(core.ID, ESPHomeDumper.represent_id) ESPHomeDumper.add_multi_representer(uuid.UUID, ESPHomeDumper.represent_stringify) +ESPHomeDumper.add_multi_representer(Path, ESPHomeDumper.represent_stringify) diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index fa496b3488..dc4ca77eb4 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -68,8 +68,11 @@ class DashboardBrowser(AsyncServiceBrowser): class DashboardImportDiscovery: - def __init__(self) -> None: + def __init__( + self, on_update: Callable[[str, DiscoveredImport | None], None] | None = None + ) -> None: self.import_state: dict[str, DiscoveredImport] = {} + self.on_update = on_update def browser_callback( self, @@ -85,7 +88,9 @@ class DashboardImportDiscovery: state_change, ) if state_change == ServiceStateChange.Removed: - self.import_state.pop(name, None) + removed = self.import_state.pop(name, None) + if removed and self.on_update: + self.on_update(name, None) return if state_change == ServiceStateChange.Updated and name not in self.import_state: @@ -139,7 +144,7 @@ class DashboardImportDiscovery: if friendly_name is not None: friendly_name = friendly_name.decode() - self.import_state[name] = DiscoveredImport( + discovered = DiscoveredImport( friendly_name=friendly_name, device_name=node_name, package_import_url=import_url, @@ -147,6 +152,10 @@ class DashboardImportDiscovery: project_version=project_version, network=network, ) + is_new = name not in self.import_state + self.import_state[name] = discovered + if is_new and self.on_update: + self.on_update(name, discovered) def update_device_mdns(self, node_name: str, version: str): storage_path = ext_storage_path(node_name + ".yaml") diff --git a/platformio.ini b/platformio.ini index d97607fac5..44b466a2b3 100644 --- a/platformio.ini +++ b/platformio.ini @@ -72,7 +72,6 @@ lib_deps = SPI ; spi (Arduino built-in) Wire ; i2c (Arduino built-int) heman/AsyncMqttClient-esphome@1.0.0 ; mqtt - ESP32Async/ESPAsyncWebServer@3.7.8 ; web_server_base fastled/FastLED@3.9.16 ; fastled_base freekode/TM1651@1.0.1 ; tm1651 glmnet/Dsmr@0.7 ; dsmr @@ -107,6 +106,7 @@ lib_deps = ESP8266WiFi ; wifi (Arduino built-in) Update ; ota (Arduino built-in) ESP32Async/ESPAsyncTCP@2.0.0 ; async_tcp + ESP32Async/ESPAsyncWebServer@3.7.8 ; web_server_base makuna/NeoPixelBus@2.7.3 ; neopixelbus ESP8266HTTPClient ; http_request (Arduino built-in) ESP8266mDNS ; mdns (Arduino built-in) @@ -129,7 +129,7 @@ platform = https://github.com/pioarduino/platform-espressif32/releases/download/ platform_packages = pioarduino/framework-arduinoespressif32@https://github.com/espressif/arduino-esp32/releases/download/3.2.1/esp32-3.2.1.zip -framework = arduino +framework = arduino, espidf ; Arduino as an ESP-IDF component lib_deps = ; order matters with lib-deps; some of the libs in common:arduino.lib_deps ; don't declare built-in libraries as dependencies, so they have to be declared first @@ -147,7 +147,7 @@ lib_deps = makuna/NeoPixelBus@2.8.0 ; neopixelbus esphome/ESP32-audioI2S@2.3.0 ; i2s_audio droscy/esp_wireguard@0.4.2 ; wireguard - esphome/esp-audio-libs@1.1.4 ; audio + esphome/esp-audio-libs@2.0.1 ; audio build_flags = ${common:arduino.build_flags} @@ -170,7 +170,7 @@ lib_deps = ${common:idf.lib_deps} droscy/esp_wireguard@0.4.2 ; wireguard kahrendt/ESPMicroSpeechFeatures@1.1.0 ; micro_wake_word - esphome/esp-audio-libs@1.1.4 ; audio + esphome/esp-audio-libs@2.0.1 ; audio build_flags = ${common:idf.build_flags} -Wno-nonnull-compare @@ -193,6 +193,7 @@ platform_packages = framework = arduino lib_deps = ${common:arduino.lib_deps} + ESP32Async/ESPAsyncWebServer@3.7.8 ; web_server_base build_flags = ${common:arduino.build_flags} -DUSE_RP2040 @@ -207,7 +208,8 @@ platform = libretiny@1.9.1 framework = arduino lib_compat_mode = soft lib_deps = - droscy/esp_wireguard@0.4.2 ; wireguard + ESP32Async/ESPAsyncWebServer@3.7.8 ; web_server_base + droscy/esp_wireguard@0.4.2 ; wireguard build_flags = ${common:arduino.build_flags} -DUSE_LIBRETINY @@ -274,6 +276,7 @@ build_unflags = [env:esp32-arduino-tidy] extends = common:esp32-arduino board = esp32dev +board_build.esp-idf.sdkconfig_path = .temp/sdkconfig-esp32-arduino-tidy build_flags = ${common:esp32-arduino.build_flags} ${flags:clangtidy.build_flags} diff --git a/pyproject.toml b/pyproject.toml index 4943c48eb0..b7b4a48d7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "esphome" -license = {text = "MIT"} +license = "MIT" description = "ESPHome is a system to configure your microcontrollers by simple yet powerful configuration files and control them remotely through Home Automation systems." readme = "README.md" authors = [ @@ -15,7 +15,6 @@ classifiers = [ "Environment :: Console", "Intended Audience :: Developers", "Intended Audience :: End Users/Desktop", - "License :: OSI Approved :: MIT License", "Programming Language :: C++", "Programming Language :: Python :: 3", "Topic :: Home Automation", diff --git a/requirements.txt b/requirements.txt index 296485bdae..7ff4a6eeb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ cryptography==45.0.1 voluptuous==0.15.2 -PyYAML==6.0.2 +PyYAML==6.0.3 paho-mqtt==1.6.1 colorama==0.4.6 icmplib==3.0.4 @@ -9,13 +9,14 @@ tzlocal==5.3.1 # from time tzdata>=2021.1 # from time pyserial==3.5 platformio==6.1.18 # When updating platformio, also update /docker/Dockerfile -esptool==5.0.2 +esptool==5.1.0 click==8.1.7 esphome-dashboard==20250904.0 -aioesphomeapi==40.2.1 -zeroconf==0.147.2 +aioesphomeapi==41.13.0 +zeroconf==0.148.0 puremagic==1.30 ruamel.yaml==0.18.15 # dashboard_import +ruamel.yaml.clib==0.2.12 # dashboard_import esphome-glyphsets==0.2.0 pillow==10.4.0 cairosvg==2.8.2 diff --git a/requirements_test.txt b/requirements_test.txt index bae9246768..76a305367a 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,14 +1,14 @@ -pylint==3.3.8 +pylint==3.3.9 flake8==7.3.0 # also change in .pre-commit-config.yaml when updating -ruff==0.12.12 # also change in .pre-commit-config.yaml when updating +ruff==0.14.0 # also change in .pre-commit-config.yaml when updating pyupgrade==3.20.0 # also change in .pre-commit-config.yaml when updating pre-commit # Unit tests pytest==8.4.2 pytest-cov==7.0.0 -pytest-mock==3.15.0 -pytest-asyncio==1.1.0 +pytest-mock==3.15.1 +pytest-asyncio==1.2.0 pytest-xdist==3.8.0 asyncmock==0.4.2 hypothesis==6.92.1 diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 205bac4937..487c187372 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -3,7 +3,6 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import IntEnum -import os from pathlib import Path import re from subprocess import call @@ -354,12 +353,33 @@ def create_field_type_info( return FixedArrayRepeatedType(field, size_define) return RepeatedTypeInfo(field) - # Check for fixed_array_size option on bytes fields - if ( - field.type == 12 - and (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None - ): - return FixedArrayBytesType(field, fixed_size) + # Check for mutually exclusive options on bytes fields + if field.type == 12: + has_pointer_to_buffer = get_field_opt(field, pb.pointer_to_buffer, False) + fixed_size = get_field_opt(field, pb.fixed_array_size, None) + + if has_pointer_to_buffer and fixed_size is not None: + raise ValueError( + f"Field '{field.name}' has both pointer_to_buffer and fixed_array_size. " + "These options are mutually exclusive. Use pointer_to_buffer for zero-copy " + "or fixed_array_size for traditional array storage." + ) + + if has_pointer_to_buffer: + # Zero-copy pointer approach - no size needed, will use size_t for length + return PointerToBytesBufferType(field, None) + + if fixed_size is not None: + # Traditional fixed array approach with copy + return FixedArrayBytesType(field, fixed_size) + + # Check for pointer_to_buffer option on string fields + if field.type == 9: + has_pointer_to_buffer = get_field_opt(field, pb.pointer_to_buffer, False) + + if has_pointer_to_buffer: + # Zero-copy pointer approach for strings + return PointerToBytesBufferType(field, None) # Special handling for bytes fields if field.type == 12: @@ -819,6 +839,91 @@ class BytesType(TypeInfo): return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes +class PointerToBytesBufferType(TypeInfo): + """Type for bytes fields that use pointer_to_buffer option for zero-copy.""" + + @classmethod + def can_use_dump_field(cls) -> bool: + return False + + def __init__( + self, field: descriptor.FieldDescriptorProto, size: int | None = None + ) -> None: + super().__init__(field) + # Size is not used for pointer_to_buffer - we always use size_t for length + self.array_size = 0 + + @property + def cpp_type(self) -> str: + return "const uint8_t*" + + @property + def default_value(self) -> str: + return "nullptr" + + @property + def reference_type(self) -> str: + return "const uint8_t*" + + @property + def const_reference_type(self) -> str: + return "const uint8_t*" + + @property + def public_content(self) -> list[str]: + # Use uint16_t for length - max packet size is well below 65535 + # Add pointer and length fields + return [ + f"const uint8_t* {self.field_name}{{nullptr}};", + f"uint16_t {self.field_name}_len{{0}};", + ] + + @property + def encode_content(self) -> str: + return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);" + + @property + def decode_length_content(self) -> str | None: + # Decode directly stores the pointer to avoid allocation + return f"""case {self.number}: {{ + // Use raw data directly to avoid allocation + this->{self.field_name} = value.data(); + this->{self.field_name}_len = value.size(); + break; + }}""" + + @property + def decode_length(self) -> str | None: + # This is handled in decode_length_content + return None + + @property + def wire_type(self) -> WireType: + """Get the wire type for this bytes field.""" + return WireType.LENGTH_DELIMITED # Uses wire type 2 + + def dump(self, name: str) -> str: + return ( + f"format_hex_pretty(this->{self.field_name}, this->{self.field_name}_len)" + ) + + @property + def dump_content(self) -> str: + # Custom dump that doesn't use dump_field template + return ( + f'out.append(" {self.name}: ");\n' + + f"out.append({self.dump(self.field_name)});\n" + + 'out.append("\\n");' + ) + + def get_size_calculation(self, name: str, force: bool = False) -> str: + return f"size.add_length({self.number}, this->{self.field_name}_len);" + + def get_estimated_size(self) -> int: + # field ID + length varint + typical data (assume small for pointer fields) + return self.calculate_field_id_size() + 2 + 16 + + class FixedArrayBytesType(TypeInfo): """Special type for fixed-size byte arrays.""" @@ -848,10 +953,17 @@ class FixedArrayBytesType(TypeInfo): @property def public_content(self) -> list[str]: + len_type = ( + "uint8_t" + if self.array_size <= 255 + else "uint16_t" + if self.array_size <= 65535 + else "size_t" + ) # Add both the array and length fields return [ f"uint8_t {self.field_name}[{self.array_size}]{{}};", - f"uint8_t {self.field_name}_len{{0}};", + f"{len_type} {self.field_name}_len{{0}};", ] @property @@ -1743,13 +1855,16 @@ def build_message_type( # Add estimated size constant estimated_size = calculate_message_estimated_size(desc) - # Validate that estimated_size fits in uint8_t - if estimated_size > 255: - raise ValueError( - f"Estimated size {estimated_size} for {desc.name} exceeds uint8_t maximum (255)" - ) + # Use a type appropriate for estimated_size + estimated_size_type = ( + "uint8_t" + if estimated_size <= 255 + else "uint16_t" + if estimated_size <= 65535 + else "size_t" + ) public_content.append( - f"static constexpr uint8_t ESTIMATED_SIZE = {estimated_size};" + f"static constexpr {estimated_size_type} ESTIMATED_SIZE = {estimated_size};" ) # Add message_name method inline in header @@ -2606,6 +2721,10 @@ static const char *const TAG = "api.service"; hpp_protected = "" cpp += "\n" + # Build a mapping of message input types to their authentication requirements + message_auth_map: dict[str, bool] = {} + message_conn_map: dict[str, bool] = {} + m = serv.method[0] for m in serv.method: func = m.name @@ -2617,6 +2736,10 @@ static const char *const TAG = "api.service"; needs_conn = get_opt(m, pb.needs_setup_connection, True) needs_auth = get_opt(m, pb.needs_authentication, True) + # Store authentication requirements for message types + message_auth_map[inp] = needs_auth + message_conn_map[inp] = needs_conn + ifdef = message_ifdef_map.get(inp, ifdefs.get(inp)) if ifdef is not None: @@ -2634,33 +2757,14 @@ static const char *const TAG = "api.service"; cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n" - # Start with authentication/connection check if needed - if needs_auth or needs_conn: - # Determine which check to use - if needs_auth: - check_func = "this->check_authenticated_()" - else: - check_func = "this->check_connection_setup_()" - - if is_void: - # For void methods, just wrap with auth check - body = f"if ({check_func}) {{\n" - body += f" this->{func}(msg);\n" - body += "}\n" - else: - # For non-void methods, combine auth check and send response check - body = f"if ({check_func} && !this->send_{func}_response(msg)) {{\n" - body += " this->on_fatal_error();\n" - body += "}\n" + # No authentication check here - it's done in read_message + body = "" + if is_void: + body += f"this->{func}(msg);\n" else: - # No auth check needed, just call the handler - body = "" - if is_void: - body += f"this->{func}(msg);\n" - else: - body += f"if (!this->send_{func}_response(msg)) {{\n" - body += " this->on_fatal_error();\n" - body += "}\n" + body += f"if (!this->send_{func}_response(msg)) {{\n" + body += " this->on_fatal_error();\n" + body += "}\n" cpp += indent(body) + "\n" + "}\n" @@ -2669,6 +2773,65 @@ static const char *const TAG = "api.service"; hpp_protected += "#endif\n" cpp += "#endif\n" + # Generate optimized read_message with authentication checking + # Categorize messages by their authentication requirements + no_conn_ids: set[int] = set() + conn_only_ids: set[int] = set() + + for id_, (_, _, case_msg_name) in cases: + if case_msg_name in message_auth_map: + needs_auth = message_auth_map[case_msg_name] + needs_conn = message_conn_map[case_msg_name] + + if not needs_conn: + no_conn_ids.add(id_) + elif not needs_auth: + conn_only_ids.add(id_) + + # Generate override if we have messages that skip checks + if no_conn_ids or conn_only_ids: + # Helper to generate case statements with ifdefs + def generate_cases(ids: set[int], comment: str) -> str: + result = "" + for id_ in sorted(ids): + _, ifdef, msg_name = RECEIVE_CASES[id_] + if ifdef: + result += f"#ifdef {ifdef}\n" + result += f" case {msg_name}::MESSAGE_TYPE: {comment}\n" + if ifdef: + result += "#endif\n" + return result + + hpp_protected += " void read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n" + + cpp += f"\nvoid {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n" + cpp += " // Check authentication/connection requirements for messages\n" + cpp += " switch (msg_type) {\n" + + # Messages that don't need any checks + if no_conn_ids: + cpp += generate_cases(no_conn_ids, "// No setup required") + cpp += " break; // Skip all checks for these messages\n" + + # Messages that only need connection setup + if conn_only_ids: + cpp += generate_cases(conn_only_ids, "// Connection setup only") + cpp += " if (!this->check_connection_setup_()) {\n" + cpp += " return; // Connection not setup\n" + cpp += " }\n" + cpp += " break;\n" + + cpp += " default:\n" + cpp += " // All other messages require authentication (which includes connection check)\n" + cpp += " if (!this->check_authenticated_()) {\n" + cpp += " return; // Authentication failed\n" + cpp += " }\n" + cpp += " break;\n" + cpp += " }\n\n" + cpp += " // Call base implementation to process the message\n" + cpp += f" {class_name}Base::read_message(msg_size, msg_type, msg_data);\n" + cpp += "}\n" + hpp += " protected:\n" hpp += hpp_protected hpp += "};\n" @@ -2694,8 +2857,8 @@ static const char *const TAG = "api.service"; import clang_format def exec_clang_format(path: Path) -> None: - clang_format_path = os.path.join( - os.path.dirname(clang_format.__file__), "data", "bin", "clang-format" + clang_format_path = ( + Path(clang_format.__file__).parent / "data" / "bin" / "clang-format" ) call([clang_format_path, "-i", path]) diff --git a/script/build_codeowners.py b/script/build_codeowners.py index 27ea82611b..10ca1295b7 100755 --- a/script/build_codeowners.py +++ b/script/build_codeowners.py @@ -39,7 +39,7 @@ esphome/core/* @esphome/core parts = [BASE] # Fake some directory so that get_component works -CORE.config_path = str(root) +CORE.config_path = root CORE.data[KEY_CORE] = {KEY_TARGET_FRAMEWORK: None, KEY_TARGET_PLATFORM: None} codeowners = defaultdict(list) diff --git a/script/build_language_schema.py b/script/build_language_schema.py index ff6e898902..1ffe3c2873 100755 --- a/script/build_language_schema.py +++ b/script/build_language_schema.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import argparse -import glob import inspect import json import os +from pathlib import Path import re import voluptuous as vol @@ -70,14 +70,14 @@ def get_component_names(): component_names = ["esphome", "sensor", "esp32", "esp8266"] skip_components = [] - for d in os.listdir(CORE_COMPONENTS_PATH): + for d in CORE_COMPONENTS_PATH.iterdir(): if ( - not d.startswith("__") - and os.path.isdir(os.path.join(CORE_COMPONENTS_PATH, d)) - and d not in component_names - and d not in skip_components + not d.name.startswith("__") + and d.is_dir() + and d.name not in component_names + and d.name not in skip_components ): - component_names.append(d) + component_names.append(d.name) return sorted(component_names) @@ -121,7 +121,7 @@ from esphome.util import Registry # noqa: E402 def write_file(name, obj): - full_path = os.path.join(args.output_path, name + ".json") + full_path = Path(args.output_path) / f"{name}.json" if JSON_DUMP_PRETTY: json_str = json.dumps(obj, indent=2) else: @@ -131,9 +131,10 @@ def write_file(name, obj): def delete_extra_files(keep_names): - for d in os.listdir(args.output_path): - if d.endswith(".json") and d[:-5] not in keep_names: - os.remove(os.path.join(args.output_path, d)) + output_path = Path(args.output_path) + for d in output_path.iterdir(): + if d.suffix == ".json" and d.stem not in keep_names: + d.unlink() print(f"Deleted {d}") @@ -367,13 +368,11 @@ def get_logger_tags(): "scheduler", "api.service", ] - for x in os.walk(CORE_COMPONENTS_PATH): - for y in glob.glob(os.path.join(x[0], "*.cpp")): - with open(y, encoding="utf-8") as file: - data = file.read() - match = pattern.search(data) - if match: - tags.append(match.group(1)) + for file in CORE_COMPONENTS_PATH.rglob("*.cpp"): + data = file.read_text() + match = pattern.search(data) + if match: + tags.append(match.group(1)) return tags diff --git a/script/ci-custom.py b/script/ci-custom.py index 61081608d5..bc1ebda93b 100755 --- a/script/ci-custom.py +++ b/script/ci-custom.py @@ -6,6 +6,7 @@ import collections import fnmatch import functools import os.path +from pathlib import Path import re import sys import time @@ -75,12 +76,12 @@ ignore_types = ( LINT_FILE_CHECKS = [] LINT_CONTENT_CHECKS = [] LINT_POST_CHECKS = [] -EXECUTABLE_BIT = {} +EXECUTABLE_BIT: dict[str, int] = {} -errors = collections.defaultdict(list) +errors: collections.defaultdict[Path, list] = collections.defaultdict(list) -def add_errors(fname, errs): +def add_errors(fname: Path, errs: list[tuple[int, int, str] | None]) -> None: if not isinstance(errs, list): errs = [errs] for err in errs: @@ -246,8 +247,8 @@ def lint_ext_check(fname): ".github/copilot-instructions.md", ] ) -def lint_executable_bit(fname): - ex = EXECUTABLE_BIT[fname] +def lint_executable_bit(fname: Path) -> str | None: + ex = EXECUTABLE_BIT[str(fname)] if ex != 100644: return ( f"File has invalid executable bit {ex}. If running from a windows machine please " @@ -506,8 +507,8 @@ def lint_constants_usage(): return errs -def relative_cpp_search_text(fname, content): - parts = fname.split("/") +def relative_cpp_search_text(fname: Path, content) -> str: + parts = fname.parts integration = parts[2] return f'#include "esphome/components/{integration}' @@ -524,8 +525,8 @@ def lint_relative_cpp_import(fname, line, col, content): ) -def relative_py_search_text(fname, content): - parts = fname.split("/") +def relative_py_search_text(fname: Path, content: str) -> str: + parts = fname.parts integration = parts[2] return f"esphome.components.{integration}" @@ -591,10 +592,8 @@ def lint_relative_py_import(fname, line, col, content): "esphome/components/http_request/httplib.h", ], ) -def lint_namespace(fname, content): - expected_name = re.match( - r"^esphome/components/([^/]+)/.*", fname.replace(os.path.sep, "/") - ).group(1) +def lint_namespace(fname: Path, content: str) -> str | None: + expected_name = fname.parts[2] # Check for both old style and C++17 nested namespace syntax search_old = f"namespace {expected_name}" search_new = f"namespace esphome::{expected_name}" @@ -733,9 +732,9 @@ def main(): files.sort() for fname in files: - _, ext = os.path.splitext(fname) + fname = Path(fname) run_checks(LINT_FILE_CHECKS, fname, fname) - if ext in ignore_types: + if fname.suffix in ignore_types: continue try: with codecs.open(fname, "r", encoding="utf-8") as f_handle: diff --git a/script/clang_tidy_hash.py b/script/clang_tidy_hash.py index 19eb2a825e..d0d8438437 100755 --- a/script/clang_tidy_hash.py +++ b/script/clang_tidy_hash.py @@ -48,9 +48,10 @@ def parse_requirement_line(line: str) -> tuple[str, str] | None: return None -def get_clang_tidy_version_from_requirements() -> str: +def get_clang_tidy_version_from_requirements(repo_root: Path | None = None) -> str: """Get clang-tidy version from requirements_dev.txt""" - requirements_path = Path(__file__).parent.parent / "requirements_dev.txt" + repo_root = _ensure_repo_root(repo_root) + requirements_path = repo_root / "requirements_dev.txt" lines = read_file_lines(requirements_path) for line in lines: @@ -68,30 +69,49 @@ def read_file_bytes(path: Path) -> bytes: return f.read() -def calculate_clang_tidy_hash() -> str: +def get_repo_root() -> Path: + """Get the repository root directory.""" + return Path(__file__).parent.parent + + +def _ensure_repo_root(repo_root: Path | None) -> Path: + """Ensure repo_root is a Path, using default if None.""" + return repo_root if repo_root is not None else get_repo_root() + + +def calculate_clang_tidy_hash(repo_root: Path | None = None) -> str: """Calculate hash of clang-tidy configuration and version""" + repo_root = _ensure_repo_root(repo_root) + hasher = hashlib.sha256() # Hash .clang-tidy file - clang_tidy_path = Path(__file__).parent.parent / ".clang-tidy" + clang_tidy_path = repo_root / ".clang-tidy" content = read_file_bytes(clang_tidy_path) hasher.update(content) # Hash clang-tidy version from requirements_dev.txt - version = get_clang_tidy_version_from_requirements() + version = get_clang_tidy_version_from_requirements(repo_root) hasher.update(version.encode()) # Hash the entire platformio.ini file - platformio_path = Path(__file__).parent.parent / "platformio.ini" + platformio_path = repo_root / "platformio.ini" platformio_content = read_file_bytes(platformio_path) hasher.update(platformio_content) + # Hash sdkconfig.defaults file + sdkconfig_path = repo_root / "sdkconfig.defaults" + if sdkconfig_path.exists(): + sdkconfig_content = read_file_bytes(sdkconfig_path) + hasher.update(sdkconfig_content) + return hasher.hexdigest() -def read_stored_hash() -> str | None: +def read_stored_hash(repo_root: Path | None = None) -> str | None: """Read the stored hash from file""" - hash_file = Path(__file__).parent.parent / ".clang-tidy.hash" + repo_root = _ensure_repo_root(repo_root) + hash_file = repo_root / ".clang-tidy.hash" if hash_file.exists(): lines = read_file_lines(hash_file) return lines[0].strip() if lines else None @@ -104,9 +124,10 @@ def write_file_content(path: Path, content: str) -> None: f.write(content) -def write_hash(hash_value: str) -> None: +def write_hash(hash_value: str, repo_root: Path | None = None) -> None: """Write hash to file""" - hash_file = Path(__file__).parent.parent / ".clang-tidy.hash" + repo_root = _ensure_repo_root(repo_root) + hash_file = repo_root / ".clang-tidy.hash" # Strip any trailing newlines to ensure consistent formatting write_file_content(hash_file, hash_value.strip() + "\n") @@ -134,8 +155,28 @@ def main() -> None: stored_hash = read_stored_hash() if args.check: - # Exit 0 if full scan needed (hash changed or no hash file) - sys.exit(0 if current_hash != stored_hash else 1) + # Check if hash changed OR if .clang-tidy.hash was updated in this PR + # This is used in CI to determine if a full clang-tidy scan is needed + hash_changed = current_hash != stored_hash + + # Lazy import to avoid requiring dependencies that aren't needed for other modes + from helpers import changed_files # noqa: E402 + + hash_file_updated = ".clang-tidy.hash" in changed_files() + + # Exit 0 if full scan needed + sys.exit(0 if (hash_changed or hash_file_updated) else 1) + + elif args.verify: + # Verify that hash file is up to date with current configuration + # This is used in pre-commit and CI checks to ensure hash was updated + if current_hash != stored_hash: + print("ERROR: Clang-tidy configuration has changed but hash not updated!") + print(f"Expected: {current_hash}") + print(f"Found: {stored_hash}") + print("\nPlease run: script/clang_tidy_hash.py --update") + sys.exit(1) + print("Hash verification passed") elif args.update: write_hash(current_hash) @@ -151,15 +192,6 @@ def main() -> None: print("Clang-tidy hash unchanged") sys.exit(0) - elif args.verify: - if current_hash != stored_hash: - print("ERROR: Clang-tidy configuration has changed but hash not updated!") - print(f"Expected: {current_hash}") - print(f"Found: {stored_hash}") - print("\nPlease run: script/clang_tidy_hash.py --update") - sys.exit(1) - print("Hash verification passed") - else: print(f"Current hash: {current_hash}") print(f"Stored hash: {stored_hash}") diff --git a/script/generate-esp32-boards.py b/script/generate-esp32-boards.py index 3f444ed455..81b78b04be 100755 --- a/script/generate-esp32-boards.py +++ b/script/generate-esp32-boards.py @@ -1,14 +1,19 @@ #!/usr/bin/env python3 +import argparse import json -import os +from pathlib import Path import subprocess +import sys import tempfile -from esphome.components.esp32 import ESP_IDF_PLATFORM_VERSION as ver +from esphome.components.esp32 import PLATFORM_VERSION_LOOKUP +from esphome.helpers import write_file_if_changed +ver = PLATFORM_VERSION_LOOKUP["recommended"] version_str = f"{ver.major}.{ver.minor:02d}.{ver.patch:02d}" -print(f"ESP32 Platform Version: {version_str}") +root = Path(__file__).parent.parent +boards_file_path = root / "esphome" / "components" / "esp32" / "boards.py" def get_boards(): @@ -17,6 +22,9 @@ def get_boards(): [ "git", "clone", + "-q", + "-c", + "advice.detachedHead=false", "--depth", "1", "--branch", @@ -26,16 +34,14 @@ def get_boards(): ], check=True, ) - boards_file = os.path.join(tempdir, "boards") + boards_directory = Path(tempdir) / "boards" boards = {} - for fname in os.listdir(boards_file): - if not fname.endswith(".json"): - continue - with open(os.path.join(boards_file, fname), encoding="utf-8") as f: + for fname in boards_directory.glob("*.json"): + with fname.open(encoding="utf-8") as f: board_info = json.load(f) mcu = board_info["build"]["mcu"] name = board_info["name"] - board = fname[:-5] + board = fname.stem variant = mcu.upper() boards[board] = { "name": name, @@ -47,33 +53,47 @@ def get_boards(): TEMPLATE = """ "%s": { "name": "%s", "variant": %s, - }, -""" + },""" -def main(): +def main(check: bool): boards = get_boards() # open boards.py, delete existing BOARDS variable and write the new boards dict - boards_file_path = os.path.join( - os.path.dirname(__file__), "..", "esphome", "components", "esp32", "boards.py" - ) - with open(boards_file_path, encoding="UTF-8") as f: - lines = f.readlines() + existing_content = boards_file_path.read_text(encoding="UTF-8") - with open(boards_file_path, "w", encoding="UTF-8") as f: - for line in lines: - if line.startswith("BOARDS = {"): - f.write("BOARDS = {\n") - f.writelines( - TEMPLATE % (board, info["name"], info["variant"]) - for board, info in sorted(boards.items()) - ) - f.write("}\n") - break + parts: list[str] = [] + for line in existing_content.splitlines(): + if line == "BOARDS = {": + parts.append(line) + parts.extend( + TEMPLATE % (board, info["name"], info["variant"]) + for board, info in sorted(boards.items()) + ) + parts.append("}") + parts.append("# DO NOT ADD ANYTHING BELOW THIS LINE") + break - f.write(line) + parts.append(line) + + parts.append("") + content = "\n".join(parts) + + if check: + if existing_content != content: + print("boards.py file is not up to date.") + print("Please run `script/generate-esp32-boards.py`") + sys.exit(1) + print("boards.py file is up to date") + elif write_file_if_changed(boards_file_path, content): + print("ESP32 boards updated successfully.") if __name__ == "__main__": - main() - print("ESP32 boards updated successfully.") + parser = argparse.ArgumentParser() + parser.add_argument( + "--check", + help="Check if the boards.py file is up to date.", + action="store_true", + ) + args = parser.parse_args() + main(args.check) diff --git a/script/helpers.py b/script/helpers.py index 2c2f44a513..61306b9489 100644 --- a/script/helpers.py +++ b/script/helpers.py @@ -52,10 +52,10 @@ def styled(color: str | tuple[str, ...], msg: str, reset: bool = True) -> str: return prefix + msg + suffix -def print_error_for_file(file: str, body: str | None) -> None: +def print_error_for_file(file: str | Path, body: str | None) -> None: print( styled(colorama.Fore.GREEN, "### File ") - + styled((colorama.Fore.GREEN, colorama.Style.BRIGHT), file) + + styled((colorama.Fore.GREEN, colorama.Style.BRIGHT), str(file)) ) print() if body is not None: @@ -513,7 +513,7 @@ def get_all_dependencies(component_names: set[str]) -> set[str]: # Set up fake config path for component loading root = Path(__file__).parent.parent - CORE.config_path = str(root) + CORE.config_path = root CORE.data[KEY_CORE] = {} # Keep finding dependencies until no new ones are found @@ -529,7 +529,16 @@ def get_all_dependencies(component_names: set[str]) -> set[str]: new_components.update(dep.split(".")[0] for dep in comp.dependencies) # Add auto_load components - new_components.update(comp.auto_load) + auto_load = comp.auto_load + if callable(auto_load): + import inspect + + if inspect.signature(auto_load).parameters: + auto_load = auto_load(None) + else: + auto_load = auto_load() + + new_components.update(auto_load) # Check if we found any new components new_components -= all_components @@ -553,7 +562,7 @@ def get_components_from_integration_fixtures() -> set[str]: fixtures_dir = Path(__file__).parent.parent / "tests" / "integration" / "fixtures" for yaml_file in fixtures_dir.glob("*.yaml"): - config: dict[str, any] | None = yaml_util.load_yaml(str(yaml_file)) + config: dict[str, any] | None = yaml_util.load_yaml(yaml_file) if not config: continue diff --git a/script/list-components.py b/script/list-components.py index 66212f44e7..9ab1cdd852 100755 --- a/script/list-components.py +++ b/script/list-components.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import argparse +from collections.abc import Callable from pathlib import Path import sys @@ -13,7 +14,7 @@ from esphome.const import ( PLATFORM_ESP8266, ) from esphome.core import CORE -from esphome.loader import get_component, get_platform +from esphome.loader import ComponentManifest, get_component, get_platform def filter_component_files(str): @@ -45,12 +46,35 @@ def add_item_to_components_graph(components_graph, parent, child): components_graph[parent].append(child) +def resolve_auto_load( + auto_load: list[str] | Callable[[], list[str]] | Callable[[dict | None], list[str]], + config: dict | None = None, +) -> list[str]: + """Resolve AUTO_LOAD to a list, handling callables with or without config parameter. + + Args: + auto_load: The AUTO_LOAD value (list or callable) + config: Optional config to pass to callable AUTO_LOAD functions + + Returns: + List of component names to auto-load + """ + if not callable(auto_load): + return auto_load + + import inspect + + if inspect.signature(auto_load).parameters: + return auto_load(config) + return auto_load() + + def create_components_graph(): # The root directory of the repo root = Path(__file__).parent.parent components_dir = root / "esphome" / "components" # Fake some directory so that get_component works - CORE.config_path = str(root) + CORE.config_path = root # Various configuration to capture different outcomes used by `AUTO_LOAD` function. TARGET_CONFIGURATIONS = [ {KEY_TARGET_FRAMEWORK: None, KEY_TARGET_PLATFORM: None}, @@ -63,7 +87,7 @@ def create_components_graph(): components_graph = {} platforms = [] - components = [] + components: list[tuple[ComponentManifest, str, Path]] = [] for path in components_dir.iterdir(): if not path.is_dir(): @@ -92,8 +116,8 @@ def create_components_graph(): for target_config in TARGET_CONFIGURATIONS: CORE.data[KEY_CORE] = target_config - for auto_load in comp.auto_load: - add_item_to_components_graph(components_graph, auto_load, name) + for item in resolve_auto_load(comp.auto_load, config=None): + add_item_to_components_graph(components_graph, item, name) # restore config CORE.data[KEY_CORE] = TARGET_CONFIGURATIONS[0] @@ -114,8 +138,8 @@ def create_components_graph(): for target_config in TARGET_CONFIGURATIONS: CORE.data[KEY_CORE] = target_config - for auto_load in platform.auto_load: - add_item_to_components_graph(components_graph, auto_load, name) + for item in resolve_auto_load(platform.auto_load, config={}): + add_item_to_components_graph(components_graph, item, name) # restore config CORE.data[KEY_CORE] = TARGET_CONFIGURATIONS[0] diff --git a/script/setup b/script/setup index 1bd7c44575..8cad7017ff 100755 --- a/script/setup +++ b/script/setup @@ -22,8 +22,6 @@ uv pip install -e ".[dev,test]" --config-settings editable_mode=compat pre-commit install -script/platformio_install_deps.py platformio.ini --libraries --tools --platforms - mkdir -p .temp echo diff --git a/script/setup.bat b/script/setup.bat index f89d5aea1a..003ea31b36 100644 --- a/script/setup.bat +++ b/script/setup.bat @@ -19,8 +19,6 @@ pip3 install -e ".[dev,test]" --config-settings editable_mode=compat pre-commit install -python script/platformio_install_deps.py platformio.ini --libraries --tools --platforms - echo . echo . echo Virtual environment created. Run 'venv/Scripts/activate' to use it. diff --git a/sdkconfig.defaults b/sdkconfig.defaults index 72ca3f6e9c..322efb701a 100644 --- a/sdkconfig.defaults +++ b/sdkconfig.defaults @@ -13,6 +13,7 @@ CONFIG_ESP_TASK_WDT=y CONFIG_ESP_TASK_WDT_PANIC=y CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0=n CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1=n +CONFIG_AUTOSTART_ARDUINO=y # esp32_ble CONFIG_BT_ENABLED=y diff --git a/tests/component_tests/conftest.py b/tests/component_tests/conftest.py index 2045b03502..0641e698e9 100644 --- a/tests/component_tests/conftest.py +++ b/tests/component_tests/conftest.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Generator from pathlib import Path import sys from typing import Any +from unittest import mock import pytest @@ -17,6 +18,7 @@ from esphome.const import ( PlatformFramework, ) from esphome.types import ConfigType +from esphome.util import OrderedDict # Add package root to python path here = Path(__file__).parent @@ -40,9 +42,9 @@ def config_path(request: pytest.FixtureRequest) -> Generator[None]: if config_dir.exists(): # Set config_path to a dummy yaml file in the config directory # This ensures CORE.config_dir points to the config directory - CORE.config_path = str(config_dir / "dummy.yaml") + CORE.config_path = config_dir / "dummy.yaml" else: - CORE.config_path = str(Path(request.fspath).parent / "dummy.yaml") + CORE.config_path = Path(request.fspath).parent / "dummy.yaml" yield CORE.config_path = original_path @@ -129,9 +131,35 @@ def generate_main() -> Generator[Callable[[str | Path], str]]: """Generates the C++ main.cpp from a given yaml file and returns it in string form.""" def generator(path: str | Path) -> str: - CORE.config_path = str(path) + CORE.config_path = Path(path) CORE.config = read_config({}) generate_cpp_contents(CORE.config) return CORE.cpp_main_section yield generator + + +@pytest.fixture +def mock_clone_or_update() -> Generator[Any]: + """Mock git.clone_or_update for testing.""" + with mock.patch("esphome.git.clone_or_update") as mock_func: + # Default return value + mock_func.return_value = (Path("/tmp/test"), None) + yield mock_func + + +@pytest.fixture +def mock_load_yaml() -> Generator[Any]: + """Mock yaml_util.load_yaml for testing.""" + + with mock.patch("esphome.yaml_util.load_yaml") as mock_func: + # Default return value + mock_func.return_value = OrderedDict({"sensor": []}) + yield mock_func + + +@pytest.fixture +def mock_install_meta_finder() -> Generator[Any]: + """Mock loader.install_meta_finder for testing.""" + with mock.patch("esphome.loader.install_meta_finder") as mock_func: + yield mock_func diff --git a/tests/component_tests/external_components/test_init.py b/tests/component_tests/external_components/test_init.py new file mode 100644 index 0000000000..905c0afa8b --- /dev/null +++ b/tests/component_tests/external_components/test_init.py @@ -0,0 +1,134 @@ +"""Tests for the external_components skip_update functionality.""" + +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +from esphome.components.external_components import do_external_components_pass +from esphome.const import ( + CONF_EXTERNAL_COMPONENTS, + CONF_REFRESH, + CONF_SOURCE, + CONF_URL, + TYPE_GIT, +) + + +def test_external_components_skip_update_true( + tmp_path: Path, mock_clone_or_update: MagicMock, mock_install_meta_finder: MagicMock +) -> None: + """Test that external components don't update when skip_update=True.""" + # Create a components directory structure + components_dir = tmp_path / "components" + components_dir.mkdir() + + # Create a test component + test_component_dir = components_dir / "test_component" + test_component_dir.mkdir() + (test_component_dir / "__init__.py").write_text("# Test component") + + # Set up mock to return our tmp_path + mock_clone_or_update.return_value = (tmp_path, None) + + config: dict[str, Any] = { + CONF_EXTERNAL_COMPONENTS: [ + { + CONF_SOURCE: { + "type": TYPE_GIT, + CONF_URL: "https://github.com/test/components", + }, + CONF_REFRESH: "1d", + "components": "all", + } + ] + } + + # Call with skip_update=True + do_external_components_pass(config, skip_update=True) + + # Verify clone_or_update was called with NEVER_REFRESH + mock_clone_or_update.assert_called_once() + call_args = mock_clone_or_update.call_args + from esphome import git + + assert call_args.kwargs["refresh"] == git.NEVER_REFRESH + + +def test_external_components_skip_update_false( + tmp_path: Path, mock_clone_or_update: MagicMock, mock_install_meta_finder: MagicMock +) -> None: + """Test that external components update when skip_update=False.""" + # Create a components directory structure + components_dir = tmp_path / "components" + components_dir.mkdir() + + # Create a test component + test_component_dir = components_dir / "test_component" + test_component_dir.mkdir() + (test_component_dir / "__init__.py").write_text("# Test component") + + # Set up mock to return our tmp_path + mock_clone_or_update.return_value = (tmp_path, None) + + config: dict[str, Any] = { + CONF_EXTERNAL_COMPONENTS: [ + { + CONF_SOURCE: { + "type": TYPE_GIT, + CONF_URL: "https://github.com/test/components", + }, + CONF_REFRESH: "1d", + "components": "all", + } + ] + } + + # Call with skip_update=False + do_external_components_pass(config, skip_update=False) + + # Verify clone_or_update was called with actual refresh value + mock_clone_or_update.assert_called_once() + call_args = mock_clone_or_update.call_args + from esphome.core import TimePeriodSeconds + + assert call_args.kwargs["refresh"] == TimePeriodSeconds(days=1) + + +def test_external_components_default_no_skip( + tmp_path: Path, mock_clone_or_update: MagicMock, mock_install_meta_finder: MagicMock +) -> None: + """Test that external components update by default when skip_update not specified.""" + # Create a components directory structure + components_dir = tmp_path / "components" + components_dir.mkdir() + + # Create a test component + test_component_dir = components_dir / "test_component" + test_component_dir.mkdir() + (test_component_dir / "__init__.py").write_text("# Test component") + + # Set up mock to return our tmp_path + mock_clone_or_update.return_value = (tmp_path, None) + + config: dict[str, Any] = { + CONF_EXTERNAL_COMPONENTS: [ + { + CONF_SOURCE: { + "type": TYPE_GIT, + CONF_URL: "https://github.com/test/components", + }, + CONF_REFRESH: "1d", + "components": "all", + } + ] + } + + # Call without skip_update parameter + do_external_components_pass(config) + + # Verify clone_or_update was called with actual refresh value + mock_clone_or_update.assert_called_once() + call_args = mock_clone_or_update.call_args + from esphome.core import TimePeriodSeconds + + assert call_args.kwargs["refresh"] == TimePeriodSeconds(days=1) diff --git a/tests/component_tests/packages/test_init.py b/tests/component_tests/packages/test_init.py new file mode 100644 index 0000000000..779244e2ed --- /dev/null +++ b/tests/component_tests/packages/test_init.py @@ -0,0 +1,114 @@ +"""Tests for the packages component skip_update functionality.""" + +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +from esphome.components.packages import do_packages_pass +from esphome.const import CONF_FILES, CONF_PACKAGES, CONF_REFRESH, CONF_URL +from esphome.util import OrderedDict + + +def test_packages_skip_update_true( + tmp_path: Path, mock_clone_or_update: MagicMock, mock_load_yaml: MagicMock +) -> None: + """Test that packages don't update when skip_update=True.""" + # Set up mock to return our tmp_path + mock_clone_or_update.return_value = (tmp_path, None) + + # Create the test yaml file + test_file = tmp_path / "test.yaml" + test_file.write_text("sensor: []") + + # Set mock_load_yaml to return some valid config + mock_load_yaml.return_value = OrderedDict({"sensor": []}) + + config: dict[str, Any] = { + CONF_PACKAGES: { + "test_package": { + CONF_URL: "https://github.com/test/repo", + CONF_FILES: ["test.yaml"], + CONF_REFRESH: "1d", + } + } + } + + # Call with skip_update=True + do_packages_pass(config, skip_update=True) + + # Verify clone_or_update was called with NEVER_REFRESH + mock_clone_or_update.assert_called_once() + call_args = mock_clone_or_update.call_args + from esphome import git + + assert call_args.kwargs["refresh"] == git.NEVER_REFRESH + + +def test_packages_skip_update_false( + tmp_path: Path, mock_clone_or_update: MagicMock, mock_load_yaml: MagicMock +) -> None: + """Test that packages update when skip_update=False.""" + # Set up mock to return our tmp_path + mock_clone_or_update.return_value = (tmp_path, None) + + # Create the test yaml file + test_file = tmp_path / "test.yaml" + test_file.write_text("sensor: []") + + # Set mock_load_yaml to return some valid config + mock_load_yaml.return_value = OrderedDict({"sensor": []}) + + config: dict[str, Any] = { + CONF_PACKAGES: { + "test_package": { + CONF_URL: "https://github.com/test/repo", + CONF_FILES: ["test.yaml"], + CONF_REFRESH: "1d", + } + } + } + + # Call with skip_update=False (default) + do_packages_pass(config, skip_update=False) + + # Verify clone_or_update was called with actual refresh value + mock_clone_or_update.assert_called_once() + call_args = mock_clone_or_update.call_args + from esphome.core import TimePeriodSeconds + + assert call_args.kwargs["refresh"] == TimePeriodSeconds(days=1) + + +def test_packages_default_no_skip( + tmp_path: Path, mock_clone_or_update: MagicMock, mock_load_yaml: MagicMock +) -> None: + """Test that packages update by default when skip_update not specified.""" + # Set up mock to return our tmp_path + mock_clone_or_update.return_value = (tmp_path, None) + + # Create the test yaml file + test_file = tmp_path / "test.yaml" + test_file.write_text("sensor: []") + + # Set mock_load_yaml to return some valid config + mock_load_yaml.return_value = OrderedDict({"sensor": []}) + + config: dict[str, Any] = { + CONF_PACKAGES: { + "test_package": { + CONF_URL: "https://github.com/test/repo", + CONF_FILES: ["test.yaml"], + CONF_REFRESH: "1d", + } + } + } + + # Call without skip_update parameter + do_packages_pass(config) + + # Verify clone_or_update was called with actual refresh value + mock_clone_or_update.assert_called_once() + call_args = mock_clone_or_update.call_args + from esphome.core import TimePeriodSeconds + + assert call_args.kwargs["refresh"] == TimePeriodSeconds(days=1) diff --git a/tests/component_tests/psram/test_psram.py b/tests/component_tests/psram/test_psram.py new file mode 100644 index 0000000000..3e40a8d192 --- /dev/null +++ b/tests/component_tests/psram/test_psram.py @@ -0,0 +1,194 @@ +"""Tests for PSRAM component.""" + +from typing import Any + +import pytest + +from esphome.components.esp32.const import ( + KEY_VARIANT, + VARIANT_ESP32, + VARIANT_ESP32C2, + VARIANT_ESP32C3, + VARIANT_ESP32C5, + VARIANT_ESP32C6, + VARIANT_ESP32H2, + VARIANT_ESP32P4, + VARIANT_ESP32S2, + VARIANT_ESP32S3, +) +import esphome.config_validation as cv +from esphome.const import CONF_ESPHOME, PlatformFramework +from tests.component_tests.types import SetCoreConfigCallable + +UNSUPPORTED_PSRAM_VARIANTS = [ + VARIANT_ESP32C2, + VARIANT_ESP32C3, + VARIANT_ESP32C5, + VARIANT_ESP32C6, + VARIANT_ESP32H2, +] + +SUPPORTED_PSRAM_VARIANTS = [ + VARIANT_ESP32, + VARIANT_ESP32S2, + VARIANT_ESP32S3, + VARIANT_ESP32P4, +] + + +@pytest.mark.parametrize( + ("config", "error_match"), + [ + pytest.param( + {}, + r"PSRAM is not supported on this chip", + id="psram_not_supported", + ), + ], +) +@pytest.mark.parametrize("variant", UNSUPPORTED_PSRAM_VARIANTS) +def test_psram_configuration_errors_unsupported_variants( + config: Any, + error_match: str, + variant: str, + set_core_config: SetCoreConfigCallable, +) -> None: + set_core_config( + PlatformFramework.ESP32_IDF, + platform_data={KEY_VARIANT: variant}, + full_config={CONF_ESPHOME: {}}, + ) + """Test detection of invalid PSRAM configuration on unsupported variants.""" + from esphome.components.psram import CONFIG_SCHEMA + + with pytest.raises(cv.Invalid, match=error_match): + CONFIG_SCHEMA(config) + + +@pytest.mark.parametrize("variant", SUPPORTED_PSRAM_VARIANTS) +def test_psram_configuration_valid_supported_variants( + variant: str, + set_core_config: SetCoreConfigCallable, +) -> None: + set_core_config( + PlatformFramework.ESP32_IDF, + platform_data={KEY_VARIANT: variant}, + full_config={ + CONF_ESPHOME: {}, + "esp32": { + "variant": variant, + "cpu_frequency": "160MHz", + "framework": {"type": "esp-idf"}, + }, + }, + ) + """Test that PSRAM configuration is valid on supported variants.""" + from esphome.components.psram import CONFIG_SCHEMA, FINAL_VALIDATE_SCHEMA + + # This should not raise an exception + config = CONFIG_SCHEMA({}) + FINAL_VALIDATE_SCHEMA(config) + + +def _setup_psram_final_validation_test( + esp32_config: dict, + set_core_config: SetCoreConfigCallable, + set_component_config: Any, +) -> str: + """Helper function to set up ESP32 configuration for PSRAM final validation tests.""" + # Use ESP32S3 for schema validation to allow all options, then override for final validation + schema_variant = "ESP32S3" + final_variant = esp32_config.get("variant", "ESP32S3") + full_esp32_config = { + "variant": final_variant, + "cpu_frequency": esp32_config.get("cpu_frequency", "240MHz"), + "framework": {"type": "esp-idf"}, + } + + set_core_config( + PlatformFramework.ESP32_IDF, + platform_data={KEY_VARIANT: schema_variant}, + full_config={ + CONF_ESPHOME: {}, + "esp32": full_esp32_config, + }, + ) + set_component_config("esp32", full_esp32_config) + + return final_variant + + +@pytest.mark.parametrize( + ("config", "esp32_config", "expect_error", "error_match"), + [ + pytest.param( + {"speed": "120MHz"}, + {"cpu_frequency": "160MHz"}, + True, + r"PSRAM 120MHz requires 240MHz CPU frequency", + id="120mhz_requires_240mhz_cpu", + ), + pytest.param( + {"mode": "octal"}, + {"variant": "ESP32"}, + True, + r"Octal PSRAM is only supported on ESP32-S3", + id="octal_mode_only_esp32s3", + ), + pytest.param( + {"mode": "quad", "enable_ecc": True}, + {}, + True, + r"ECC is only available in octal mode", + id="ecc_only_in_octal_mode", + ), + pytest.param( + {"speed": "120MHZ"}, + {"cpu_frequency": "240MHZ"}, + False, + None, + id="120mhz_with_240mhz_cpu", + ), + pytest.param( + {"mode": "octal"}, + {"variant": "ESP32S3"}, + False, + None, + id="octal_mode_on_esp32s3", + ), + pytest.param( + {"mode": "octal", "enable_ecc": True}, + {"variant": "ESP32S3"}, + False, + None, + id="ecc_in_octal_mode", + ), + ], +) +def test_psram_final_validation( + config: Any, + esp32_config: dict, + expect_error: bool, + error_match: str | None, + set_core_config: SetCoreConfigCallable, + set_component_config: Any, +) -> None: + """Test PSRAM final validation for both error and valid cases.""" + from esphome.components.psram import CONFIG_SCHEMA, FINAL_VALIDATE_SCHEMA + from esphome.core import CORE + + final_variant = _setup_psram_final_validation_test( + esp32_config, set_core_config, set_component_config + ) + + validated_config = CONFIG_SCHEMA(config) + + # Update CORE variant for final validation + CORE.data["esp32"][KEY_VARIANT] = final_variant + + if expect_error: + with pytest.raises(cv.Invalid, match=error_match): + FINAL_VALIDATE_SCHEMA(validated_config) + else: + # This should not raise an exception + FINAL_VALIDATE_SCHEMA(validated_config) diff --git a/tests/components/analog_threshold/test.nrf52-adafruit.yaml b/tests/components/analog_threshold/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/analog_threshold/test.nrf52-adafruit.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/analog_threshold/test.nrf52-mcumgr.yaml b/tests/components/analog_threshold/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/analog_threshold/test.nrf52-mcumgr.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/api/common.yaml b/tests/components/api/common.yaml index 7ac11e4da6..d87ae56ec2 100644 --- a/tests/components/api/common.yaml +++ b/tests/components/api/common.yaml @@ -10,10 +10,42 @@ esphome: data: message: Button was pressed - homeassistant.tag_scanned: pulse + - homeassistant.action: + action: weather.get_forecasts + data: + entity_id: weather.forecast_home + type: hourly + capture_response: true + on_success: + - lambda: |- + JsonObjectConst next_hour = response["response"]["weather.forecast_home"]["forecast"][0]; + float next_temperature = next_hour["temperature"].as(); + ESP_LOGD("main", "Next hour temperature: %f", next_temperature); + on_error: + - lambda: |- + ESP_LOGE("main", "Action failed with error: %s", error.c_str()); + - homeassistant.action: + action: weather.get_forecasts + data: + entity_id: weather.forecast_home + type: hourly + capture_response: true + response_template: "{{ response['weather.forecast_home']['forecast'][0]['temperature'] }}" + on_success: + - lambda: |- + float temperature = response["response"].as(); + ESP_LOGD("main", "Next hour temperature: %f", temperature); + - homeassistant.action: + action: light.toggle + data: + entity_id: light.demo_light + on_success: + - logger.log: "Toggled demo light" + on_error: + - logger.log: "Failed to toggle demo light" api: port: 8000 - password: pwd reboot_timeout: 0min encryption: key: bOFFzzvfpg5DB94DuBGLXD/hMnhpDKgP9UQyBulwWVU= diff --git a/tests/components/bang_bang/test.nrf52-adafruit.yaml b/tests/components/bang_bang/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/bang_bang/test.nrf52-adafruit.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/bang_bang/test.nrf52-mcumgr.yaml b/tests/components/bang_bang/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/bang_bang/test.nrf52-mcumgr.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/deep_sleep/common-esp32-all.yaml b/tests/components/deep_sleep/common-esp32-all.yaml new file mode 100644 index 0000000000..b97eec76b9 --- /dev/null +++ b/tests/components/deep_sleep/common-esp32-all.yaml @@ -0,0 +1,14 @@ +deep_sleep: + run_duration: + default: 10s + gpio_wakeup_reason: 30s + touch_wakeup_reason: 15s + sleep_duration: 50s + wakeup_pin: ${wakeup_pin} + wakeup_pin_mode: INVERT_WAKEUP + esp32_ext1_wakeup: + pins: + - number: GPIO2 + - number: GPIO13 + mode: ANY_HIGH + touch_wakeup: true diff --git a/tests/components/deep_sleep/common-esp32-ext1.yaml b/tests/components/deep_sleep/common-esp32-ext1.yaml new file mode 100644 index 0000000000..9ed4279a33 --- /dev/null +++ b/tests/components/deep_sleep/common-esp32-ext1.yaml @@ -0,0 +1,12 @@ +deep_sleep: + run_duration: + default: 10s + gpio_wakeup_reason: 30s + sleep_duration: 50s + wakeup_pin: ${wakeup_pin} + wakeup_pin_mode: INVERT_WAKEUP + esp32_ext1_wakeup: + pins: + - number: GPIO2 + - number: GPIO5 + mode: ANY_HIGH diff --git a/tests/components/deep_sleep/test.esp32-c6-idf.yaml b/tests/components/deep_sleep/test.esp32-c6-idf.yaml index 10c17af0f5..11abe70711 100644 --- a/tests/components/deep_sleep/test.esp32-c6-idf.yaml +++ b/tests/components/deep_sleep/test.esp32-c6-idf.yaml @@ -2,4 +2,4 @@ substitutions: wakeup_pin: GPIO4 <<: !include common.yaml -<<: !include common-esp32.yaml +<<: !include common-esp32-ext1.yaml diff --git a/tests/components/deep_sleep/test.esp32-idf.yaml b/tests/components/deep_sleep/test.esp32-idf.yaml index 10c17af0f5..e45eb08349 100644 --- a/tests/components/deep_sleep/test.esp32-idf.yaml +++ b/tests/components/deep_sleep/test.esp32-idf.yaml @@ -2,4 +2,4 @@ substitutions: wakeup_pin: GPIO4 <<: !include common.yaml -<<: !include common-esp32.yaml +<<: !include common-esp32-all.yaml diff --git a/tests/components/deep_sleep/test.esp32-s2-idf.yaml b/tests/components/deep_sleep/test.esp32-s2-idf.yaml index 10c17af0f5..e45eb08349 100644 --- a/tests/components/deep_sleep/test.esp32-s2-idf.yaml +++ b/tests/components/deep_sleep/test.esp32-s2-idf.yaml @@ -2,4 +2,4 @@ substitutions: wakeup_pin: GPIO4 <<: !include common.yaml -<<: !include common-esp32.yaml +<<: !include common-esp32-all.yaml diff --git a/tests/components/deep_sleep/test.esp32-s3-idf.yaml b/tests/components/deep_sleep/test.esp32-s3-idf.yaml index 10c17af0f5..e45eb08349 100644 --- a/tests/components/deep_sleep/test.esp32-s3-idf.yaml +++ b/tests/components/deep_sleep/test.esp32-s3-idf.yaml @@ -2,4 +2,4 @@ substitutions: wakeup_pin: GPIO4 <<: !include common.yaml -<<: !include common-esp32.yaml +<<: !include common-esp32-all.yaml diff --git a/tests/components/ektf2232/common.yaml b/tests/components/ektf2232/common.yaml index 3271839fd4..91f09b4710 100644 --- a/tests/components/ektf2232/common.yaml +++ b/tests/components/ektf2232/common.yaml @@ -7,7 +7,7 @@ display: - platform: ssd1306_i2c id: ssd1306_display model: SSD1306_128X64 - reset_pin: ${reset_pin} + reset_pin: ${display_reset_pin} pages: - id: page1 lambda: |- @@ -16,7 +16,7 @@ display: touchscreen: - platform: ektf2232 interrupt_pin: ${interrupt_pin} - rts_pin: ${rts_pin} + reset_pin: ${touch_reset_pin} display: ssd1306_display on_touch: - logger.log: diff --git a/tests/components/ektf2232/test.esp32-ard.yaml b/tests/components/ektf2232/test.esp32-ard.yaml index b8f491c0c3..7d3f2ca7a2 100644 --- a/tests/components/ektf2232/test.esp32-ard.yaml +++ b/tests/components/ektf2232/test.esp32-ard.yaml @@ -1,8 +1,8 @@ substitutions: scl_pin: GPIO16 sda_pin: GPIO17 - reset_pin: GPIO13 + display_reset_pin: GPIO13 interrupt_pin: GPIO14 - rts_pin: GPIO15 + touch_reset_pin: GPIO15 <<: !include common.yaml diff --git a/tests/components/ektf2232/test.esp32-c3-ard.yaml b/tests/components/ektf2232/test.esp32-c3-ard.yaml index 9f2149b9d7..4d793a3242 100644 --- a/tests/components/ektf2232/test.esp32-c3-ard.yaml +++ b/tests/components/ektf2232/test.esp32-c3-ard.yaml @@ -1,8 +1,8 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 - reset_pin: GPIO3 + display_reset_pin: GPIO3 interrupt_pin: GPIO6 - rts_pin: GPIO7 + touch_reset_pin: GPIO7 <<: !include common.yaml diff --git a/tests/components/ektf2232/test.esp32-c3-idf.yaml b/tests/components/ektf2232/test.esp32-c3-idf.yaml index 9f2149b9d7..4d793a3242 100644 --- a/tests/components/ektf2232/test.esp32-c3-idf.yaml +++ b/tests/components/ektf2232/test.esp32-c3-idf.yaml @@ -1,8 +1,8 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 - reset_pin: GPIO3 + display_reset_pin: GPIO3 interrupt_pin: GPIO6 - rts_pin: GPIO7 + touch_reset_pin: GPIO7 <<: !include common.yaml diff --git a/tests/components/ektf2232/test.esp32-idf.yaml b/tests/components/ektf2232/test.esp32-idf.yaml index b8f491c0c3..7d3f2ca7a2 100644 --- a/tests/components/ektf2232/test.esp32-idf.yaml +++ b/tests/components/ektf2232/test.esp32-idf.yaml @@ -1,8 +1,8 @@ substitutions: scl_pin: GPIO16 sda_pin: GPIO17 - reset_pin: GPIO13 + display_reset_pin: GPIO13 interrupt_pin: GPIO14 - rts_pin: GPIO15 + touch_reset_pin: GPIO15 <<: !include common.yaml diff --git a/tests/components/ektf2232/test.esp8266-ard.yaml b/tests/components/ektf2232/test.esp8266-ard.yaml index 6d91a6533f..a87e9dfd45 100644 --- a/tests/components/ektf2232/test.esp8266-ard.yaml +++ b/tests/components/ektf2232/test.esp8266-ard.yaml @@ -1,8 +1,8 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 - reset_pin: GPIO3 + display_reset_pin: GPIO3 interrupt_pin: GPIO12 - rts_pin: GPIO13 + touch_reset_pin: GPIO13 <<: !include common.yaml diff --git a/tests/components/ektf2232/test.rp2040-ard.yaml b/tests/components/ektf2232/test.rp2040-ard.yaml index 9f2149b9d7..4d793a3242 100644 --- a/tests/components/ektf2232/test.rp2040-ard.yaml +++ b/tests/components/ektf2232/test.rp2040-ard.yaml @@ -1,8 +1,8 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 - reset_pin: GPIO3 + display_reset_pin: GPIO3 interrupt_pin: GPIO6 - rts_pin: GPIO7 + touch_reset_pin: GPIO7 <<: !include common.yaml diff --git a/tests/components/epaper_spi/test.esp32-s3-idf.yaml b/tests/components/epaper_spi/test.esp32-s3-idf.yaml new file mode 100644 index 0000000000..3d8d62a7ca --- /dev/null +++ b/tests/components/epaper_spi/test.esp32-s3-idf.yaml @@ -0,0 +1,15 @@ +spi: + clk_pin: GPIO7 + mosi_pin: GPIO9 + +display: + - platform: epaper_spi + model: 7.3in-spectra-e6 + cs_pin: GPIO5 + dc_pin: GPIO17 + reset_pin: GPIO16 + busy_pin: GPIO4 + rotation: 0 + update_interval: 60s + lambda: |- + it.circle(64, 64, 50, Color::BLACK); diff --git a/tests/components/esp32_can/test.esp32-c6-idf.yaml b/tests/components/esp32_can/test.esp32-c6-idf.yaml new file mode 100644 index 0000000000..6ef730c378 --- /dev/null +++ b/tests/components/esp32_can/test.esp32-c6-idf.yaml @@ -0,0 +1,89 @@ +esphome: + on_boot: + then: + - canbus.send: + # Extended ID explicit + canbus_id: esp32_internal_can + use_extended_id: true + can_id: 0x100 + data: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08] + - canbus.send: + # Standard ID by default + canbus_id: esp32_internal_can + can_id: 0x100 + data: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08] + - canbus.send: + # Extended ID explicit + canbus_id: esp32_internal_can_2 + use_extended_id: true + can_id: 0x100 + data: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08] + - canbus.send: + # Standard ID by default + canbus_id: esp32_internal_can_2 + can_id: 0x100 + data: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08] + +canbus: + - platform: esp32_can + id: esp32_internal_can + rx_pin: GPIO8 + tx_pin: GPIO7 + can_id: 4 + bit_rate: 50kbps + on_frame: + - can_id: 500 + then: + - lambda: |- + std::string b(x.begin(), x.end()); + ESP_LOGD("canbus1", "canid 500 %s", b.c_str() ); + - can_id: 0b00000000000000000000001000000 + can_id_mask: 0b11111000000000011111111000000 + use_extended_id: true + then: + - lambda: |- + auto pdo_id = can_id >> 14; + switch (pdo_id) + { + case 117: + ESP_LOGD("canbus1", "exhaust_fan_duty"); + break; + case 118: + ESP_LOGD("canbus1", "supply_fan_duty"); + break; + case 119: + ESP_LOGD("canbus1", "supply_fan_flow"); + break; + // to be continued... + } + - platform: esp32_can + id: esp32_internal_can_2 + rx_pin: GPIO10 + tx_pin: GPIO9 + can_id: 4 + bit_rate: 50kbps + on_frame: + - can_id: 500 + then: + - lambda: |- + std::string b(x.begin(), x.end()); + ESP_LOGD("canbus2", "canid 500 %s", b.c_str() ); + - can_id: 0b00000000000000000000001000000 + can_id_mask: 0b11111000000000011111111000000 + use_extended_id: true + then: + - lambda: |- + auto pdo_id = can_id >> 14; + switch (pdo_id) + { + case 117: + ESP_LOGD("canbus2", "exhaust_fan_duty"); + break; + case 118: + ESP_LOGD("canbus2", "supply_fan_duty"); + break; + case 119: + ESP_LOGD("canbus2", "supply_fan_flow"); + break; + // to be continued... + } diff --git a/tests/components/ethernet/common-dm9051.yaml b/tests/components/ethernet/common-dm9051.yaml index c878ca6e59..4526e7732d 100644 --- a/tests/components/ethernet/common-dm9051.yaml +++ b/tests/components/ethernet/common-dm9051.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-dp83848.yaml b/tests/components/ethernet/common-dp83848.yaml index 140c7d0d1b..7cedfeaf08 100644 --- a/tests/components/ethernet/common-dp83848.yaml +++ b/tests/components/ethernet/common-dp83848.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-ip101.yaml b/tests/components/ethernet/common-ip101.yaml index b5589220de..2dece15171 100644 --- a/tests/components/ethernet/common-ip101.yaml +++ b/tests/components/ethernet/common-ip101.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-jl1101.yaml b/tests/components/ethernet/common-jl1101.yaml index 2ada9495a0..b6ea884102 100644 --- a/tests/components/ethernet/common-jl1101.yaml +++ b/tests/components/ethernet/common-jl1101.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-ksz8081.yaml b/tests/components/ethernet/common-ksz8081.yaml index 7da8adb09a..f70d42319e 100644 --- a/tests/components/ethernet/common-ksz8081.yaml +++ b/tests/components/ethernet/common-ksz8081.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-ksz8081rna.yaml b/tests/components/ethernet/common-ksz8081rna.yaml index df04f06132..18efdae0e1 100644 --- a/tests/components/ethernet/common-ksz8081rna.yaml +++ b/tests/components/ethernet/common-ksz8081rna.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-lan8670.yaml b/tests/components/ethernet/common-lan8670.yaml new file mode 100644 index 0000000000..ec2f24273d --- /dev/null +++ b/tests/components/ethernet/common-lan8670.yaml @@ -0,0 +1,14 @@ +ethernet: + type: LAN8670 + mdc_pin: 23 + mdio_pin: 25 + clk: + pin: 0 + mode: CLK_EXT_IN + phy_addr: 0 + power_pin: 26 + manual_ip: + static_ip: 192.168.178.56 + gateway: 192.168.178.1 + subnet: 255.255.255.0 + domain: .local diff --git a/tests/components/ethernet/common-lan8720.yaml b/tests/components/ethernet/common-lan8720.yaml index f227752f42..204c1d9210 100644 --- a/tests/components/ethernet/common-lan8720.yaml +++ b/tests/components/ethernet/common-lan8720.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-rtl8201.yaml b/tests/components/ethernet/common-rtl8201.yaml index 7c9c9d913c..8b9f2b86f2 100644 --- a/tests/components/ethernet/common-rtl8201.yaml +++ b/tests/components/ethernet/common-rtl8201.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/common-w5500.yaml b/tests/components/ethernet/common-w5500.yaml index 76661a75c3..b3e96f000d 100644 --- a/tests/components/ethernet/common-w5500.yaml +++ b/tests/components/ethernet/common-w5500.yaml @@ -12,3 +12,4 @@ ethernet: gateway: 192.168.178.1 subnet: 255.255.255.0 domain: .local + mac_address: "02:AA:BB:CC:DD:01" diff --git a/tests/components/ethernet/test-lan8670.esp32-ard.yaml b/tests/components/ethernet/test-lan8670.esp32-ard.yaml new file mode 100644 index 0000000000..914a06ae88 --- /dev/null +++ b/tests/components/ethernet/test-lan8670.esp32-ard.yaml @@ -0,0 +1 @@ +<<: !include common-lan8670.yaml diff --git a/tests/components/ethernet/test-lan8670.esp32-idf.yaml b/tests/components/ethernet/test-lan8670.esp32-idf.yaml new file mode 100644 index 0000000000..914a06ae88 --- /dev/null +++ b/tests/components/ethernet/test-lan8670.esp32-idf.yaml @@ -0,0 +1 @@ +<<: !include common-lan8670.yaml diff --git a/tests/components/lm75b/common.yaml b/tests/components/lm75b/common.yaml new file mode 100644 index 0000000000..e451c2f679 --- /dev/null +++ b/tests/components/lm75b/common.yaml @@ -0,0 +1,9 @@ +i2c: + - id: i2c_lm75b + scl: ${scl_pin} + sda: ${sda_pin} + +sensor: + - platform: lm75b + name: LM75B Temperature + update_interval: 30s diff --git a/tests/components/lm75b/test.esp32-ard.yaml b/tests/components/lm75b/test.esp32-ard.yaml new file mode 100644 index 0000000000..43264df633 --- /dev/null +++ b/tests/components/lm75b/test.esp32-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO15 + sda_pin: GPIO13 + +<<: !include common.yaml diff --git a/tests/components/lm75b/test.esp32-c3-ard.yaml b/tests/components/lm75b/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/lm75b/test.esp32-c3-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/lm75b/test.esp32-c3-idf.yaml b/tests/components/lm75b/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/lm75b/test.esp32-c3-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/lm75b/test.esp32-idf.yaml b/tests/components/lm75b/test.esp32-idf.yaml new file mode 100644 index 0000000000..43264df633 --- /dev/null +++ b/tests/components/lm75b/test.esp32-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO15 + sda_pin: GPIO13 + +<<: !include common.yaml diff --git a/tests/components/lm75b/test.esp8266-ard.yaml b/tests/components/lm75b/test.esp8266-ard.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/lm75b/test.esp8266-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/lm75b/test.rp2040-ard.yaml b/tests/components/lm75b/test.rp2040-ard.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/lm75b/test.rp2040-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/logger/common-default_uart.yaml b/tests/components/logger/common-default_uart.yaml index e8b56043eb..7939a5f9c5 100644 --- a/tests/components/logger/common-default_uart.yaml +++ b/tests/components/logger/common-default_uart.yaml @@ -6,11 +6,16 @@ esphome: format: "Warning: Logger level is %d" args: [id(logger_id).get_log_level()] - logger.set_level: WARN + - logger.set_level: + level: ERROR + tag: mqtt.client logger: id: logger_id level: DEBUG initial_level: INFO + logs: + mqtt.component: WARN select: - platform: logger diff --git a/tests/components/modbus_controller/common.yaml b/tests/components/modbus_controller/common.yaml index 7d342ee353..c2b5ab737f 100644 --- a/tests/components/modbus_controller/common.yaml +++ b/tests/components/modbus_controller/common.yaml @@ -45,6 +45,22 @@ modbus_controller: printf("address=%d, value=%d", x); return true; max_cmd_retries: 0 + - id: modbus_controller4 + modbus_id: mod_bus2 + address: 0x4 + server_courtesy_response: + enabled: true + register_last_address: 100 + register_value: 0 + server_registers: + - address: 0x0001 + value_type: U_WORD + read_lambda: |- + return 0x8; + - address: 0x0005 + value_type: U_WORD + read_lambda: |- + return (random_uint32() % 100); binary_sensor: - platform: modbus_controller modbus_controller_id: modbus_controller1 diff --git a/tests/components/network/test-ipv6.bk72xx-ard.yaml b/tests/components/network/test-ipv6.bk72xx-ard.yaml index d0c4bbfcb9..da1324b17e 100644 --- a/tests/components/network/test-ipv6.bk72xx-ard.yaml +++ b/tests/components/network/test-ipv6.bk72xx-ard.yaml @@ -1,8 +1,4 @@ substitutions: network_enable_ipv6: "true" -bk72xx: - framework: - version: 1.7.0 - <<: !include common.yaml diff --git a/tests/components/network/test.bk72xx-ard.yaml b/tests/components/network/test.bk72xx-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/network/test.bk72xx-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/qmc5883l/common.yaml b/tests/components/qmc5883l/common.yaml index 5d8ac73b4f..c8ad4ba006 100644 --- a/tests/components/qmc5883l/common.yaml +++ b/tests/components/qmc5883l/common.yaml @@ -17,5 +17,7 @@ sensor: temperature: name: QMC5883L Temperature range: 800uT + data_rate: 200Hz oversampling: 256x update_interval: 15s + drdy_pin: ${drdy_pin} diff --git a/tests/components/qmc5883l/test.esp32-ard.yaml b/tests/components/qmc5883l/test.esp32-ard.yaml index 63c3bd6afd..2cf2041501 100644 --- a/tests/components/qmc5883l/test.esp32-ard.yaml +++ b/tests/components/qmc5883l/test.esp32-ard.yaml @@ -1,5 +1,6 @@ substitutions: scl_pin: GPIO16 sda_pin: GPIO17 + drdy_pin: GPIO18 <<: !include common.yaml diff --git a/tests/components/qmc5883l/test.esp32-c3-ard.yaml b/tests/components/qmc5883l/test.esp32-c3-ard.yaml index ee2c29ca4e..677501d15a 100644 --- a/tests/components/qmc5883l/test.esp32-c3-ard.yaml +++ b/tests/components/qmc5883l/test.esp32-c3-ard.yaml @@ -1,5 +1,6 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 + drdy_pin: GPIO6 <<: !include common.yaml diff --git a/tests/components/qmc5883l/test.esp32-c3-idf.yaml b/tests/components/qmc5883l/test.esp32-c3-idf.yaml index ee2c29ca4e..677501d15a 100644 --- a/tests/components/qmc5883l/test.esp32-c3-idf.yaml +++ b/tests/components/qmc5883l/test.esp32-c3-idf.yaml @@ -1,5 +1,6 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 + drdy_pin: GPIO6 <<: !include common.yaml diff --git a/tests/components/qmc5883l/test.esp32-idf.yaml b/tests/components/qmc5883l/test.esp32-idf.yaml index 63c3bd6afd..2cf2041501 100644 --- a/tests/components/qmc5883l/test.esp32-idf.yaml +++ b/tests/components/qmc5883l/test.esp32-idf.yaml @@ -1,5 +1,6 @@ substitutions: scl_pin: GPIO16 sda_pin: GPIO17 + drdy_pin: GPIO18 <<: !include common.yaml diff --git a/tests/components/qmc5883l/test.esp8266-ard.yaml b/tests/components/qmc5883l/test.esp8266-ard.yaml index ee2c29ca4e..65b0fd75d9 100644 --- a/tests/components/qmc5883l/test.esp8266-ard.yaml +++ b/tests/components/qmc5883l/test.esp8266-ard.yaml @@ -1,5 +1,6 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 + drdy_pin: GPIO2 <<: !include common.yaml diff --git a/tests/components/qmc5883l/test.rp2040-ard.yaml b/tests/components/qmc5883l/test.rp2040-ard.yaml index ee2c29ca4e..65b0fd75d9 100644 --- a/tests/components/qmc5883l/test.rp2040-ard.yaml +++ b/tests/components/qmc5883l/test.rp2040-ard.yaml @@ -1,5 +1,6 @@ substitutions: scl_pin: GPIO5 sda_pin: GPIO4 + drdy_pin: GPIO2 <<: !include common.yaml diff --git a/tests/components/remote_receiver/test.esp32-idf.yaml b/tests/components/remote_receiver/test.esp32-idf.yaml index 10dd767598..cdeeab2c4a 100644 --- a/tests/components/remote_receiver/test.esp32-idf.yaml +++ b/tests/components/remote_receiver/test.esp32-idf.yaml @@ -1,6 +1,8 @@ substitutions: pin: GPIO2 clock_resolution: "2000000" + carrier_duty_percent: "25" + carrier_frequency: "30000" filter_symbols: "2" receive_symbols: "4" rmt_symbols: "64" diff --git a/tests/components/restart/test.nrf52-adafruit.yaml b/tests/components/restart/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/restart/test.nrf52-adafruit.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/restart/test.nrf52-mcumgr.yaml b/tests/components/restart/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/restart/test.nrf52-mcumgr.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/script/test.nrf52-adafruit.yaml b/tests/components/script/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/script/test.nrf52-adafruit.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/script/test.nrf52-mcumgr.yaml b/tests/components/script/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/script/test.nrf52-mcumgr.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sha256/common.yaml b/tests/components/sha256/common.yaml new file mode 100644 index 0000000000..fa884c1958 --- /dev/null +++ b/tests/components/sha256/common.yaml @@ -0,0 +1,32 @@ +esphome: + on_boot: + - lambda: |- + // Test SHA256 functionality + #ifdef USE_SHA256 + using esphome::sha256::SHA256; + SHA256 hasher; + hasher.init(); + + // Test with "Hello World" - known SHA256 + const char* test_string = "Hello World"; + hasher.add(test_string, strlen(test_string)); + hasher.calculate(); + + char hex_output[65]; + hasher.get_hex(hex_output); + hex_output[64] = '\0'; + + ESP_LOGD("SHA256", "SHA256('Hello World') = %s", hex_output); + + // Expected: a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e + const char* expected = "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e"; + if (strcmp(hex_output, expected) == 0) { + ESP_LOGI("SHA256", "Test PASSED"); + } else { + ESP_LOGE("SHA256", "Test FAILED. Expected %s", expected); + } + #else + ESP_LOGW("SHA256", "SHA256 not available on this platform"); + #endif + +sha256: diff --git a/tests/components/sha256/test.bk72xx-ard.yaml b/tests/components/sha256/test.bk72xx-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sha256/test.bk72xx-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sha256/test.esp32-idf.yaml b/tests/components/sha256/test.esp32-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sha256/test.esp32-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sha256/test.esp8266-ard.yaml b/tests/components/sha256/test.esp8266-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sha256/test.esp8266-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sha256/test.host.yaml b/tests/components/sha256/test.host.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sha256/test.host.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sha256/test.rp2040-ard.yaml b/tests/components/sha256/test.rp2040-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sha256/test.rp2040-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/spi/test.esp32-s3-ard.yaml b/tests/components/spi/test.esp32-s3-ard.yaml new file mode 100644 index 0000000000..e4d4f20586 --- /dev/null +++ b/tests/components/spi/test.esp32-s3-ard.yaml @@ -0,0 +1,13 @@ +spi: + - id: three_spi + interface: spi3 + clk_pin: + number: 47 + mosi_pin: + number: 40 + - id: hw_spi + interface: hardware + clk_pin: + number: 0 + miso_pin: + number: 41 diff --git a/tests/components/sprinkler/test.nrf52-adafruit.yaml b/tests/components/sprinkler/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sprinkler/test.nrf52-adafruit.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sprinkler/test.nrf52-mcumgr.yaml b/tests/components/sprinkler/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/sprinkler/test.nrf52-mcumgr.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/sx126x/common.yaml b/tests/components/sx126x/common.yaml index 3f888c3ce4..05db2ef812 100644 --- a/tests/components/sx126x/common.yaml +++ b/tests/components/sx126x/common.yaml @@ -11,6 +11,10 @@ sx126x: pa_power: 3 bandwidth: 125_0kHz crc_enable: true + crc_initial: 0x1D0F + crc_polynomial: 0x1021 + crc_size: 2 + crc_inverted: true frequency: 433920000 modulation: LORA rx_start: true diff --git a/tests/components/template/common.yaml b/tests/components/template/common.yaml index ae7dc98e57..efbb83ee06 100644 --- a/tests/components/template/common.yaml +++ b/tests/components/template/common.yaml @@ -341,6 +341,7 @@ datetime: time: - platform: sntp # Required for datetime + id: sntp_time wifi: # Required for sntp time ap: diff --git a/tests/components/template/test.nrf52-adafruit.yaml b/tests/components/template/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..6a8c01560a --- /dev/null +++ b/tests/components/template/test.nrf52-adafruit.yaml @@ -0,0 +1,6 @@ +packages: !include common.yaml + +time: + - id: !remove sntp_time + +wifi: !remove diff --git a/tests/components/template/test.nrf52-mcumgr.yaml b/tests/components/template/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..6a8c01560a --- /dev/null +++ b/tests/components/template/test.nrf52-mcumgr.yaml @@ -0,0 +1,6 @@ +packages: !include common.yaml + +time: + - id: !remove sntp_time + +wifi: !remove diff --git a/tests/components/thermostat/test.nrf52-adafruit.yaml b/tests/components/thermostat/test.nrf52-adafruit.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/thermostat/test.nrf52-adafruit.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/thermostat/test.nrf52-mcumgr.yaml b/tests/components/thermostat/test.nrf52-mcumgr.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/thermostat/test.nrf52-mcumgr.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-c3-ard.yaml b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-c3-ard.yaml index 2a73826c51..602766869c 100644 --- a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-c3-ard.yaml +++ b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-c3-ard.yaml @@ -14,17 +14,23 @@ uart: - id: uart_1 tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 - id: uart_2 - tx_pin: 6 - rx_pin: 7 + tx_pin: 7 + rx_pin: 8 + flow_control_pin: 9 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-ard.yaml b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-ard.yaml index 2a73826c51..602766869c 100644 --- a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-ard.yaml +++ b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-ard.yaml @@ -14,17 +14,23 @@ uart: - id: uart_1 tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 - id: uart_2 - tx_pin: 6 - rx_pin: 7 + tx_pin: 7 + rx_pin: 8 + flow_control_pin: 9 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-idf.yaml b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-idf.yaml index 2a73826c51..602766869c 100644 --- a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-idf.yaml +++ b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s2-idf.yaml @@ -14,17 +14,23 @@ uart: - id: uart_1 tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 - id: uart_2 - tx_pin: 6 - rx_pin: 7 + tx_pin: 7 + rx_pin: 8 + flow_control_pin: 9 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s3-ard.yaml b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s3-ard.yaml index 2a73826c51..4af255e1e4 100644 --- a/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s3-ard.yaml +++ b/tests/components/uart/test-uart_max_with_usb_cdc.esp32-s3-ard.yaml @@ -14,17 +14,35 @@ uart: - id: uart_1 tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 - id: uart_2 - tx_pin: 6 - rx_pin: 7 + tx_pin: 7 + rx_pin: 8 + flow_control_pin: 9 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 + parity: EVEN + stop_bits: 2 + + - id: uart_3 + tx_pin: 10 + rx_pin: 11 + flow_control_pin: 12 + baud_rate: 9600 + data_bits: 8 + rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-c3-idf.yaml b/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-c3-idf.yaml index e0a07dde91..3151403896 100644 --- a/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-c3-idf.yaml +++ b/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-c3-idf.yaml @@ -14,17 +14,23 @@ uart: - id: uart_1 tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 - id: uart_2 - tx_pin: 6 - rx_pin: 7 + tx_pin: 7 + rx_pin: 8 + flow_control_pin: 9 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-s3-idf.yaml b/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-s3-idf.yaml index e0a07dde91..88a806eb92 100644 --- a/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-s3-idf.yaml +++ b/tests/components/uart/test-uart_max_with_usb_serial_jtag.esp32-s3-idf.yaml @@ -14,17 +14,35 @@ uart: - id: uart_1 tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 - id: uart_2 - tx_pin: 6 - rx_pin: 7 + tx_pin: 7 + rx_pin: 8 + flow_control_pin: 9 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 + parity: EVEN + stop_bits: 2 + + - id: uart_3 + tx_pin: 10 + rx_pin: 11 + flow_control_pin: 12 + baud_rate: 9600 + data_bits: 8 + rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test.esp32-ard.yaml b/tests/components/uart/test.esp32-ard.yaml index bef5b460ab..a201185309 100644 --- a/tests/components/uart/test.esp32-ard.yaml +++ b/tests/components/uart/test.esp32-ard.yaml @@ -8,8 +8,11 @@ uart: - id: uart_uart tx_pin: 17 rx_pin: 16 + flow_control_pin: 4 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test.esp32-c3-ard.yaml b/tests/components/uart/test.esp32-c3-ard.yaml index 09178f1663..b053290a8b 100644 --- a/tests/components/uart/test.esp32-c3-ard.yaml +++ b/tests/components/uart/test.esp32-c3-ard.yaml @@ -8,8 +8,11 @@ uart: - id: uart_uart tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test.esp32-c3-idf.yaml b/tests/components/uart/test.esp32-c3-idf.yaml index 09178f1663..b053290a8b 100644 --- a/tests/components/uart/test.esp32-c3-idf.yaml +++ b/tests/components/uart/test.esp32-c3-idf.yaml @@ -8,8 +8,11 @@ uart: - id: uart_uart tx_pin: 4 rx_pin: 5 + flow_control_pin: 6 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/uart/test.esp32-idf.yaml b/tests/components/uart/test.esp32-idf.yaml index 5a0ed7eba7..5634c5c6f6 100644 --- a/tests/components/uart/test.esp32-idf.yaml +++ b/tests/components/uart/test.esp32-idf.yaml @@ -8,9 +8,12 @@ uart: - id: uart_uart tx_pin: 17 rx_pin: 16 + flow_control_pin: 4 baud_rate: 9600 data_bits: 8 rx_buffer_size: 512 + rx_full_threshold: 10 + rx_timeout: 1 parity: EVEN stop_bits: 2 diff --git a/tests/components/wts01/common.yaml b/tests/components/wts01/common.yaml new file mode 100644 index 0000000000..c26cc3e475 --- /dev/null +++ b/tests/components/wts01/common.yaml @@ -0,0 +1,7 @@ +uart: + rx_pin: ${rx_pin} + baud_rate: 9600 + +sensor: + - platform: wts01 + id: wts01_sensor diff --git a/tests/components/wts01/test.esp32-ard.yaml b/tests/components/wts01/test.esp32-ard.yaml new file mode 100644 index 0000000000..4904e1f54f --- /dev/null +++ b/tests/components/wts01/test.esp32-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO16 + rx_pin: GPIO17 + +<<: !include common.yaml diff --git a/tests/components/wts01/test.esp32-c3-ard.yaml b/tests/components/wts01/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..00cec5b3b8 --- /dev/null +++ b/tests/components/wts01/test.esp32-c3-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO6 + rx_pin: GPIO7 + +<<: !include common.yaml diff --git a/tests/components/wts01/test.esp32-c3-idf.yaml b/tests/components/wts01/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..00cec5b3b8 --- /dev/null +++ b/tests/components/wts01/test.esp32-c3-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO6 + rx_pin: GPIO7 + +<<: !include common.yaml diff --git a/tests/components/wts01/test.esp32-idf.yaml b/tests/components/wts01/test.esp32-idf.yaml new file mode 100644 index 0000000000..4904e1f54f --- /dev/null +++ b/tests/components/wts01/test.esp32-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO16 + rx_pin: GPIO17 + +<<: !include common.yaml diff --git a/tests/components/wts01/test.esp8266-ard.yaml b/tests/components/wts01/test.esp8266-ard.yaml new file mode 100644 index 0000000000..3b44f9c9c3 --- /dev/null +++ b/tests/components/wts01/test.esp8266-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO1 + rx_pin: GPIO3 + +<<: !include common.yaml diff --git a/tests/components/wts01/test.rp2040-ard.yaml b/tests/components/wts01/test.rp2040-ard.yaml new file mode 100644 index 0000000000..16b2a4b006 --- /dev/null +++ b/tests/components/wts01/test.rp2040-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO0 + rx_pin: GPIO1 + +<<: !include common.yaml diff --git a/tests/components/zwave_proxy/common.yaml b/tests/components/zwave_proxy/common.yaml new file mode 100644 index 0000000000..08092ebe55 --- /dev/null +++ b/tests/components/zwave_proxy/common.yaml @@ -0,0 +1,15 @@ +wifi: + ssid: MySSID + password: password1 + power_save_mode: none + +uart: + - id: uart_zwave_proxy + tx_pin: ${tx_pin} + rx_pin: ${rx_pin} + baud_rate: 115200 + +api: + +zwave_proxy: + id: zw_proxy diff --git a/tests/components/zwave_proxy/test.esp32-ard.yaml b/tests/components/zwave_proxy/test.esp32-ard.yaml new file mode 100644 index 0000000000..f486544afa --- /dev/null +++ b/tests/components/zwave_proxy/test.esp32-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO17 + rx_pin: GPIO16 + +<<: !include common.yaml diff --git a/tests/components/zwave_proxy/test.esp32-c3-ard.yaml b/tests/components/zwave_proxy/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..b516342f3b --- /dev/null +++ b/tests/components/zwave_proxy/test.esp32-c3-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO4 + rx_pin: GPIO5 + +<<: !include common.yaml diff --git a/tests/components/zwave_proxy/test.esp32-c3-idf.yaml b/tests/components/zwave_proxy/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..b516342f3b --- /dev/null +++ b/tests/components/zwave_proxy/test.esp32-c3-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO4 + rx_pin: GPIO5 + +<<: !include common.yaml diff --git a/tests/components/zwave_proxy/test.esp32-idf.yaml b/tests/components/zwave_proxy/test.esp32-idf.yaml new file mode 100644 index 0000000000..f486544afa --- /dev/null +++ b/tests/components/zwave_proxy/test.esp32-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO17 + rx_pin: GPIO16 + +<<: !include common.yaml diff --git a/tests/components/zwave_proxy/test.esp8266-ard.yaml b/tests/components/zwave_proxy/test.esp8266-ard.yaml new file mode 100644 index 0000000000..b516342f3b --- /dev/null +++ b/tests/components/zwave_proxy/test.esp8266-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO4 + rx_pin: GPIO5 + +<<: !include common.yaml diff --git a/tests/components/zwave_proxy/test.rp2040-ard.yaml b/tests/components/zwave_proxy/test.rp2040-ard.yaml new file mode 100644 index 0000000000..b516342f3b --- /dev/null +++ b/tests/components/zwave_proxy/test.rp2040-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + tx_pin: GPIO4 + rx_pin: GPIO5 + +<<: !include common.yaml diff --git a/tests/dashboard/conftest.py b/tests/dashboard/conftest.py new file mode 100644 index 0000000000..f95adef749 --- /dev/null +++ b/tests/dashboard/conftest.py @@ -0,0 +1,43 @@ +"""Common fixtures for dashboard tests.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, Mock + +import pytest +import pytest_asyncio + +from esphome.dashboard.core import ESPHomeDashboard +from esphome.dashboard.entries import DashboardEntries + + +@pytest.fixture +def mock_settings(tmp_path: Path) -> MagicMock: + """Create mock dashboard settings.""" + settings = MagicMock() + settings.config_dir = str(tmp_path) + settings.absolute_config_dir = tmp_path + return settings + + +@pytest.fixture +def mock_dashboard(mock_settings: MagicMock) -> Mock: + """Create a mock dashboard.""" + dashboard = Mock(spec=ESPHomeDashboard) + dashboard.settings = mock_settings + dashboard.entries = Mock() + dashboard.entries.async_all.return_value = [] + dashboard.stop_event = Mock() + dashboard.stop_event.is_set.return_value = True + dashboard.ping_request = Mock() + dashboard.ignored_devices = set() + dashboard.bus = Mock() + dashboard.bus.async_fire = Mock() + return dashboard + + +@pytest_asyncio.fixture +async def dashboard_entries(mock_dashboard: Mock) -> DashboardEntries: + """Create a DashboardEntries instance for testing.""" + return DashboardEntries(mock_dashboard) diff --git a/tests/dashboard/status/__init__.py b/tests/dashboard/status/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dashboard/status/test_dns.py b/tests/dashboard/status/test_dns.py new file mode 100644 index 0000000000..9ca48ba2d8 --- /dev/null +++ b/tests/dashboard/status/test_dns.py @@ -0,0 +1,121 @@ +"""Unit tests for esphome.dashboard.dns module.""" + +from __future__ import annotations + +import time +from unittest.mock import patch + +import pytest + +from esphome.dashboard.dns import DNSCache + + +@pytest.fixture +def dns_cache_fixture() -> DNSCache: + """Create a DNSCache instance.""" + return DNSCache() + + +def test_get_cached_addresses_not_in_cache(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses when hostname is not in cache.""" + now = time.monotonic() + result = dns_cache_fixture.get_cached_addresses("unknown.example.com", now) + assert result is None + + +def test_get_cached_addresses_expired(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses when cache entry is expired.""" + now = time.monotonic() + # Add entry that's already expired + dns_cache_fixture._cache["example.com"] = (now - 1, ["192.168.1.10"]) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result is None + # Expired entry should still be in cache (not removed by get_cached_addresses) + assert "example.com" in dns_cache_fixture._cache + + +def test_get_cached_addresses_valid(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses with valid cache entry.""" + now = time.monotonic() + # Add entry that expires in 60 seconds + dns_cache_fixture._cache["example.com"] = ( + now + 60, + ["192.168.1.10", "192.168.1.11"], + ) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result == ["192.168.1.10", "192.168.1.11"] + # Entry should still be in cache + assert "example.com" in dns_cache_fixture._cache + + +def test_get_cached_addresses_hostname_normalization( + dns_cache_fixture: DNSCache, +) -> None: + """Test get_cached_addresses normalizes hostname.""" + now = time.monotonic() + # Add entry with lowercase hostname + dns_cache_fixture._cache["example.com"] = (now + 60, ["192.168.1.10"]) + + # Test with various forms + assert dns_cache_fixture.get_cached_addresses("EXAMPLE.COM", now) == [ + "192.168.1.10" + ] + assert dns_cache_fixture.get_cached_addresses("example.com.", now) == [ + "192.168.1.10" + ] + assert dns_cache_fixture.get_cached_addresses("EXAMPLE.COM.", now) == [ + "192.168.1.10" + ] + + +def test_get_cached_addresses_ipv6(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses with IPv6 addresses.""" + now = time.monotonic() + dns_cache_fixture._cache["example.com"] = (now + 60, ["2001:db8::1", "fe80::1"]) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result == ["2001:db8::1", "fe80::1"] + + +def test_get_cached_addresses_empty_list(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses with empty address list.""" + now = time.monotonic() + dns_cache_fixture._cache["example.com"] = (now + 60, []) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result == [] + + +def test_get_cached_addresses_exception_in_cache(dns_cache_fixture: DNSCache) -> None: + """Test get_cached_addresses when cache contains an exception.""" + now = time.monotonic() + # Store an exception (from failed resolution) + dns_cache_fixture._cache["example.com"] = (now + 60, OSError("Resolution failed")) + + result = dns_cache_fixture.get_cached_addresses("example.com", now) + assert result is None # Should return None for exceptions + + +def test_async_resolve_not_called(dns_cache_fixture: DNSCache) -> None: + """Test that get_cached_addresses never calls async_resolve.""" + now = time.monotonic() + + with patch.object(dns_cache_fixture, "async_resolve") as mock_resolve: + # Test non-cached + result = dns_cache_fixture.get_cached_addresses("uncached.com", now) + assert result is None + mock_resolve.assert_not_called() + + # Test expired + dns_cache_fixture._cache["expired.com"] = (now - 1, ["192.168.1.10"]) + result = dns_cache_fixture.get_cached_addresses("expired.com", now) + assert result is None + mock_resolve.assert_not_called() + + # Test valid + dns_cache_fixture._cache["valid.com"] = (now + 60, ["192.168.1.10"]) + result = dns_cache_fixture.get_cached_addresses("valid.com", now) + assert result == ["192.168.1.10"] + mock_resolve.assert_not_called() diff --git a/tests/dashboard/status/test_mdns.py b/tests/dashboard/status/test_mdns.py new file mode 100644 index 0000000000..56c6d254cf --- /dev/null +++ b/tests/dashboard/status/test_mdns.py @@ -0,0 +1,240 @@ +"""Unit tests for esphome.dashboard.status.mdns module.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import pytest_asyncio +from zeroconf import AddressResolver, IPVersion + +from esphome.dashboard.const import DashboardEvent +from esphome.dashboard.status.mdns import MDNSStatus +from esphome.zeroconf import DiscoveredImport + + +@pytest_asyncio.fixture +async def mdns_status(mock_dashboard: Mock) -> MDNSStatus: + """Create an MDNSStatus instance in async context.""" + # We're in an async context so get_running_loop will work + return MDNSStatus(mock_dashboard) + + +@pytest.mark.asyncio +async def test_get_cached_addresses_no_zeroconf(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses when no zeroconf instance is available.""" + mdns_status.aiozc = None + result = mdns_status.get_cached_addresses("device.local") + assert result is None + + +@pytest.mark.asyncio +async def test_get_cached_addresses_not_in_cache(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses when address is not in cache.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = False + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result is None + mock_info.load_from_cache.assert_called_once_with(mdns_status.aiozc.zeroconf) + + +@pytest.mark.asyncio +async def test_get_cached_addresses_found_in_cache(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses when address is found in cache.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10", "fe80::1"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result == ["192.168.1.10", "fe80::1"] + mock_info.load_from_cache.assert_called_once_with(mdns_status.aiozc.zeroconf) + mock_info.parsed_scoped_addresses.assert_called_once_with(IPVersion.All) + + +@pytest.mark.asyncio +async def test_get_cached_addresses_with_trailing_dot(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses with hostname having trailing dot.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local.") + assert result == ["192.168.1.10"] + # Should normalize to device.local. for zeroconf + mock_resolver.assert_called_once_with("device.local.") + + +@pytest.mark.asyncio +async def test_get_cached_addresses_uppercase_hostname(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses with uppercase hostname.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("DEVICE.LOCAL") + assert result == ["192.168.1.10"] + # Should normalize to device.local. for zeroconf + mock_resolver.assert_called_once_with("device.local.") + + +@pytest.mark.asyncio +async def test_get_cached_addresses_simple_hostname(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses with simple hostname (no domain).""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["192.168.1.10"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device") + assert result == ["192.168.1.10"] + # Should append .local. for zeroconf + mock_resolver.assert_called_once_with("device.local.") + + +@pytest.mark.asyncio +async def test_get_cached_addresses_ipv6_only(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses returning only IPv6 addresses.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = ["fe80::1", "2001:db8::1"] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result == ["fe80::1", "2001:db8::1"] + + +@pytest.mark.asyncio +async def test_get_cached_addresses_empty_list(mdns_status: MDNSStatus) -> None: + """Test get_cached_addresses returning empty list from cache.""" + mdns_status.aiozc = Mock() + mdns_status.aiozc.zeroconf = Mock() + + with patch("esphome.dashboard.status.mdns.AddressResolver") as mock_resolver: + mock_info = Mock(spec=AddressResolver) + mock_info.load_from_cache.return_value = True + mock_info.parsed_scoped_addresses.return_value = [] + mock_resolver.return_value = mock_info + + result = mdns_status.get_cached_addresses("device.local") + assert result == [] + + +@pytest.mark.asyncio +async def test_async_setup_success(mock_dashboard: Mock) -> None: + """Test successful async_setup.""" + mdns_status = MDNSStatus(mock_dashboard) + with patch("esphome.dashboard.status.mdns.AsyncEsphomeZeroconf") as mock_zc: + mock_zc.return_value = Mock() + result = mdns_status.async_setup() + assert result is True + assert mdns_status.aiozc is not None + + +@pytest.mark.asyncio +async def test_async_setup_failure(mock_dashboard: Mock) -> None: + """Test async_setup with OSError.""" + mdns_status = MDNSStatus(mock_dashboard) + with patch("esphome.dashboard.status.mdns.AsyncEsphomeZeroconf") as mock_zc: + mock_zc.side_effect = OSError("Network error") + result = mdns_status.async_setup() + assert result is False + assert mdns_status.aiozc is None + + +@pytest.mark.asyncio +async def test_on_import_update_device_added(mdns_status: MDNSStatus) -> None: + """Test _on_import_update when a device is added.""" + # Create a DiscoveredImport object + discovered = DiscoveredImport( + device_name="test_device", + friendly_name="Test Device", + package_import_url="https://example.com/package", + project_name="test_project", + project_version="1.0.0", + network="wifi", + ) + + # Call _on_import_update with a device + mdns_status._on_import_update("test_device", discovered) + + # Should fire IMPORTABLE_DEVICE_ADDED event + mock_dashboard = mdns_status.dashboard + mock_dashboard.bus.async_fire.assert_called_once() + call_args = mock_dashboard.bus.async_fire.call_args + assert call_args[0][0] == DashboardEvent.IMPORTABLE_DEVICE_ADDED + assert "device" in call_args[0][1] + device_data = call_args[0][1]["device"] + assert device_data["name"] == "test_device" + assert device_data["friendly_name"] == "Test Device" + assert device_data["project_name"] == "test_project" + assert device_data["ignored"] is False + + +@pytest.mark.asyncio +async def test_on_import_update_device_ignored(mdns_status: MDNSStatus) -> None: + """Test _on_import_update when a device is ignored.""" + # Add device to ignored list + mdns_status.dashboard.ignored_devices.add("ignored_device") + + # Create a DiscoveredImport object for ignored device + discovered = DiscoveredImport( + device_name="ignored_device", + friendly_name="Ignored Device", + package_import_url="https://example.com/package", + project_name="test_project", + project_version="1.0.0", + network="ethernet", + ) + + # Call _on_import_update with an ignored device + mdns_status._on_import_update("ignored_device", discovered) + + # Should fire IMPORTABLE_DEVICE_ADDED event with ignored=True + mock_dashboard = mdns_status.dashboard + mock_dashboard.bus.async_fire.assert_called_once() + call_args = mock_dashboard.bus.async_fire.call_args + assert call_args[0][0] == DashboardEvent.IMPORTABLE_DEVICE_ADDED + device_data = call_args[0][1]["device"] + assert device_data["name"] == "ignored_device" + assert device_data["ignored"] is True + + +@pytest.mark.asyncio +async def test_on_import_update_device_removed(mdns_status: MDNSStatus) -> None: + """Test _on_import_update when a device is removed.""" + # Call _on_import_update with None (device removed) + mdns_status._on_import_update("removed_device", None) + + # Should fire IMPORTABLE_DEVICE_REMOVED event + mdns_status.dashboard.bus.async_fire.assert_called_once_with( + DashboardEvent.IMPORTABLE_DEVICE_REMOVED, {"name": "removed_device"} + ) diff --git a/tests/dashboard/test_entries.py b/tests/dashboard/test_entries.py new file mode 100644 index 0000000000..9a3a776b28 --- /dev/null +++ b/tests/dashboard/test_entries.py @@ -0,0 +1,288 @@ +"""Tests for dashboard entries Path-related functionality.""" + +from __future__ import annotations + +import os +from pathlib import Path +import tempfile +from unittest.mock import Mock + +import pytest + +from esphome.core import CORE +from esphome.dashboard.const import DashboardEvent +from esphome.dashboard.entries import DashboardEntries, DashboardEntry + + +def create_cache_key() -> tuple[int, int, float, int]: + """Helper to create a valid DashboardCacheKeyType.""" + return (0, 0, 0.0, 0) + + +@pytest.fixture(autouse=True) +def setup_core(): + """Set up CORE for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + CORE.config_path = Path(tmpdir) / "test.yaml" + yield + CORE.reset() + + +def test_dashboard_entry_path_initialization() -> None: + """Test DashboardEntry initializes with path correctly.""" + test_path = Path("/test/config/device.yaml") + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + assert entry.cache_key == cache_key + + +def test_dashboard_entry_path_with_absolute_path() -> None: + """Test DashboardEntry handles absolute paths.""" + # Use a truly absolute path for the platform + test_path = Path.cwd() / "absolute" / "path" / "to" / "config.yaml" + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + assert entry.path.is_absolute() + + +def test_dashboard_entry_path_with_relative_path() -> None: + """Test DashboardEntry handles relative paths.""" + test_path = Path("configs/device.yaml") + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + assert not entry.path.is_absolute() + + +@pytest.mark.asyncio +async def test_dashboard_entries_get_by_path( + dashboard_entries: DashboardEntries, tmp_path: Path +) -> None: + """Test getting entry by path.""" + # Create a test file + test_file = tmp_path / "device.yaml" + test_file.write_text("test config") + + # Update entries to load the file + await dashboard_entries.async_update_entries() + + # Verify the entry was loaded + all_entries = dashboard_entries.async_all() + assert len(all_entries) == 1 + entry = all_entries[0] + assert entry.path == test_file + + # Also verify get() works with Path + result = dashboard_entries.get(test_file) + assert result == entry + + +@pytest.mark.asyncio +async def test_dashboard_entries_get_nonexistent_path( + dashboard_entries: DashboardEntries, +) -> None: + """Test getting non-existent entry returns None.""" + result = dashboard_entries.get("/nonexistent/path.yaml") + assert result is None + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_normalization( + dashboard_entries: DashboardEntries, tmp_path: Path +) -> None: + """Test that paths are handled consistently.""" + # Create a test file + test_file = tmp_path / "device.yaml" + test_file.write_text("test config") + + # Update entries to load the file + await dashboard_entries.async_update_entries() + + # Get the entry by path + result = dashboard_entries.get(test_file) + assert result is not None + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_with_spaces( + dashboard_entries: DashboardEntries, tmp_path: Path +) -> None: + """Test handling paths with spaces.""" + # Create a test file with spaces in name + test_file = tmp_path / "my device.yaml" + test_file.write_text("test config") + + # Update entries to load the file + await dashboard_entries.async_update_entries() + + # Get the entry by path + result = dashboard_entries.get(test_file) + assert result is not None + assert result.path == test_file + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_with_special_chars( + dashboard_entries: DashboardEntries, tmp_path: Path +) -> None: + """Test handling paths with special characters.""" + # Create a test file with special characters + test_file = tmp_path / "device-01_test.yaml" + test_file.write_text("test config") + + # Update entries to load the file + await dashboard_entries.async_update_entries() + + # Get the entry by path + result = dashboard_entries.get(test_file) + assert result is not None + + +def test_dashboard_entries_windows_path() -> None: + """Test handling Windows-style paths.""" + test_path = Path(r"C:\Users\test\esphome\device.yaml") + cache_key = create_cache_key() + + entry = DashboardEntry(test_path, cache_key) + + assert entry.path == test_path + + +@pytest.mark.asyncio +async def test_dashboard_entries_path_to_cache_key_mapping( + dashboard_entries: DashboardEntries, tmp_path: Path +) -> None: + """Test internal entries storage with paths and cache keys.""" + # Create test files + file1 = tmp_path / "device1.yaml" + file2 = tmp_path / "device2.yaml" + file1.write_text("test config 1") + file2.write_text("test config 2") + + # Update entries to load the files + await dashboard_entries.async_update_entries() + + # Get entries and verify they have different cache keys + entry1 = dashboard_entries.get(file1) + entry2 = dashboard_entries.get(file2) + + assert entry1 is not None + assert entry2 is not None + assert entry1.cache_key != entry2.cache_key + + +def test_dashboard_entry_path_property() -> None: + """Test that path property returns expected value.""" + test_path = Path("/test/config/device.yaml") + entry = DashboardEntry(test_path, create_cache_key()) + + assert entry.path == test_path + assert isinstance(entry.path, Path) + + +@pytest.mark.asyncio +async def test_dashboard_entries_all_returns_entries_with_paths( + dashboard_entries: DashboardEntries, tmp_path: Path +) -> None: + """Test that all() returns entries with their paths intact.""" + # Create test files + files = [ + tmp_path / "device1.yaml", + tmp_path / "device2.yaml", + tmp_path / "device3.yaml", + ] + + for file in files: + file.write_text("test config") + + # Update entries to load the files + await dashboard_entries.async_update_entries() + + all_entries = dashboard_entries.async_all() + + assert len(all_entries) == len(files) + retrieved_paths = [entry.path for entry in all_entries] + assert set(retrieved_paths) == set(files) + + +@pytest.mark.asyncio +async def test_async_update_entries_removed_path( + dashboard_entries: DashboardEntries, mock_dashboard: Mock, tmp_path: Path +) -> None: + """Test that removed files trigger ENTRY_REMOVED event.""" + + # Create a test file + test_file = tmp_path / "device.yaml" + test_file.write_text("test config") + + # First update to add the entry + await dashboard_entries.async_update_entries() + + # Verify entry was added + all_entries = dashboard_entries.async_all() + assert len(all_entries) == 1 + entry = all_entries[0] + + # Delete the file + test_file.unlink() + + # Second update to detect removal + await dashboard_entries.async_update_entries() + + # Verify entry was removed + all_entries = dashboard_entries.async_all() + assert len(all_entries) == 0 + + # Verify ENTRY_REMOVED event was fired + mock_dashboard.bus.async_fire.assert_any_call( + DashboardEvent.ENTRY_REMOVED, {"entry": entry} + ) + + +@pytest.mark.asyncio +async def test_async_update_entries_updated_path( + dashboard_entries: DashboardEntries, mock_dashboard: Mock, tmp_path: Path +) -> None: + """Test that modified files trigger ENTRY_UPDATED event.""" + + # Create a test file + test_file = tmp_path / "device.yaml" + test_file.write_text("test config") + + # First update to add the entry + await dashboard_entries.async_update_entries() + + # Verify entry was added + all_entries = dashboard_entries.async_all() + assert len(all_entries) == 1 + entry = all_entries[0] + original_cache_key = entry.cache_key + + # Modify the file to change its mtime + test_file.write_text("updated config") + # Explicitly change the mtime to ensure it's different + stat = test_file.stat() + os.utime(test_file, (stat.st_atime, stat.st_mtime + 1)) + + # Second update to detect modification + await dashboard_entries.async_update_entries() + + # Verify entry is still there with updated cache key + all_entries = dashboard_entries.async_all() + assert len(all_entries) == 1 + updated_entry = all_entries[0] + assert updated_entry == entry # Same entry object + assert updated_entry.cache_key != original_cache_key # But cache key updated + + # Verify ENTRY_UPDATED event was fired + mock_dashboard.bus.async_fire.assert_any_call( + DashboardEvent.ENTRY_UPDATED, {"entry": entry} + ) diff --git a/tests/dashboard/test_settings.py b/tests/dashboard/test_settings.py new file mode 100644 index 0000000000..c9097fe5e2 --- /dev/null +++ b/tests/dashboard/test_settings.py @@ -0,0 +1,161 @@ +"""Tests for dashboard settings Path-related functionality.""" + +from __future__ import annotations + +from pathlib import Path +import tempfile + +import pytest + +from esphome.dashboard.settings import DashboardSettings + + +@pytest.fixture +def dashboard_settings(tmp_path: Path) -> DashboardSettings: + """Create DashboardSettings instance with temp directory.""" + settings = DashboardSettings() + # Resolve symlinks to ensure paths match + resolved_dir = tmp_path.resolve() + settings.config_dir = resolved_dir + settings.absolute_config_dir = resolved_dir + return settings + + +def test_rel_path_simple(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with simple relative path.""" + result = dashboard_settings.rel_path("config.yaml") + + expected = dashboard_settings.config_dir / "config.yaml" + assert result == expected + + +def test_rel_path_multiple_components(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with multiple path components.""" + result = dashboard_settings.rel_path("subfolder", "device", "config.yaml") + + expected = dashboard_settings.config_dir / "subfolder" / "device" / "config.yaml" + assert result == expected + + +def test_rel_path_with_dots(dashboard_settings: DashboardSettings) -> None: + """Test rel_path prevents directory traversal.""" + # This should raise ValueError as it tries to go outside config_dir + with pytest.raises(ValueError): + dashboard_settings.rel_path("..", "outside.yaml") + + +def test_rel_path_absolute_path_within_config( + dashboard_settings: DashboardSettings, +) -> None: + """Test rel_path with absolute path that's within config dir.""" + internal_path = dashboard_settings.absolute_config_dir / "internal.yaml" + + internal_path.touch() + result = dashboard_settings.rel_path("internal.yaml") + expected = dashboard_settings.config_dir / "internal.yaml" + assert result == expected + + +def test_rel_path_absolute_path_outside_config( + dashboard_settings: DashboardSettings, +) -> None: + """Test rel_path with absolute path outside config dir raises error.""" + outside_path = "/tmp/outside/config.yaml" + + with pytest.raises(ValueError): + dashboard_settings.rel_path(outside_path) + + +def test_rel_path_empty_args(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with no arguments returns config_dir.""" + result = dashboard_settings.rel_path() + assert result == dashboard_settings.config_dir + + +def test_rel_path_with_pathlib_path(dashboard_settings: DashboardSettings) -> None: + """Test rel_path works with Path objects as arguments.""" + path_obj = Path("subfolder") / "config.yaml" + result = dashboard_settings.rel_path(path_obj) + + expected = dashboard_settings.config_dir / "subfolder" / "config.yaml" + assert result == expected + + +def test_rel_path_normalizes_slashes(dashboard_settings: DashboardSettings) -> None: + """Test rel_path normalizes path separators.""" + # os.path.join normalizes slashes on Windows but preserves them on Unix + # Test that providing components separately gives same result + result1 = dashboard_settings.rel_path("folder", "subfolder", "file.yaml") + result2 = dashboard_settings.rel_path("folder", "subfolder", "file.yaml") + assert result1 == result2 + + # Also test that the result is as expected + expected = dashboard_settings.config_dir / "folder" / "subfolder" / "file.yaml" + assert result1 == expected + + +def test_rel_path_handles_spaces(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles paths with spaces.""" + result = dashboard_settings.rel_path("my folder", "my config.yaml") + + expected = dashboard_settings.config_dir / "my folder" / "my config.yaml" + assert result == expected + + +def test_rel_path_handles_special_chars(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles paths with special characters.""" + result = dashboard_settings.rel_path("device-01_test", "config.yaml") + + expected = dashboard_settings.config_dir / "device-01_test" / "config.yaml" + assert result == expected + + +def test_config_dir_as_path_property(dashboard_settings: DashboardSettings) -> None: + """Test that config_dir can be accessed and used with Path operations.""" + config_path = dashboard_settings.config_dir + + assert config_path.exists() + assert config_path.is_dir() + assert config_path.is_absolute() + + +def test_absolute_config_dir_property(dashboard_settings: DashboardSettings) -> None: + """Test absolute_config_dir is a Path object.""" + assert isinstance(dashboard_settings.absolute_config_dir, Path) + assert dashboard_settings.absolute_config_dir.exists() + assert dashboard_settings.absolute_config_dir.is_dir() + assert dashboard_settings.absolute_config_dir.is_absolute() + + +def test_rel_path_symlink_inside_config(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with symlink that points inside config dir.""" + target = dashboard_settings.absolute_config_dir / "target.yaml" + target.touch() + symlink = dashboard_settings.absolute_config_dir / "link.yaml" + symlink.symlink_to(target) + result = dashboard_settings.rel_path("link.yaml") + expected = dashboard_settings.config_dir / "link.yaml" + assert result == expected + + +def test_rel_path_symlink_outside_config(dashboard_settings: DashboardSettings) -> None: + """Test rel_path with symlink that points outside config dir.""" + with tempfile.NamedTemporaryFile(suffix=".yaml") as tmp: + symlink = dashboard_settings.absolute_config_dir / "external_link.yaml" + symlink.symlink_to(tmp.name) + with pytest.raises(ValueError): + dashboard_settings.rel_path("external_link.yaml") + + +def test_rel_path_with_none_arg(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles None arguments gracefully.""" + result = dashboard_settings.rel_path("None") + expected = dashboard_settings.config_dir / "None" + assert result == expected + + +def test_rel_path_with_numeric_args(dashboard_settings: DashboardSettings) -> None: + """Test rel_path handles numeric arguments.""" + result = dashboard_settings.rel_path("123", "456.789") + expected = dashboard_settings.config_dir / "123" / "456.789" + assert result == expected diff --git a/tests/dashboard/test_web_server.py b/tests/dashboard/test_web_server.py index 1938617f20..5bbe7e78fc 100644 --- a/tests/dashboard/test_web_server.py +++ b/tests/dashboard/test_web_server.py @@ -2,11 +2,12 @@ from __future__ import annotations import asyncio from collections.abc import Generator +from contextlib import asynccontextmanager import gzip import json import os from pathlib import Path -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest import pytest_asyncio @@ -14,9 +15,19 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from tornado.testing import bind_unused_port +from tornado.websocket import WebSocketClientConnection, websocket_connect from esphome.dashboard import web_server +from esphome.dashboard.const import DashboardEvent from esphome.dashboard.core import DASHBOARD +from esphome.dashboard.entries import ( + DashboardEntry, + EntryStateSource, + bool_to_entry_state, +) +from esphome.dashboard.models import build_importable_device_dict +from esphome.dashboard.web_server import DashboardSubscriber +from esphome.zeroconf import DiscoveredImport from .common import get_fixture_path @@ -49,7 +60,7 @@ def mock_trash_storage_path(tmp_path: Path) -> Generator[MagicMock]: """Fixture to mock trash_storage_path.""" trash_dir = tmp_path / "trash" with patch( - "esphome.dashboard.web_server.trash_storage_path", return_value=str(trash_dir) + "esphome.dashboard.web_server.trash_storage_path", return_value=trash_dir ) as mock: yield mock @@ -60,7 +71,7 @@ def mock_archive_storage_path(tmp_path: Path) -> Generator[MagicMock]: archive_dir = tmp_path / "archive" with patch( "esphome.dashboard.web_server.archive_storage_path", - return_value=str(archive_dir), + return_value=archive_dir, ) as mock: yield mock @@ -126,6 +137,33 @@ async def dashboard() -> DashboardTestHelper: io_loop.close() +@asynccontextmanager +async def websocket_connection(dashboard: DashboardTestHelper): + """Async context manager for WebSocket connections.""" + url = f"ws://127.0.0.1:{dashboard.port}/events" + ws = await websocket_connect(url) + try: + yield ws + finally: + if ws: + ws.close() + + +@pytest_asyncio.fixture +async def websocket_client(dashboard: DashboardTestHelper) -> WebSocketClientConnection: + """Create a WebSocket connection for testing.""" + url = f"ws://127.0.0.1:{dashboard.port}/events" + ws = await websocket_connect(url) + + # Read and discard initial state message + await ws.read_message() + + yield ws + + if ws: + ws.close() + + @pytest.mark.asyncio async def test_main_page(dashboard: DashboardTestHelper) -> None: response = await dashboard.fetch("/") @@ -257,7 +295,7 @@ async def test_download_binary_handler_with_file( # Mock storage JSON mock_storage = Mock() mock_storage.name = "test_device" - mock_storage.firmware_bin_path = str(firmware_file) + mock_storage.firmware_bin_path = firmware_file mock_storage_json.load.return_value = mock_storage response = await dashboard.fetch( @@ -289,7 +327,7 @@ async def test_download_binary_handler_compressed( # Mock storage JSON mock_storage = Mock() mock_storage.name = "test_device" - mock_storage.firmware_bin_path = str(firmware_file) + mock_storage.firmware_bin_path = firmware_file mock_storage_json.load.return_value = mock_storage response = await dashboard.fetch( @@ -321,7 +359,7 @@ async def test_download_binary_handler_custom_download_name( # Mock storage JSON mock_storage = Mock() mock_storage.name = "test_device" - mock_storage.firmware_bin_path = str(firmware_file) + mock_storage.firmware_bin_path = firmware_file mock_storage_json.load.return_value = mock_storage response = await dashboard.fetch( @@ -355,7 +393,7 @@ async def test_download_binary_handler_idedata_fallback( # Mock storage JSON mock_storage = Mock() mock_storage.name = "test_device" - mock_storage.firmware_bin_path = str(firmware_file) + mock_storage.firmware_bin_path = firmware_file mock_storage_json.load.return_value = mock_storage # Mock idedata response @@ -402,7 +440,7 @@ async def test_edit_request_handler_post_existing( test_file.write_text("esphome:\n name: original\n") # Configure the mock settings - mock_dashboard_settings.rel_path.return_value = str(test_file) + mock_dashboard_settings.rel_path.return_value = test_file mock_dashboard_settings.absolute_config_dir = test_file.parent new_content = "esphome:\n name: modified\n" @@ -426,7 +464,7 @@ async def test_unarchive_request_handler( ) -> None: """Test the UnArchiveRequestHandler.post method.""" # Set up an archived file - archive_dir = Path(mock_archive_storage_path.return_value) + archive_dir = mock_archive_storage_path.return_value archive_dir.mkdir(parents=True, exist_ok=True) archived_file = archive_dir / "archived.yaml" archived_file.write_text("test content") @@ -435,7 +473,7 @@ async def test_unarchive_request_handler( config_dir = tmp_path / "config" config_dir.mkdir(parents=True, exist_ok=True) destination_file = config_dir / "archived.yaml" - mock_dashboard_settings.rel_path.return_value = str(destination_file) + mock_dashboard_settings.rel_path.return_value = destination_file response = await dashboard.fetch( "/unarchive?configuration=archived.yaml", @@ -474,7 +512,7 @@ async def test_secret_keys_handler_with_file( # Configure mock to return our temp secrets file # Since the file actually exists, os.path.isfile will return True naturally - mock_dashboard_settings.rel_path.return_value = str(secrets_file) + mock_dashboard_settings.rel_path.return_value = secrets_file response = await dashboard.fetch("/secret_keys", method="GET") assert response.code == 200 @@ -538,8 +576,8 @@ def test_start_web_server_with_address_port( ) -> None: """Test the start_web_server function with address and port.""" app = Mock() - trash_dir = Path(mock_trash_storage_path.return_value) - archive_dir = Path(mock_archive_storage_path.return_value) + trash_dir = mock_trash_storage_path.return_value + archive_dir = mock_archive_storage_path.return_value # Create trash dir to test migration trash_dir.mkdir() @@ -643,12 +681,12 @@ async def test_archive_handler_with_build_folder( (build_folder / ".pioenvs").mkdir() mock_dashboard_settings.config_dir = str(config_dir) - mock_dashboard_settings.rel_path.return_value = str(test_config) - mock_archive_storage_path.return_value = str(archive_dir) + mock_dashboard_settings.rel_path.return_value = test_config + mock_archive_storage_path.return_value = archive_dir mock_storage = MagicMock() mock_storage.name = "test_device" - mock_storage.build_path = str(build_folder) + mock_storage.build_path = build_folder mock_storage_json.load.return_value = mock_storage response = await dashboard.fetch( @@ -686,8 +724,8 @@ async def test_archive_handler_no_build_folder( test_config.write_text("esphome:\n name: test_device\n") mock_dashboard_settings.config_dir = str(config_dir) - mock_dashboard_settings.rel_path.return_value = str(test_config) - mock_archive_storage_path.return_value = str(archive_dir) + mock_dashboard_settings.rel_path.return_value = test_config + mock_archive_storage_path.return_value = archive_dir mock_storage = MagicMock() mock_storage.name = "test_device" @@ -730,3 +768,537 @@ def test_start_web_server_with_unix_socket(tmp_path: Path) -> None: mock_server_class.assert_called_once_with(app) mock_bind.assert_called_once_with(str(socket_path), mode=0o666) server.add_socket.assert_called_once() + + +def test_build_cache_arguments_no_entry(mock_dashboard: Mock) -> None: + """Test with no entry returns empty list.""" + result = web_server.build_cache_arguments(None, mock_dashboard, 0.0) + assert result == [] + + +def test_build_cache_arguments_no_address_no_name(mock_dashboard: Mock) -> None: + """Test with entry but no address or name.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.address = None + entry.name = None + result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0) + assert result == [] + + +def test_build_cache_arguments_mdns_address_cached(mock_dashboard: Mock) -> None: + """Test with .local address that has cached mDNS results.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.address = "device.local" + entry.name = None + mock_dashboard.mdns_status = Mock() + mock_dashboard.mdns_status.get_cached_addresses.return_value = [ + "192.168.1.10", + "fe80::1", + ] + + result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0) + + assert result == [ + "--mdns-address-cache", + "device.local=192.168.1.10,fe80::1", + ] + mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with( + "device.local" + ) + + +def test_build_cache_arguments_dns_address_cached(mock_dashboard: Mock) -> None: + """Test with non-.local address that has cached DNS results.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.address = "example.com" + entry.name = None + mock_dashboard.dns_cache = Mock() + mock_dashboard.dns_cache.get_cached_addresses.return_value = [ + "93.184.216.34", + "2606:2800:220:1:248:1893:25c8:1946", + ] + + now = 100.0 + result = web_server.build_cache_arguments(entry, mock_dashboard, now) + + # IPv6 addresses are sorted before IPv4 + assert result == [ + "--dns-address-cache", + "example.com=2606:2800:220:1:248:1893:25c8:1946,93.184.216.34", + ] + mock_dashboard.dns_cache.get_cached_addresses.assert_called_once_with( + "example.com", now + ) + + +def test_build_cache_arguments_name_without_address(mock_dashboard: Mock) -> None: + """Test with name but no address - should check mDNS with .local suffix.""" + entry = Mock(spec=web_server.DashboardEntry) + entry.name = "my-device" + entry.address = None + mock_dashboard.mdns_status = Mock() + mock_dashboard.mdns_status.get_cached_addresses.return_value = ["192.168.1.20"] + + result = web_server.build_cache_arguments(entry, mock_dashboard, 0.0) + + assert result == [ + "--mdns-address-cache", + "my-device.local=192.168.1.20", + ] + mock_dashboard.mdns_status.get_cached_addresses.assert_called_once_with( + "my-device.local" + ) + + +@pytest.mark.asyncio +async def test_websocket_connection_initial_state( + dashboard: DashboardTestHelper, +) -> None: + """Test WebSocket connection and initial state.""" + async with websocket_connection(dashboard) as ws: + # Should receive initial state with configured and importable devices + msg = await ws.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "initial_state" + assert "devices" in data["data"] + assert "configured" in data["data"]["devices"] + assert "importable" in data["data"]["devices"] + + # Check configured devices + configured = data["data"]["devices"]["configured"] + assert len(configured) > 0 + assert configured[0]["name"] == "pico" # From test fixtures + + +@pytest.mark.asyncio +async def test_websocket_ping_pong( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket ping/pong mechanism.""" + # Send ping + await websocket_client.write_message(json.dumps({"event": "ping"})) + + # Should receive pong + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "pong" + + +@pytest.mark.asyncio +async def test_websocket_invalid_json( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket handling of invalid JSON.""" + # Send invalid JSON + await websocket_client.write_message("not valid json {]") + + # Send a valid ping to verify connection is still alive + await websocket_client.write_message(json.dumps({"event": "ping"})) + + # Should receive pong, confirming the connection wasn't closed by invalid JSON + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "pong" + + +@pytest.mark.asyncio +async def test_websocket_authentication_required( + dashboard: DashboardTestHelper, +) -> None: + """Test WebSocket authentication when auth is required.""" + with patch( + "esphome.dashboard.web_server.is_authenticated" + ) as mock_is_authenticated: + mock_is_authenticated.return_value = False + + # Try to connect - should be rejected with 401 + url = f"ws://127.0.0.1:{dashboard.port}/events" + with pytest.raises(HTTPClientError) as exc_info: + await websocket_connect(url) + # Should get HTTP 401 Unauthorized + assert exc_info.value.code == 401 + + +@pytest.mark.asyncio +async def test_websocket_authentication_not_required( + dashboard: DashboardTestHelper, +) -> None: + """Test WebSocket connection when no auth is required.""" + with patch( + "esphome.dashboard.web_server.is_authenticated" + ) as mock_is_authenticated: + mock_is_authenticated.return_value = True + + # Should be able to connect successfully + async with websocket_connection(dashboard) as ws: + msg = await ws.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "initial_state" + + +@pytest.mark.asyncio +async def test_websocket_entry_state_changed( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket entry state changed event.""" + # Simulate entry state change + entry = DASHBOARD.entries.async_all()[0] + state = bool_to_entry_state(True, EntryStateSource.MDNS) + DASHBOARD.bus.async_fire( + DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state} + ) + + # Should receive state change event + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "entry_state_changed" + assert data["data"]["filename"] == entry.filename + assert data["data"]["name"] == entry.name + assert data["data"]["state"] is True + + +@pytest.mark.asyncio +async def test_websocket_entry_added( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket entry added event.""" + # Create a mock entry + mock_entry = Mock(spec=DashboardEntry) + mock_entry.filename = "test.yaml" + mock_entry.name = "test_device" + mock_entry.to_dict.return_value = { + "name": "test_device", + "filename": "test.yaml", + "configuration": "test.yaml", + } + + # Simulate entry added + DASHBOARD.bus.async_fire(DashboardEvent.ENTRY_ADDED, {"entry": mock_entry}) + + # Should receive entry added event + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "entry_added" + assert data["data"]["device"]["name"] == "test_device" + assert data["data"]["device"]["filename"] == "test.yaml" + + +@pytest.mark.asyncio +async def test_websocket_entry_removed( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket entry removed event.""" + # Create a mock entry + mock_entry = Mock(spec=DashboardEntry) + mock_entry.filename = "removed.yaml" + mock_entry.name = "removed_device" + mock_entry.to_dict.return_value = { + "name": "removed_device", + "filename": "removed.yaml", + "configuration": "removed.yaml", + } + + # Simulate entry removed + DASHBOARD.bus.async_fire(DashboardEvent.ENTRY_REMOVED, {"entry": mock_entry}) + + # Should receive entry removed event + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "entry_removed" + assert data["data"]["device"]["name"] == "removed_device" + assert data["data"]["device"]["filename"] == "removed.yaml" + + +@pytest.mark.asyncio +async def test_websocket_importable_device_added( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket importable device added event with real DiscoveredImport.""" + # Create a real DiscoveredImport object + discovered = DiscoveredImport( + device_name="new_import_device", + friendly_name="New Import Device", + package_import_url="https://example.com/package", + project_name="test_project", + project_version="1.0.0", + network="wifi", + ) + + # Directly fire the event as the mDNS system would + device_dict = build_importable_device_dict(DASHBOARD, discovered) + DASHBOARD.bus.async_fire( + DashboardEvent.IMPORTABLE_DEVICE_ADDED, {"device": device_dict} + ) + + # Should receive importable device added event + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "importable_device_added" + assert data["data"]["device"]["name"] == "new_import_device" + assert data["data"]["device"]["friendly_name"] == "New Import Device" + assert data["data"]["device"]["project_name"] == "test_project" + assert data["data"]["device"]["network"] == "wifi" + assert data["data"]["device"]["ignored"] is False + + +@pytest.mark.asyncio +async def test_websocket_importable_device_added_ignored( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket importable device added event for ignored device.""" + # Add device to ignored list + DASHBOARD.ignored_devices.add("ignored_device") + + # Create a real DiscoveredImport object + discovered = DiscoveredImport( + device_name="ignored_device", + friendly_name="Ignored Device", + package_import_url="https://example.com/package", + project_name="test_project", + project_version="1.0.0", + network="ethernet", + ) + + # Directly fire the event as the mDNS system would + device_dict = build_importable_device_dict(DASHBOARD, discovered) + DASHBOARD.bus.async_fire( + DashboardEvent.IMPORTABLE_DEVICE_ADDED, {"device": device_dict} + ) + + # Should receive importable device added event with ignored=True + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "importable_device_added" + assert data["data"]["device"]["name"] == "ignored_device" + assert data["data"]["device"]["friendly_name"] == "Ignored Device" + assert data["data"]["device"]["network"] == "ethernet" + assert data["data"]["device"]["ignored"] is True + + +@pytest.mark.asyncio +async def test_websocket_importable_device_removed( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket importable device removed event.""" + # Simulate importable device removed + DASHBOARD.bus.async_fire( + DashboardEvent.IMPORTABLE_DEVICE_REMOVED, + {"name": "removed_import_device"}, + ) + + # Should receive importable device removed event + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "importable_device_removed" + assert data["data"]["name"] == "removed_import_device" + + +@pytest.mark.asyncio +async def test_websocket_importable_device_already_configured( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test that importable device event is not sent if device is already configured.""" + # Get an existing configured device name + existing_entry = DASHBOARD.entries.async_all()[0] + + # Simulate importable device added with same name as configured device + DASHBOARD.bus.async_fire( + DashboardEvent.IMPORTABLE_DEVICE_ADDED, + { + "device": { + "name": existing_entry.name, + "friendly_name": "Should Not Be Sent", + "package_import_url": "https://example.com/package", + "project_name": "test_project", + "project_version": "1.0.0", + "network": "wifi", + } + }, + ) + + # Send a ping to ensure connection is still alive + await websocket_client.write_message(json.dumps({"event": "ping"})) + + # Should only receive pong, not the importable device event + msg = await websocket_client.read_message() + assert msg is not None + data = json.loads(msg) + assert data["event"] == "pong" + + +@pytest.mark.asyncio +async def test_websocket_multiple_connections(dashboard: DashboardTestHelper) -> None: + """Test multiple WebSocket connections.""" + async with ( + websocket_connection(dashboard) as ws1, + websocket_connection(dashboard) as ws2, + ): + # Both should receive initial state + msg1 = await ws1.read_message() + assert msg1 is not None + data1 = json.loads(msg1) + assert data1["event"] == "initial_state" + + msg2 = await ws2.read_message() + assert msg2 is not None + data2 = json.loads(msg2) + assert data2["event"] == "initial_state" + + # Fire an event - both should receive it + entry = DASHBOARD.entries.async_all()[0] + state = bool_to_entry_state(False, EntryStateSource.MDNS) + DASHBOARD.bus.async_fire( + DashboardEvent.ENTRY_STATE_CHANGED, {"entry": entry, "state": state} + ) + + msg1 = await ws1.read_message() + assert msg1 is not None + data1 = json.loads(msg1) + assert data1["event"] == "entry_state_changed" + + msg2 = await ws2.read_message() + assert msg2 is not None + data2 = json.loads(msg2) + assert data2["event"] == "entry_state_changed" + + +@pytest.mark.asyncio +async def test_dashboard_subscriber_lifecycle(dashboard: DashboardTestHelper) -> None: + """Test DashboardSubscriber lifecycle.""" + subscriber = DashboardSubscriber() + + # Initially no subscribers + assert len(subscriber._subscribers) == 0 + assert subscriber._event_loop_task is None + + # Add a subscriber + mock_websocket = Mock() + unsubscribe = subscriber.subscribe(mock_websocket) + + # Should have started the event loop task + assert len(subscriber._subscribers) == 1 + assert subscriber._event_loop_task is not None + + # Unsubscribe + unsubscribe() + + # Should have stopped the task + assert len(subscriber._subscribers) == 0 + + +@pytest.mark.asyncio +async def test_dashboard_subscriber_entries_update_interval( + dashboard: DashboardTestHelper, +) -> None: + """Test DashboardSubscriber entries update interval.""" + # Patch the constants to make the test run faster + with ( + patch("esphome.dashboard.web_server.DASHBOARD_POLL_INTERVAL", 0.01), + patch("esphome.dashboard.web_server.DASHBOARD_ENTRIES_UPDATE_ITERATIONS", 2), + patch("esphome.dashboard.web_server.settings") as mock_settings, + patch("esphome.dashboard.web_server.DASHBOARD") as mock_dashboard, + ): + mock_settings.status_use_mqtt = False + + # Mock dashboard dependencies + mock_dashboard.ping_request = Mock() + mock_dashboard.ping_request.set = Mock() + mock_dashboard.entries = Mock() + mock_dashboard.entries.async_request_update_entries = Mock() + + subscriber = DashboardSubscriber() + mock_websocket = Mock() + + # Subscribe to start the event loop + unsubscribe = subscriber.subscribe(mock_websocket) + + # Wait for a few iterations to ensure entries update is called + await asyncio.sleep(0.05) # Should be enough for 2+ iterations + + # Unsubscribe to stop the task + unsubscribe() + + # Verify entries update was called + assert mock_dashboard.entries.async_request_update_entries.call_count >= 1 + # Verify ping request was set multiple times + assert mock_dashboard.ping_request.set.call_count >= 2 + + +@pytest.mark.asyncio +async def test_websocket_refresh_command( + dashboard: DashboardTestHelper, websocket_client: WebSocketClientConnection +) -> None: + """Test WebSocket refresh command triggers dashboard update.""" + with patch("esphome.dashboard.web_server.DASHBOARD_SUBSCRIBER") as mock_subscriber: + mock_subscriber.request_refresh = Mock() + + # Send refresh command + await websocket_client.write_message(json.dumps({"event": "refresh"})) + + # Give it a moment to process + await asyncio.sleep(0.01) + + # Verify request_refresh was called + mock_subscriber.request_refresh.assert_called_once() + + +@pytest.mark.asyncio +async def test_dashboard_subscriber_refresh_event( + dashboard: DashboardTestHelper, +) -> None: + """Test DashboardSubscriber refresh event triggers immediate update.""" + # Patch the constants to make the test run faster + with ( + patch( + "esphome.dashboard.web_server.DASHBOARD_POLL_INTERVAL", 1.0 + ), # Long timeout + patch( + "esphome.dashboard.web_server.DASHBOARD_ENTRIES_UPDATE_ITERATIONS", 100 + ), # Won't reach naturally + patch("esphome.dashboard.web_server.settings") as mock_settings, + patch("esphome.dashboard.web_server.DASHBOARD") as mock_dashboard, + ): + mock_settings.status_use_mqtt = False + + # Mock dashboard dependencies + mock_dashboard.ping_request = Mock() + mock_dashboard.ping_request.set = Mock() + mock_dashboard.entries = Mock() + mock_dashboard.entries.async_request_update_entries = AsyncMock() + + subscriber = DashboardSubscriber() + mock_websocket = Mock() + + # Subscribe to start the event loop + unsubscribe = subscriber.subscribe(mock_websocket) + + # Wait a bit to ensure loop is running + await asyncio.sleep(0.01) + + # Verify entries update hasn't been called yet (iterations not reached) + assert mock_dashboard.entries.async_request_update_entries.call_count == 0 + + # Request refresh + subscriber.request_refresh() + + # Wait for the refresh to be processed + await asyncio.sleep(0.01) + + # Now entries update should have been called + assert mock_dashboard.entries.async_request_update_entries.call_count == 1 + + # Unsubscribe to stop the task + unsubscribe() + + # Give it a moment to clean up + await asyncio.sleep(0.01) diff --git a/tests/dashboard/test_web_server_paths.py b/tests/dashboard/test_web_server_paths.py new file mode 100644 index 0000000000..b596ebb581 --- /dev/null +++ b/tests/dashboard/test_web_server_paths.py @@ -0,0 +1,223 @@ +"""Tests for dashboard web_server Path-related functionality.""" + +from __future__ import annotations + +import gzip +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +from esphome.dashboard import web_server + + +def test_get_base_frontend_path_production() -> None: + """Test get_base_frontend_path in production mode.""" + mock_module = MagicMock() + mock_module.where.return_value = Path("/usr/local/lib/esphome_dashboard") + + with ( + patch.dict(os.environ, {}, clear=True), + patch.dict("sys.modules", {"esphome_dashboard": mock_module}), + ): + result = web_server.get_base_frontend_path() + assert result == Path("/usr/local/lib/esphome_dashboard") + mock_module.where.assert_called_once() + + +def test_get_base_frontend_path_dev_mode() -> None: + """Test get_base_frontend_path in development mode.""" + test_path = "/home/user/esphome/dashboard" + + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": test_path}): + result = web_server.get_base_frontend_path() + + # The function uses Path.resolve() which resolves symlinks + # The actual function adds "/" to the path, so we simulate that + test_path_with_slash = test_path if test_path.endswith("/") else test_path + "/" + expected = ( + Path(os.getcwd()) / test_path_with_slash / "esphome_dashboard" + ).resolve() + assert result == expected + + +def test_get_base_frontend_path_dev_mode_with_trailing_slash() -> None: + """Test get_base_frontend_path in dev mode with trailing slash.""" + test_path = "/home/user/esphome/dashboard/" + + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": test_path}): + result = web_server.get_base_frontend_path() + + # The function uses Path.resolve() which resolves symlinks + expected = (Path.cwd() / test_path / "esphome_dashboard").resolve() + assert result == expected + + +def test_get_base_frontend_path_dev_mode_relative_path() -> None: + """Test get_base_frontend_path with relative dev path.""" + test_path = "./dashboard" + + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": test_path}): + result = web_server.get_base_frontend_path() + + # The function uses Path.resolve() which resolves symlinks + # The actual function adds "/" to the path, so we simulate that + test_path_with_slash = test_path if test_path.endswith("/") else test_path + "/" + expected = ( + Path(os.getcwd()) / test_path_with_slash / "esphome_dashboard" + ).resolve() + assert result == expected + assert result.is_absolute() + + +def test_get_static_path_single_component() -> None: + """Test get_static_path with single path component.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/frontend") + + result = web_server.get_static_path("file.js") + + assert result == Path("/base/frontend") / "static" / "file.js" + + +def test_get_static_path_multiple_components() -> None: + """Test get_static_path with multiple path components.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/frontend") + + result = web_server.get_static_path("js", "esphome", "index.js") + + assert ( + result == Path("/base/frontend") / "static" / "js" / "esphome" / "index.js" + ) + + +def test_get_static_path_empty_args() -> None: + """Test get_static_path with no arguments.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/frontend") + + result = web_server.get_static_path() + + assert result == Path("/base/frontend") / "static" + + +def test_get_static_path_with_pathlib_path() -> None: + """Test get_static_path with Path objects.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/frontend") + + path_obj = Path("js") / "app.js" + result = web_server.get_static_path(str(path_obj)) + + assert result == Path("/base/frontend") / "static" / "js" / "app.js" + + +def test_get_static_file_url_production() -> None: + """Test get_static_file_url in production mode.""" + web_server.get_static_file_url.cache_clear() + mock_module = MagicMock() + mock_path = MagicMock(spec=Path) + mock_path.read_bytes.return_value = b"test content" + + with ( + patch.dict(os.environ, {}, clear=True), + patch.dict("sys.modules", {"esphome_dashboard": mock_module}), + patch("esphome.dashboard.web_server.get_static_path") as mock_get_path, + ): + mock_get_path.return_value = mock_path + result = web_server.get_static_file_url("js/app.js") + assert result.startswith("./static/js/app.js?hash=") + + +def test_get_static_file_url_dev_mode() -> None: + """Test get_static_file_url in development mode.""" + with patch.dict(os.environ, {"ESPHOME_DASHBOARD_DEV": "/dev/path"}): + web_server.get_static_file_url.cache_clear() + result = web_server.get_static_file_url("js/app.js") + + assert result == "./static/js/app.js" + + +def test_get_static_file_url_index_js_special_case() -> None: + """Test get_static_file_url replaces index.js with entrypoint.""" + web_server.get_static_file_url.cache_clear() + mock_module = MagicMock() + mock_module.entrypoint.return_value = "main.js" + + with ( + patch.dict(os.environ, {}, clear=True), + patch.dict("sys.modules", {"esphome_dashboard": mock_module}), + ): + result = web_server.get_static_file_url("js/esphome/index.js") + assert result == "./static/js/esphome/main.js" + + +def test_load_file_path(tmp_path: Path) -> None: + """Test loading a file.""" + test_file = tmp_path / "test.txt" + test_file.write_bytes(b"test content") + + with open(test_file, "rb") as f: + content = f.read() + assert content == b"test content" + + +def test_load_file_compressed_path(tmp_path: Path) -> None: + """Test loading a compressed file.""" + test_file = tmp_path / "test.txt.gz" + + with gzip.open(test_file, "wb") as gz: + gz.write(b"compressed content") + + with gzip.open(test_file, "rb") as gz: + content = gz.read() + assert content == b"compressed content" + + +def test_path_normalization_in_static_path() -> None: + """Test that paths are normalized correctly.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/frontend") + + # Test with separate components + result1 = web_server.get_static_path("js", "app.js") + result2 = web_server.get_static_path("js", "app.js") + + assert result1 == result2 + assert result1 == Path("/base/frontend") / "static" / "js" / "app.js" + + +def test_windows_path_handling() -> None: + """Test handling of Windows-style paths.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path(r"C:\Program Files\esphome\frontend") + + result = web_server.get_static_path("js", "app.js") + + # Path should handle this correctly on the platform + expected = ( + Path(r"C:\Program Files\esphome\frontend") / "static" / "js" / "app.js" + ) + assert result == expected + + +def test_path_with_special_characters() -> None: + """Test paths with special characters.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/frontend") + + result = web_server.get_static_path("js-modules", "app_v1.0.js") + + assert ( + result == Path("/base/frontend") / "static" / "js-modules" / "app_v1.0.js" + ) + + +def test_path_with_spaces() -> None: + """Test paths with spaces.""" + with patch("esphome.dashboard.web_server.get_base_frontend_path") as mock_base: + mock_base.return_value = Path("/base/my frontend") + + result = web_server.get_static_path("my js", "my app.js") + + assert result == Path("/base/my frontend") / "static" / "my js" / "my app.js" diff --git a/tests/dashboard/util/test_file.py b/tests/dashboard/util/test_file.py deleted file mode 100644 index 51ba10b328..0000000000 --- a/tests/dashboard/util/test_file.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -from pathlib import Path -from unittest.mock import patch - -import py -import pytest - -from esphome.dashboard.util.file import write_file, write_utf8_file - - -def test_write_utf8_file(tmp_path: Path) -> None: - write_utf8_file(tmp_path.joinpath("foo.txt"), "foo") - assert tmp_path.joinpath("foo.txt").read_text() == "foo" - - with pytest.raises(OSError): - write_utf8_file(Path("/dev/not-writable"), "bar") - - -def test_write_file(tmp_path: Path) -> None: - write_file(tmp_path.joinpath("foo.txt"), b"foo") - assert tmp_path.joinpath("foo.txt").read_text() == "foo" - - -def test_write_utf8_file_fails_at_rename( - tmpdir: py.path.local, caplog: pytest.LogCaptureFixture -) -> None: - """Test that if rename fails not not remove, we do not log the failed cleanup.""" - test_dir = tmpdir.mkdir("files") - test_file = Path(test_dir / "test.json") - - with ( - pytest.raises(OSError), - patch("esphome.dashboard.util.file.os.replace", side_effect=OSError), - ): - write_utf8_file(test_file, '{"some":"data"}', False) - - assert not os.path.exists(test_file) - - assert "File replacement cleanup failed" not in caplog.text - - -def test_write_utf8_file_fails_at_rename_and_remove( - tmpdir: py.path.local, caplog: pytest.LogCaptureFixture -) -> None: - """Test that if rename and remove both fail, we log the failed cleanup.""" - test_dir = tmpdir.mkdir("files") - test_file = Path(test_dir / "test.json") - - with ( - pytest.raises(OSError), - patch("esphome.dashboard.util.file.os.remove", side_effect=OSError), - patch("esphome.dashboard.util.file.os.replace", side_effect=OSError), - ): - write_utf8_file(test_file, '{"some":"data"}', False) - - assert "File replacement cleanup failed" in caplog.text diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0530752551..965363972f 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -58,6 +58,8 @@ def _get_platformio_env(cache_dir: Path) -> dict[str, str]: env["PLATFORMIO_CORE_DIR"] = str(cache_dir) env["PLATFORMIO_CACHE_DIR"] = str(cache_dir / ".cache") env["PLATFORMIO_LIBDEPS_DIR"] = str(cache_dir / "libdeps") + # Prevent cache cleaning during integration tests + env["ESPHOME_SKIP_CLEAN_BUILD"] = "1" return env @@ -68,6 +70,11 @@ def shared_platformio_cache() -> Generator[Path]: test_cache_dir = Path.home() / ".esphome-integration-tests" cache_dir = test_cache_dir / "platformio" + # Create the temp directory that PlatformIO uses to avoid race conditions + # This ensures it exists and won't be deleted by parallel processes + platformio_tmp_dir = cache_dir / ".cache" / "tmp" + platformio_tmp_dir.mkdir(parents=True, exist_ok=True) + # Use a lock file in the home directory to ensure only one process initializes the cache # This is needed when running with pytest-xdist # The lock file must be in a directory that already exists to avoid race conditions @@ -83,17 +90,11 @@ def shared_platformio_cache() -> Generator[Path]: test_cache_dir.mkdir(exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: - # Create a basic host config + # Use the cache_init fixture for initialization init_dir = Path(tmpdir) + fixture_path = Path(__file__).parent / "fixtures" / "cache_init.yaml" config_path = init_dir / "cache_init.yaml" - config_path.write_text("""esphome: - name: cache-init -host: -api: - encryption: - key: "IIevImVI42I0FGos5nLqFK91jrJehrgidI0ArwMLr8w=" -logger: -""") + config_path.write_text(fixture_path.read_text()) # Run compilation to populate the cache # We must succeed here to avoid race conditions where multiple @@ -271,7 +272,7 @@ async def compile_esphome( def _read_config_and_get_binary(): CORE.reset() # Reset CORE state between test runs - CORE.config_path = str(config_path) + CORE.config_path = config_path config = esphome.config.read_config( {"command": "compile", "config": str(config_path)} ) @@ -346,7 +347,8 @@ async def wait_and_connect_api_client( noise_psk: str | None = None, client_info: str = "integration-test", timeout: float = API_CONNECTION_TIMEOUT, -) -> AsyncGenerator[APIClient]: + return_disconnect_event: bool = False, +) -> AsyncGenerator[APIClient | tuple[APIClient, asyncio.Event]]: """Wait for API to be available and connect.""" client = APIClient( address=address, @@ -359,14 +361,17 @@ async def wait_and_connect_api_client( # Create a future to signal when connected loop = asyncio.get_running_loop() connected_future: asyncio.Future[None] = loop.create_future() + disconnect_event = asyncio.Event() async def on_connect() -> None: """Called when successfully connected.""" + disconnect_event.clear() # Clear the disconnect event on new connection if not connected_future.done(): connected_future.set_result(None) async def on_disconnect(expected_disconnect: bool) -> None: """Called when disconnected.""" + disconnect_event.set() if not connected_future.done() and not expected_disconnect: connected_future.set_exception( APIConnectionError("Disconnected before fully connected") @@ -397,7 +402,10 @@ async def wait_and_connect_api_client( except TimeoutError: raise TimeoutError(f"Failed to connect to API after {timeout} seconds") - yield client + if return_disconnect_event: + yield client, disconnect_event + else: + yield client finally: # Stop reconnect logic and disconnect await reconnect_logic.stop() @@ -430,6 +438,33 @@ async def api_client_connected( yield _connect_client +@pytest_asyncio.fixture +async def api_client_connected_with_disconnect( + unused_tcp_port: int, +) -> AsyncGenerator: + """Factory for creating connected API client context managers with disconnect event.""" + + def _connect_client_with_disconnect( + address: str = LOCALHOST, + port: int | None = None, + password: str = "", + noise_psk: str | None = None, + client_info: str = "integration-test", + timeout: float = API_CONNECTION_TIMEOUT, + ): + return wait_and_connect_api_client( + address=address, + port=port if port is not None else unused_tcp_port, + password=password, + noise_psk=noise_psk, + client_info=client_info, + timeout=timeout, + return_disconnect_event=True, + ) + + yield _connect_client_with_disconnect + + async def _read_stream_lines( stream: asyncio.StreamReader, lines: list[str], diff --git a/tests/integration/fixtures/cache_init.yaml b/tests/integration/fixtures/cache_init.yaml new file mode 100644 index 0000000000..de208196cd --- /dev/null +++ b/tests/integration/fixtures/cache_init.yaml @@ -0,0 +1,10 @@ +esphome: + name: cache-init + +host: + +api: + encryption: + key: "IIevImVI42I0FGos5nLqFK91jrJehrgidI0ArwMLr8w=" + +logger: diff --git a/tests/integration/fixtures/noise_corrupt_encrypted_frame.yaml b/tests/integration/fixtures/noise_corrupt_encrypted_frame.yaml new file mode 100644 index 0000000000..6f0266c6fd --- /dev/null +++ b/tests/integration/fixtures/noise_corrupt_encrypted_frame.yaml @@ -0,0 +1,11 @@ +esphome: + name: oversized-noise + +host: + +api: + encryption: + key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU= + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_payload_noise.yaml b/tests/integration/fixtures/oversized_payload_noise.yaml new file mode 100644 index 0000000000..6f0266c6fd --- /dev/null +++ b/tests/integration/fixtures/oversized_payload_noise.yaml @@ -0,0 +1,11 @@ +esphome: + name: oversized-noise + +host: + +api: + encryption: + key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU= + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_payload_plaintext.yaml b/tests/integration/fixtures/oversized_payload_plaintext.yaml new file mode 100644 index 0000000000..44ece4f770 --- /dev/null +++ b/tests/integration/fixtures/oversized_payload_plaintext.yaml @@ -0,0 +1,9 @@ +esphome: + name: oversized-plaintext + +host: + +api: + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_protobuf_message_id_noise.yaml b/tests/integration/fixtures/oversized_protobuf_message_id_noise.yaml new file mode 100644 index 0000000000..6f0266c6fd --- /dev/null +++ b/tests/integration/fixtures/oversized_protobuf_message_id_noise.yaml @@ -0,0 +1,11 @@ +esphome: + name: oversized-noise + +host: + +api: + encryption: + key: N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU= + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/fixtures/oversized_protobuf_message_id_plaintext.yaml b/tests/integration/fixtures/oversized_protobuf_message_id_plaintext.yaml new file mode 100644 index 0000000000..1e9eadfdc5 --- /dev/null +++ b/tests/integration/fixtures/oversized_protobuf_message_id_plaintext.yaml @@ -0,0 +1,9 @@ +esphome: + name: oversized-protobuf-plaintext + +host: + +api: + +logger: + level: VERY_VERBOSE diff --git a/tests/integration/test_host_mode_api_password.py b/tests/integration/test_host_mode_api_password.py index 825c2c55f2..5c5e689e45 100644 --- a/tests/integration/test_host_mode_api_password.py +++ b/tests/integration/test_host_mode_api_password.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio -from aioesphomeapi import APIConnectionError +from aioesphomeapi import APIConnectionError, InvalidAuthAPIError import pytest from .types import APIClientConnectedFactory, RunCompiledFunction @@ -48,6 +48,22 @@ async def test_host_mode_api_password( assert len(states) > 0 # Test with wrong password - should fail - with pytest.raises(APIConnectionError, match="Invalid password"): - async with api_client_connected(password="wrong_password"): - pass # Should not reach here + # Try connecting with wrong password + try: + async with api_client_connected( + password="wrong_password", timeout=5 + ) as client: + # If we get here without exception, try to use the connection + # which should fail if auth failed + await client.device_info_and_list_entities() + # If we successfully got device info and entities, auth didn't fail properly + pytest.fail("Connection succeeded with wrong password") + except (InvalidAuthAPIError, APIConnectionError) as e: + # Expected - auth should fail + # Accept either InvalidAuthAPIError or generic APIConnectionError + # since the client might not always distinguish + assert ( + "password" in str(e).lower() + or "auth" in str(e).lower() + or "invalid" in str(e).lower() + ) diff --git a/tests/integration/test_oversized_payloads.py b/tests/integration/test_oversized_payloads.py new file mode 100644 index 0000000000..ba18e3d348 --- /dev/null +++ b/tests/integration/test_oversized_payloads.py @@ -0,0 +1,337 @@ +"""Integration tests for oversized payloads and headers that should cause disconnection.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from .types import APIClientConnectedWithDisconnectFactory, RunCompiledFunction + + +@pytest.mark.asyncio +async def test_oversized_payload_plaintext( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that oversized payloads (>32768 bytes) from client cause disconnection without crashing.""" + process_exited = False + helper_log_found = False + + def check_logs(line: str) -> None: + nonlocal process_exited, helper_log_found + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for HELPER_LOG message about message size exceeding maximum + if ( + "[VV]" in line + and "Bad packet: message size" in line + and "exceeds maximum" in line + ): + helper_log_found = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect() as (client, disconnect_event): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-plaintext" + + # Create an oversized payload (>32768 bytes which is our new limit) + oversized_data = b"X" * 40000 # ~40KiB, exceeds the 32768 byte limit + + # Access the internal connection to send raw data + frame_helper = client._connection._frame_helper + # Create a message with oversized payload + # Using message type 1 (DeviceInfoRequest) as an example + message_type = 1 + frame_helper.write_packets([(message_type, oversized_data)], True) + + # Wait for the connection to be closed by ESPHome + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + # Verify we saw the expected HELPER_LOG message + assert helper_log_found, ( + "Expected to see HELPER_LOG about message size exceeding maximum" + ) + + # Try to reconnect to verify the process is still running + async with api_client_connected_with_disconnect() as (client2, _): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-plaintext" + + +@pytest.mark.asyncio +async def test_oversized_protobuf_message_id_plaintext( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that protobuf messages with ID > UINT16_MAX cause disconnection without crashing. + + This tests the message type limit - message IDs must fit in a uint16_t (0-65535). + """ + process_exited = False + helper_log_found = False + + def check_logs(line: str) -> None: + nonlocal process_exited, helper_log_found + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for HELPER_LOG message about message type exceeding maximum + if ( + "[VV]" in line + and "Bad packet: message type" in line + and "exceeds maximum" in line + ): + helper_log_found = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect() as (client, disconnect_event): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-protobuf-plaintext" + + # Access the internal connection to send raw message with large ID + frame_helper = client._connection._frame_helper + # Message ID that exceeds uint16_t limit (> 65535) + large_message_id = 65536 # 2^16, exceeds UINT16_MAX + # Small payload for the test + payload = b"test" + + # This should cause disconnection due to oversized varint + frame_helper.write_packets([(large_message_id, payload)], True) + + # Wait for the connection to be closed by ESPHome + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + # Verify we saw the expected HELPER_LOG message + assert helper_log_found, ( + "Expected to see HELPER_LOG about message type exceeding maximum" + ) + + # Try to reconnect to verify the process is still running + async with api_client_connected_with_disconnect() as (client2, _): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-protobuf-plaintext" + + +@pytest.mark.asyncio +async def test_oversized_payload_noise( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that oversized payloads from client cause disconnection without crashing with noise encryption.""" + noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=" + process_exited = False + helper_log_found = False + + def check_logs(line: str) -> None: + nonlocal process_exited, helper_log_found + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for HELPER_LOG message about message size exceeding maximum + # With our new protection, oversized messages are rejected at frame level + if ( + "[VV]" in line + and "Bad packet: message size" in line + and "exceeds maximum" in line + ): + helper_log_found = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client, + disconnect_event, + ): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + # Create an oversized payload (>32768 bytes which is our new limit) + oversized_data = b"Y" * 40000 # ~40KiB, exceeds the 32768 byte limit + + # Access the internal connection to send raw data + frame_helper = client._connection._frame_helper + # For noise connections, we still send through write_packets + # but the frame helper will handle encryption + # Using message type 1 (DeviceInfoRequest) as an example + message_type = 1 + frame_helper.write_packets([(message_type, oversized_data)], True) + + # Wait for the connection to be closed by ESPHome + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + # Verify we saw the expected HELPER_LOG message + assert helper_log_found, ( + "Expected to see HELPER_LOG about message size exceeding maximum" + ) + + # Try to reconnect to verify the process is still running + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client2, + _, + ): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + +@pytest.mark.asyncio +async def test_oversized_protobuf_message_id_noise( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that the noise protocol handles unknown message types correctly. + + With noise encryption, message types are stored as uint16_t (2 bytes) after decryption. + Unknown message types should be ignored without disconnecting, as ESPHome needs to + read the full message to maintain encryption stream continuity. + """ + noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=" + process_exited = False + + def check_logs(line: str) -> None: + nonlocal process_exited + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client, + disconnect_event, + ): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + # With noise, message types are uint16_t, so we test with an unknown but valid value + frame_helper = client._connection._frame_helper + + # Test with an unknown message type (65535 is not used by ESPHome) + unknown_message_id = 65535 # Valid uint16_t but unknown to ESPHome + payload = b"test" + + # Send the unknown message type - ESPHome should read and ignore it + frame_helper.write_packets([(unknown_message_id, payload)], True) + + # Give ESPHome a moment to process (but expect no disconnection) + # The connection should stay alive as ESPHome ignores unknown message types + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(disconnect_event.wait(), timeout=0.5) + + # Connection should still be alive - unknown types are ignored, not fatal + assert client._connection.is_connected, ( + "Connection should remain open for unknown message types" + ) + + # Verify we can still communicate by sending a valid request + device_info2 = await client.device_info() + assert device_info2 is not None + assert device_info2.name == "oversized-noise" + + # After test, verify process didn't crash + assert not process_exited, "ESPHome process should not crash" + + # Verify we can still reconnect + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client2, + _, + ): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + +@pytest.mark.asyncio +async def test_noise_corrupt_encrypted_frame( + yaml_config: str, + run_compiled: RunCompiledFunction, + api_client_connected_with_disconnect: APIClientConnectedWithDisconnectFactory, +) -> None: + """Test that noise protocol properly handles corrupt encrypted frames. + + Send a frame with valid size but corrupt encrypted content (garbage bytes). + This should fail decryption and cause disconnection. + """ + noise_key = "N4Yle5YirwZhPiHHsdZLdOA73ndj/84veVaLhTvxCuU=" + process_exited = False + cipherstate_failed = False + + def check_logs(line: str) -> None: + nonlocal process_exited, cipherstate_failed + # Check for signs that the process exited/crashed + if "Segmentation fault" in line or "core dumped" in line: + process_exited = True + # Check for the expected warning about decryption failure + if ( + "[W][api.connection" in line + and "Reading failed CIPHERSTATE_DECRYPT_FAILED" in line + ): + cipherstate_failed = True + + async with run_compiled(yaml_config, line_callback=check_logs): + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client, + disconnect_event, + ): + # Verify basic connection works first + device_info = await client.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" + + # Get the socket to send raw corrupt data + socket = client._connection._socket + + # Send a corrupt noise frame directly to the socket + # Format: [indicator=0x01][size_high][size_low][garbage_encrypted_data] + # Size of 32 bytes (reasonable size for a noise frame with MAC) + corrupt_frame = bytes( + [ + 0x01, # Noise indicator + 0x00, # Size high byte + 0x20, # Size low byte (32 bytes) + ] + ) + bytes(32) # 32 bytes of zeros (invalid encrypted data) + + # Send the corrupt frame + socket.sendall(corrupt_frame) + + # Wait for ESPHome to disconnect due to decryption failure + await asyncio.wait_for(disconnect_event.wait(), timeout=5.0) + + # After disconnection, verify process didn't crash + assert not process_exited, ( + "ESPHome process should not crash on corrupt encrypted frames" + ) + # Verify we saw the expected warning message + assert cipherstate_failed, ( + "Expected to see warning about CIPHERSTATE_DECRYPT_FAILED" + ) + + # Verify we can still reconnect after handling the corrupt frame + async with api_client_connected_with_disconnect(noise_psk=noise_key) as ( + client2, + _, + ): + device_info = await client2.device_info() + assert device_info is not None + assert device_info.name == "oversized-noise" diff --git a/tests/integration/types.py b/tests/integration/types.py index 5e4bfaa29d..b6728a2fcb 100644 --- a/tests/integration/types.py +++ b/tests/integration/types.py @@ -54,3 +54,17 @@ class APIClientConnectedFactory(Protocol): client_info: str = "integration-test", timeout: float = 30, ) -> AbstractAsyncContextManager[APIClient]: ... + + +class APIClientConnectedWithDisconnectFactory(Protocol): + """Protocol for connected API client factory that returns disconnect event.""" + + def __call__( # noqa: E704 + self, + address: str = "localhost", + port: int | None = None, + password: str = "", + noise_psk: str | None = None, + client_info: str = "integration-test", + timeout: float = 30, + ) -> AbstractAsyncContextManager[tuple[APIClient, asyncio.Event]]: ... diff --git a/tests/script/test_clang_tidy_hash.py b/tests/script/test_clang_tidy_hash.py index 2f84d11a0d..b1690a6a2d 100644 --- a/tests/script/test_clang_tidy_hash.py +++ b/tests/script/test_clang_tidy_hash.py @@ -44,37 +44,53 @@ def test_get_clang_tidy_version_from_requirements( assert result == expected -def test_calculate_clang_tidy_hash() -> None: - """Test calculating hash from all configuration sources.""" +def test_calculate_clang_tidy_hash_with_sdkconfig(tmp_path: Path) -> None: + """Test calculating hash from all configuration sources including sdkconfig.defaults.""" clang_tidy_content = b"Checks: '-*,readability-*'\n" requirements_version = "clang-tidy==18.1.5" platformio_content = b"[env:esp32]\nplatform = espressif32\n" + sdkconfig_content = b"CONFIG_AUTOSTART_ARDUINO=y\n" + requirements_content = "clang-tidy==18.1.5\n" + + # Create temporary files + (tmp_path / ".clang-tidy").write_bytes(clang_tidy_content) + (tmp_path / "platformio.ini").write_bytes(platformio_content) + (tmp_path / "sdkconfig.defaults").write_bytes(sdkconfig_content) + (tmp_path / "requirements_dev.txt").write_text(requirements_content) # Expected hash calculation expected_hasher = hashlib.sha256() expected_hasher.update(clang_tidy_content) expected_hasher.update(requirements_version.encode()) expected_hasher.update(platformio_content) + expected_hasher.update(sdkconfig_content) expected_hash = expected_hasher.hexdigest() - # Mock the dependencies - with ( - patch("clang_tidy_hash.read_file_bytes") as mock_read_bytes, - patch( - "clang_tidy_hash.get_clang_tidy_version_from_requirements", - return_value=requirements_version, - ), - ): - # Set up mock to return different content based on the file being read - def read_file_mock(path: Path) -> bytes: - if ".clang-tidy" in str(path): - return clang_tidy_content - if "platformio.ini" in str(path): - return platformio_content - return b"" + result = clang_tidy_hash.calculate_clang_tidy_hash(repo_root=tmp_path) - mock_read_bytes.side_effect = read_file_mock - result = clang_tidy_hash.calculate_clang_tidy_hash() + assert result == expected_hash + + +def test_calculate_clang_tidy_hash_without_sdkconfig(tmp_path: Path) -> None: + """Test calculating hash without sdkconfig.defaults file.""" + clang_tidy_content = b"Checks: '-*,readability-*'\n" + requirements_version = "clang-tidy==18.1.5" + platformio_content = b"[env:esp32]\nplatform = espressif32\n" + requirements_content = "clang-tidy==18.1.5\n" + + # Create temporary files (without sdkconfig.defaults) + (tmp_path / ".clang-tidy").write_bytes(clang_tidy_content) + (tmp_path / "platformio.ini").write_bytes(platformio_content) + (tmp_path / "requirements_dev.txt").write_text(requirements_content) + + # Expected hash calculation (no sdkconfig) + expected_hasher = hashlib.sha256() + expected_hasher.update(clang_tidy_content) + expected_hasher.update(requirements_version.encode()) + expected_hasher.update(platformio_content) + expected_hash = expected_hasher.hexdigest() + + result = clang_tidy_hash.calculate_clang_tidy_hash(repo_root=tmp_path) assert result == expected_hash @@ -85,67 +101,63 @@ def test_read_stored_hash_exists(tmp_path: Path) -> None: hash_file = tmp_path / ".clang-tidy.hash" hash_file.write_text(f"{stored_hash}\n") - with ( - patch("clang_tidy_hash.Path") as mock_path_class, - patch("clang_tidy_hash.read_file_lines", return_value=[f"{stored_hash}\n"]), - ): - # Mock the path calculation and exists check - mock_hash_file = Mock() - mock_hash_file.exists.return_value = True - mock_path_class.return_value.parent.parent.__truediv__.return_value = ( - mock_hash_file - ) - - result = clang_tidy_hash.read_stored_hash() + result = clang_tidy_hash.read_stored_hash(repo_root=tmp_path) assert result == stored_hash -def test_read_stored_hash_not_exists() -> None: +def test_read_stored_hash_not_exists(tmp_path: Path) -> None: """Test reading hash when file doesn't exist.""" - with patch("clang_tidy_hash.Path") as mock_path_class: - # Mock the path calculation and exists check - mock_hash_file = Mock() - mock_hash_file.exists.return_value = False - mock_path_class.return_value.parent.parent.__truediv__.return_value = ( - mock_hash_file - ) - - result = clang_tidy_hash.read_stored_hash() + result = clang_tidy_hash.read_stored_hash(repo_root=tmp_path) assert result is None -def test_write_hash() -> None: +def test_write_hash(tmp_path: Path) -> None: """Test writing hash to file.""" hash_value = "abc123def456" + hash_file = tmp_path / ".clang-tidy.hash" - with patch("clang_tidy_hash.write_file_content") as mock_write: - clang_tidy_hash.write_hash(hash_value) + clang_tidy_hash.write_hash(hash_value, repo_root=tmp_path) - # Verify write_file_content was called with correct parameters - mock_write.assert_called_once() - args = mock_write.call_args[0] - assert str(args[0]).endswith(".clang-tidy.hash") - assert args[1] == hash_value.strip() + "\n" + assert hash_file.exists() + assert hash_file.read_text() == hash_value.strip() + "\n" @pytest.mark.parametrize( - ("args", "current_hash", "stored_hash", "expected_exit"), + ("args", "current_hash", "stored_hash", "hash_file_in_changed", "expected_exit"), [ - (["--check"], "abc123", "abc123", 1), # Hashes match, no scan needed - (["--check"], "abc123", "def456", 0), # Hashes differ, scan needed - (["--check"], "abc123", None, 0), # No stored hash, scan needed + (["--check"], "abc123", "abc123", False, 1), # Hashes match, no scan needed + (["--check"], "abc123", "def456", False, 0), # Hashes differ, scan needed + (["--check"], "abc123", None, False, 0), # No stored hash, scan needed + ( + ["--check"], + "abc123", + "abc123", + True, + 0, + ), # Hash file updated in PR, scan needed ], ) def test_main_check_mode( - args: list[str], current_hash: str, stored_hash: str | None, expected_exit: int + args: list[str], + current_hash: str, + stored_hash: str | None, + hash_file_in_changed: bool, + expected_exit: int, ) -> None: """Test main function in check mode.""" + changed = [".clang-tidy.hash"] if hash_file_in_changed else [] + + # Create a mock module that can be imported + mock_helpers = Mock() + mock_helpers.changed_files = Mock(return_value=changed) + with ( patch("sys.argv", ["clang_tidy_hash.py"] + args), patch("clang_tidy_hash.calculate_clang_tidy_hash", return_value=current_hash), patch("clang_tidy_hash.read_stored_hash", return_value=stored_hash), + patch.dict("sys.modules", {"helpers": mock_helpers}), pytest.raises(SystemExit) as exc_info, ): clang_tidy_hash.main() diff --git a/tests/unit_tests/build_gen/test_platformio.py b/tests/unit_tests/build_gen/test_platformio.py new file mode 100644 index 0000000000..a124dbc128 --- /dev/null +++ b/tests/unit_tests/build_gen/test_platformio.py @@ -0,0 +1,188 @@ +"""Tests for esphome.build_gen.platformio module.""" + +from __future__ import annotations + +from collections.abc import Generator +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from esphome.build_gen import platformio +from esphome.core import CORE + + +@pytest.fixture +def mock_update_storage_json() -> Generator[MagicMock]: + """Mock update_storage_json for all tests.""" + with patch("esphome.build_gen.platformio.update_storage_json") as mock: + yield mock + + +@pytest.fixture +def mock_write_file_if_changed() -> Generator[MagicMock]: + """Mock write_file_if_changed for tests.""" + with patch("esphome.build_gen.platformio.write_file_if_changed") as mock: + yield mock + + +def test_write_ini_creates_new_file( + tmp_path: Path, mock_update_storage_json: MagicMock +) -> None: + """Test write_ini creates a new platformio.ini file.""" + CORE.build_path = str(tmp_path) + + content = """ +[env:test] +platform = espressif32 +board = esp32dev +framework = arduino +""" + + platformio.write_ini(content) + + ini_file = tmp_path / "platformio.ini" + assert ini_file.exists() + + file_content = ini_file.read_text() + assert content in file_content + assert platformio.INI_AUTO_GENERATE_BEGIN in file_content + assert platformio.INI_AUTO_GENERATE_END in file_content + + +def test_write_ini_updates_existing_file( + tmp_path: Path, mock_update_storage_json: MagicMock +) -> None: + """Test write_ini updates existing platformio.ini file.""" + CORE.build_path = str(tmp_path) + + # Create existing file with custom content + ini_file = tmp_path / "platformio.ini" + existing_content = f""" +; Custom header +[platformio] +default_envs = test + +{platformio.INI_AUTO_GENERATE_BEGIN} +; Old auto-generated content +[env:old] +platform = old +{platformio.INI_AUTO_GENERATE_END} + +; Custom footer +""" + ini_file.write_text(existing_content) + + # New content to write + new_content = """ +[env:test] +platform = espressif32 +board = esp32dev +framework = arduino +""" + + platformio.write_ini(new_content) + + file_content = ini_file.read_text() + + # Check that custom parts are preserved + assert "; Custom header" in file_content + assert "[platformio]" in file_content + assert "default_envs = test" in file_content + assert "; Custom footer" in file_content + + # Check that new content replaced old auto-generated content + assert new_content in file_content + assert "[env:old]" not in file_content + assert "platform = old" not in file_content + + +def test_write_ini_preserves_custom_sections( + tmp_path: Path, mock_update_storage_json: MagicMock +) -> None: + """Test write_ini preserves custom sections outside auto-generate markers.""" + CORE.build_path = str(tmp_path) + + # Create existing file with multiple custom sections + ini_file = tmp_path / "platformio.ini" + existing_content = f""" +[platformio] +src_dir = . +include_dir = . + +[common] +lib_deps = + Wire + SPI + +{platformio.INI_AUTO_GENERATE_BEGIN} +[env:old] +platform = old +{platformio.INI_AUTO_GENERATE_END} + +[env:custom] +upload_speed = 921600 +monitor_speed = 115200 +""" + ini_file.write_text(existing_content) + + new_content = "[env:auto]\nplatform = new" + + platformio.write_ini(new_content) + + file_content = ini_file.read_text() + + # All custom sections should be preserved + assert "[platformio]" in file_content + assert "src_dir = ." in file_content + assert "[common]" in file_content + assert "lib_deps" in file_content + assert "[env:custom]" in file_content + assert "upload_speed = 921600" in file_content + + # New auto-generated content should replace old + assert "[env:auto]" in file_content + assert "platform = new" in file_content + assert "[env:old]" not in file_content + + +def test_write_ini_no_change_when_content_same( + tmp_path: Path, + mock_update_storage_json: MagicMock, + mock_write_file_if_changed: MagicMock, +) -> None: + """Test write_ini doesn't rewrite file when content is unchanged.""" + CORE.build_path = str(tmp_path) + + content = "[env:test]\nplatform = esp32" + full_content = ( + f"{platformio.INI_BASE_FORMAT[0]}" + f"{platformio.INI_AUTO_GENERATE_BEGIN}\n" + f"{content}" + f"{platformio.INI_AUTO_GENERATE_END}" + f"{platformio.INI_BASE_FORMAT[1]}" + ) + + ini_file = tmp_path / "platformio.ini" + ini_file.write_text(full_content) + + mock_write_file_if_changed.return_value = False # Indicate no change + platformio.write_ini(content) + + # write_file_if_changed should be called with the same content + mock_write_file_if_changed.assert_called_once() + call_args = mock_write_file_if_changed.call_args[0] + assert call_args[0] == ini_file + assert content in call_args[1] + + +def test_write_ini_calls_update_storage_json( + tmp_path: Path, mock_update_storage_json: MagicMock +) -> None: + """Test write_ini calls update_storage_json.""" + CORE.build_path = str(tmp_path) + + content = "[env:test]\nplatform = esp32" + + platformio.write_ini(content) + mock_update_storage_json.assert_called_once() diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 06d06d0506..932221997c 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -43,7 +43,7 @@ def fixture_path() -> Path: @pytest.fixture def setup_core(tmp_path: Path) -> Path: """Set up CORE with test paths.""" - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" return tmp_path @@ -87,3 +87,24 @@ def mock_run_external_command() -> Generator[Mock, None, None]: """Mock run_external_command for platformio_api.""" with patch("esphome.platformio_api.run_external_command") as mock: yield mock + + +@pytest.fixture +def mock_run_git_command() -> Generator[Mock, None, None]: + """Mock run_git_command for git module.""" + with patch("esphome.git.run_git_command") as mock: + yield mock + + +@pytest.fixture +def mock_get_idedata() -> Generator[Mock, None, None]: + """Mock get_idedata for platformio_api.""" + with patch("esphome.platformio_api.get_idedata") as mock: + yield mock + + +@pytest.fixture +def mock_get_component() -> Generator[Mock, None, None]: + """Mock get_component for config module.""" + with patch("esphome.config.get_component") as mock: + yield mock diff --git a/tests/unit_tests/core/common.py b/tests/unit_tests/core/common.py index 1848d5397b..daa429dc96 100644 --- a/tests/unit_tests/core/common.py +++ b/tests/unit_tests/core/common.py @@ -10,7 +10,7 @@ from esphome.core import CORE def load_config_from_yaml( - yaml_file: Callable[[str], str], yaml_content: str + yaml_file: Callable[[str], Path], yaml_content: str ) -> Config | None: """Load configuration from YAML content.""" yaml_path = yaml_file(yaml_content) @@ -25,7 +25,7 @@ def load_config_from_yaml( def load_config_from_fixture( - yaml_file: Callable[[str], str], fixture_name: str, fixtures_dir: Path + yaml_file: Callable[[str], Path], fixture_name: str, fixtures_dir: Path ) -> Config | None: """Load configuration from a fixture file.""" fixture_path = fixtures_dir / fixture_name diff --git a/tests/unit_tests/core/conftest.py b/tests/unit_tests/core/conftest.py index 60d6738ce9..42e59c15e6 100644 --- a/tests/unit_tests/core/conftest.py +++ b/tests/unit_tests/core/conftest.py @@ -7,12 +7,12 @@ import pytest @pytest.fixture -def yaml_file(tmp_path: Path) -> Callable[[str], str]: +def yaml_file(tmp_path: Path) -> Callable[[str], Path]: """Create a temporary YAML file for testing.""" - def _yaml_file(content: str) -> str: + def _yaml_file(content: str) -> Path: yaml_path = tmp_path / "test.yaml" yaml_path.write_text(content) - return str(yaml_path) + return yaml_path return _yaml_file diff --git a/tests/unit_tests/core/test_config.py b/tests/unit_tests/core/test_config.py index 46fe0148d8..4fddfc9678 100644 --- a/tests/unit_tests/core/test_config.py +++ b/tests/unit_tests/core/test_config.py @@ -35,6 +35,22 @@ from .common import load_config_from_fixture FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" / "core" / "config" +@pytest.fixture +def mock_cg_with_include_capture() -> tuple[Mock, list[str]]: + """Mock code generation with include capture.""" + includes_added: list[str] = [] + + with patch("esphome.core.config.cg") as mock_cg: + mock_raw_statement = MagicMock() + + def capture_include(text: str) -> MagicMock: + includes_added.append(text) + return mock_raw_statement + + mock_cg.RawStatement.side_effect = capture_include + yield mock_cg, includes_added + + def test_validate_area_config_with_string() -> None: """Test that string area config is converted to structured format.""" result = validate_area_config("Living Room") @@ -273,7 +289,7 @@ def test_valid_include_with_angle_brackets() -> None: def test_valid_include_with_valid_file(tmp_path: Path) -> None: """Test valid_include accepts valid include files.""" - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" include_file = tmp_path / "include.h" include_file.touch() @@ -282,7 +298,7 @@ def test_valid_include_with_valid_file(tmp_path: Path) -> None: def test_valid_include_with_valid_directory(tmp_path: Path) -> None: """Test valid_include accepts valid directories.""" - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" include_dir = tmp_path / "includes" include_dir.mkdir() @@ -291,7 +307,7 @@ def test_valid_include_with_valid_directory(tmp_path: Path) -> None: def test_valid_include_invalid_extension(tmp_path: Path) -> None: """Test valid_include rejects files with invalid extensions.""" - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" invalid_file = tmp_path / "file.txt" invalid_file.touch() @@ -465,7 +481,7 @@ def test_include_file_header(tmp_path: Path, mock_copy_file_if_changed: Mock) -> src_file = tmp_path / "source.h" src_file.write_text("// Header content") - CORE.build_path = str(tmp_path / "build") + CORE.build_path = tmp_path / "build" with patch("esphome.core.config.cg") as mock_cg: # Mock RawStatement to capture the text @@ -478,7 +494,7 @@ def test_include_file_header(tmp_path: Path, mock_copy_file_if_changed: Mock) -> mock_cg.RawStatement.side_effect = raw_statement_side_effect - config.include_file(str(src_file), "test.h") + config.include_file(src_file, Path("test.h")) mock_copy_file_if_changed.assert_called_once() mock_cg.add_global.assert_called_once() @@ -491,10 +507,10 @@ def test_include_file_cpp(tmp_path: Path, mock_copy_file_if_changed: Mock) -> No src_file = tmp_path / "source.cpp" src_file.write_text("// CPP content") - CORE.build_path = str(tmp_path / "build") + CORE.build_path = tmp_path / "build" with patch("esphome.core.config.cg") as mock_cg: - config.include_file(str(src_file), "test.cpp") + config.include_file(src_file, Path("test.cpp")) mock_copy_file_if_changed.assert_called_once() # Should not add include statement for .cpp files @@ -577,3 +593,262 @@ def test_is_target_platform() -> None: assert config._is_target_platform("rp2040") is True assert config._is_target_platform("invalid_platform") is False assert config._is_target_platform("api") is False # Component but not platform + + +@pytest.mark.asyncio +async def test_add_includes_with_single_file( + tmp_path: Path, + mock_copy_file_if_changed: Mock, + mock_cg_with_include_capture: tuple[Mock, list[str]], +) -> None: + """Test add_includes copies a single header file to build directory.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create include file + include_file = tmp_path / "my_header.h" + include_file.write_text("#define MY_CONSTANT 42") + + mock_cg, includes_added = mock_cg_with_include_capture + + await config.add_includes([str(include_file)]) + + # Verify copy_file_if_changed was called to copy the file + # Note: add_includes adds files to a src/ subdirectory + mock_copy_file_if_changed.assert_called_once_with( + include_file, CORE.build_path / "src" / "my_header.h" + ) + + # Verify include statement was added + assert any('#include "my_header.h"' in inc for inc in includes_added) + + +@pytest.mark.asyncio +@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test") +async def test_add_includes_with_directory_unix( + tmp_path: Path, + mock_copy_file_if_changed: Mock, + mock_cg_with_include_capture: tuple[Mock, list[str]], +) -> None: + """Test add_includes copies all files from a directory on Unix.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create include directory with files + include_dir = tmp_path / "includes" + include_dir.mkdir() + (include_dir / "header1.h").write_text("#define HEADER1") + (include_dir / "header2.hpp").write_text("#define HEADER2") + (include_dir / "source.cpp").write_text("// Implementation") + (include_dir / "README.md").write_text( + "# Documentation" + ) # Should be copied but not included + + # Create subdirectory with files + subdir = include_dir / "subdir" + subdir.mkdir() + (subdir / "nested.h").write_text("#define NESTED") + + mock_cg, includes_added = mock_cg_with_include_capture + + await config.add_includes([str(include_dir)]) + + # Verify copy_file_if_changed was called for all files + assert mock_copy_file_if_changed.call_count == 5 # 4 code files + 1 README + + # Verify include statements were added for valid extensions + include_strings = " ".join(includes_added) + assert "includes/header1.h" in include_strings + assert "includes/header2.hpp" in include_strings + assert "includes/subdir/nested.h" in include_strings + # CPP files are copied but not included + assert "source.cpp" not in include_strings or "#include" not in include_strings + # README.md should not have an include statement + assert "README.md" not in include_strings + + +@pytest.mark.asyncio +@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") +async def test_add_includes_with_directory_windows( + tmp_path: Path, + mock_copy_file_if_changed: Mock, + mock_cg_with_include_capture: tuple[Mock, list[str]], +) -> None: + """Test add_includes copies all files from a directory on Windows.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create include directory with files + include_dir = tmp_path / "includes" + include_dir.mkdir() + (include_dir / "header1.h").write_text("#define HEADER1") + (include_dir / "header2.hpp").write_text("#define HEADER2") + (include_dir / "source.cpp").write_text("// Implementation") + (include_dir / "README.md").write_text( + "# Documentation" + ) # Should be copied but not included + + # Create subdirectory with files + subdir = include_dir / "subdir" + subdir.mkdir() + (subdir / "nested.h").write_text("#define NESTED") + + mock_cg, includes_added = mock_cg_with_include_capture + + await config.add_includes([str(include_dir)]) + + # Verify copy_file_if_changed was called for all files + assert mock_copy_file_if_changed.call_count == 5 # 4 code files + 1 README + + # Verify include statements were added for valid extensions + include_strings = " ".join(includes_added) + assert "includes\\header1.h" in include_strings + assert "includes\\header2.hpp" in include_strings + assert "includes\\subdir\\nested.h" in include_strings + # CPP files are copied but not included + assert "source.cpp" not in include_strings or "#include" not in include_strings + # README.md should not have an include statement + assert "README.md" not in include_strings + + +@pytest.mark.asyncio +async def test_add_includes_with_multiple_sources( + tmp_path: Path, mock_copy_file_if_changed: Mock +) -> None: + """Test add_includes with multiple files and directories.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create various include sources + single_file = tmp_path / "single.h" + single_file.write_text("#define SINGLE") + + dir1 = tmp_path / "dir1" + dir1.mkdir() + (dir1 / "file1.h").write_text("#define FILE1") + + dir2 = tmp_path / "dir2" + dir2.mkdir() + (dir2 / "file2.cpp").write_text("// File2") + + with patch("esphome.core.config.cg"): + await config.add_includes([str(single_file), str(dir1), str(dir2)]) + + # Verify copy_file_if_changed was called for all files + assert mock_copy_file_if_changed.call_count == 3 # 3 files total + + +@pytest.mark.asyncio +async def test_add_includes_empty_directory( + tmp_path: Path, mock_copy_file_if_changed: Mock +) -> None: + """Test add_includes with an empty directory doesn't fail.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create empty directory + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + + with patch("esphome.core.config.cg"): + # Should not raise any errors + await config.add_includes([str(empty_dir)]) + + # No files to copy from empty directory + mock_copy_file_if_changed.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test") +async def test_add_includes_preserves_directory_structure_unix( + tmp_path: Path, mock_copy_file_if_changed: Mock +) -> None: + """Test that add_includes preserves relative directory structure on Unix.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create nested directory structure + lib_dir = tmp_path / "lib" + lib_dir.mkdir() + + src_dir = lib_dir / "src" + src_dir.mkdir() + (src_dir / "core.h").write_text("#define CORE") + + utils_dir = lib_dir / "utils" + utils_dir.mkdir() + (utils_dir / "helper.h").write_text("#define HELPER") + + with patch("esphome.core.config.cg"): + await config.add_includes([str(lib_dir)]) + + # Verify copy_file_if_changed was called with correct paths + calls = mock_copy_file_if_changed.call_args_list + dest_paths = [call[0][1] for call in calls] + + # Check that relative paths are preserved + assert any("lib/src/core.h" in str(path) for path in dest_paths) + assert any("lib/utils/helper.h" in str(path) for path in dest_paths) + + +@pytest.mark.asyncio +@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") +async def test_add_includes_preserves_directory_structure_windows( + tmp_path: Path, mock_copy_file_if_changed: Mock +) -> None: + """Test that add_includes preserves relative directory structure on Windows.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create nested directory structure + lib_dir = tmp_path / "lib" + lib_dir.mkdir() + + src_dir = lib_dir / "src" + src_dir.mkdir() + (src_dir / "core.h").write_text("#define CORE") + + utils_dir = lib_dir / "utils" + utils_dir.mkdir() + (utils_dir / "helper.h").write_text("#define HELPER") + + with patch("esphome.core.config.cg"): + await config.add_includes([str(lib_dir)]) + + # Verify copy_file_if_changed was called with correct paths + calls = mock_copy_file_if_changed.call_args_list + dest_paths = [call[0][1] for call in calls] + + # Check that relative paths are preserved + assert any("lib\\src\\core.h" in str(path) for path in dest_paths) + assert any("lib\\utils\\helper.h" in str(path) for path in dest_paths) + + +@pytest.mark.asyncio +async def test_add_includes_overwrites_existing_files( + tmp_path: Path, mock_copy_file_if_changed: Mock +) -> None: + """Test that add_includes overwrites existing files in build directory.""" + CORE.config_path = tmp_path / "config.yaml" + CORE.build_path = tmp_path / "build" + os.makedirs(CORE.build_path, exist_ok=True) + + # Create include file + include_file = tmp_path / "header.h" + include_file.write_text("#define NEW_VALUE 42") + + with patch("esphome.core.config.cg"): + await config.add_includes([str(include_file)]) + + # Verify copy_file_if_changed was called (it handles overwriting) + # Note: add_includes adds files to a src/ subdirectory + mock_copy_file_if_changed.assert_called_once_with( + include_file, CORE.build_path / "src" / "header.h" + ) diff --git a/tests/unit_tests/fixtures/auto_load_dynamic.yaml b/tests/unit_tests/fixtures/auto_load_dynamic.yaml new file mode 100644 index 0000000000..b604a2a42b --- /dev/null +++ b/tests/unit_tests/fixtures/auto_load_dynamic.yaml @@ -0,0 +1,10 @@ +esphome: + name: test-device + +esp32: + board: esp32dev + +# Test component with dynamic AUTO_LOAD +test_component: + enable_logger: true + enable_api: false diff --git a/tests/unit_tests/fixtures/auto_load_static.yaml b/tests/unit_tests/fixtures/auto_load_static.yaml new file mode 100644 index 0000000000..c8f9e6222a --- /dev/null +++ b/tests/unit_tests/fixtures/auto_load_static.yaml @@ -0,0 +1,8 @@ +esphome: + name: test-device + +esp32: + board: esp32dev + +# Test component with static AUTO_LOAD +test_component: diff --git a/tests/unit_tests/fixtures/ota_empty_dict.yaml b/tests/unit_tests/fixtures/ota_empty_dict.yaml new file mode 100644 index 0000000000..cf9b166afa --- /dev/null +++ b/tests/unit_tests/fixtures/ota_empty_dict.yaml @@ -0,0 +1,17 @@ +esphome: + name: test-device2 + +esp32: + board: esp32dev + framework: + type: esp-idf + +# OTA with empty dict - should be normalized +ota: {} + +wifi: + ssid: "test" + password: "test" + +# Captive portal auto-loads ota.web_server which triggers the issue +captive_portal: diff --git a/tests/unit_tests/fixtures/ota_no_platform.yaml b/tests/unit_tests/fixtures/ota_no_platform.yaml new file mode 100644 index 0000000000..0b09c836fb --- /dev/null +++ b/tests/unit_tests/fixtures/ota_no_platform.yaml @@ -0,0 +1,17 @@ +esphome: + name: test-device + +esp32: + board: esp32dev + framework: + type: esp-idf + +# OTA with no value - this should be normalized to empty list +ota: + +wifi: + ssid: "test" + password: "test" + +# Captive portal auto-loads ota.web_server which triggers the issue +captive_portal: diff --git a/tests/unit_tests/fixtures/ota_with_platform_list.yaml b/tests/unit_tests/fixtures/ota_with_platform_list.yaml new file mode 100644 index 0000000000..b1b03743ae --- /dev/null +++ b/tests/unit_tests/fixtures/ota_with_platform_list.yaml @@ -0,0 +1,19 @@ +esphome: + name: test-device3 + +esp32: + board: esp32dev + framework: + type: esp-idf + +# OTA with proper list format +ota: + - platform: esphome + password: "test123" + +wifi: + ssid: "test" + password: "test" + +# Captive portal auto-loads ota.web_server +captive_portal: diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml index f5d2f8aa20..795a788f62 100644 --- a/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml @@ -1,7 +1,14 @@ substitutions: + substituted: 99 var1: '1' var2: '2' var21: '79' + value: 33 + values: 44 + position: + x: 79 + y: 82 + esphome: name: test test_list: @@ -19,3 +26,10 @@ test_list: - ${ undefined_var } - key1: 1 key2: 2 + - Literal $values ${are not substituted} + - ["list $value", "${is not}", "${substituted}"] + - {"$dictionary": "$value", "${is not}": "${substituted}"} + - |- + {{{ "x", "79"}, { "y", "82"}}} + - '{{{"AA"}}}' + - '"HELLO"' diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml index 5717433c7e..722e116d36 100644 --- a/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml @@ -2,9 +2,15 @@ esphome: name: test substitutions: + substituted: 99 var1: "1" var2: "2" var21: "79" + value: 33 + values: 44 + position: + x: 79 + y: 82 test_list: - "$var1" @@ -21,3 +27,10 @@ test_list: - ${ undefined_var } - key${var1}: 1 key${var2}: 2 + - !literal Literal $values ${are not substituted} + - !literal ["list $value", "${is not}", "${substituted}"] + - !literal {"$dictionary": "$value", "${is not}": "${substituted}"} + - |- # Test parsing things that look like a python set of sets when rendered: + {{{ "x", "${ position.x }"}, { "y", "${ position.y }"}}} + - ${ '{{{"AA"}}}' } + - ${ '"HELLO"' } diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml index 9e401ec5d6..443cba144e 100644 --- a/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml @@ -22,3 +22,6 @@ test_list: - The pin number is 18 - The square root is: 5.0 - The number is 80 + - ord("a") = 97 + - chr(97) = a + - len([1,2,3]) = 3 diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml index 1777b46f67..07ad992f1f 100644 --- a/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml @@ -20,3 +20,6 @@ test_list: - The pin number is ${pin.number} - The square root is: ${math.sqrt(area)} - The number is ${var${numberOne} + 1} + - ord("a") = ${ ord("a") } + - chr(97) = ${ chr(97) } + - len([1,2,3]) = ${ len([1,2,3]) } diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/.hidden.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/.hidden.yaml new file mode 100644 index 0000000000..75eb989ea5 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/.hidden.yaml @@ -0,0 +1,3 @@ +# This file should be ignored +platform: template +name: "Hidden Sensor" diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/not_yaml.txt b/tests/unit_tests/fixtures/yaml_util/named_dir/not_yaml.txt new file mode 100644 index 0000000000..98efb74b0f --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/not_yaml.txt @@ -0,0 +1 @@ +This is not a YAML file and should be ignored diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/sensor1.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor1.yaml new file mode 100644 index 0000000000..a4b0a11916 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor1.yaml @@ -0,0 +1,4 @@ +platform: template +name: "Sensor 1" +lambda: |- + return 42.0; diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/sensor2.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor2.yaml new file mode 100644 index 0000000000..72d4b714b6 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/sensor2.yaml @@ -0,0 +1,4 @@ +platform: template +name: "Sensor 2" +lambda: |- + return 100.0; diff --git a/tests/unit_tests/fixtures/yaml_util/named_dir/subdir/sensor3.yaml b/tests/unit_tests/fixtures/yaml_util/named_dir/subdir/sensor3.yaml new file mode 100644 index 0000000000..bcb8dd320d --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/named_dir/subdir/sensor3.yaml @@ -0,0 +1,4 @@ +platform: template +name: "Sensor 3 in subdir" +lambda: |- + return 200.0; diff --git a/tests/unit_tests/fixtures/yaml_util/secrets.yaml b/tests/unit_tests/fixtures/yaml_util/secrets.yaml new file mode 100644 index 0000000000..4eef570926 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/secrets.yaml @@ -0,0 +1,4 @@ +test_secret: "my_secret_value" +another_secret: "another_value" +wifi_password: "super_secret_wifi" +api_key: "0123456789abcdef" diff --git a/tests/unit_tests/fixtures/yaml_util/test_secret.yaml b/tests/unit_tests/fixtures/yaml_util/test_secret.yaml new file mode 100644 index 0000000000..c23afaee94 --- /dev/null +++ b/tests/unit_tests/fixtures/yaml_util/test_secret.yaml @@ -0,0 +1,17 @@ +esphome: + name: test_device + platform: ESP32 + board: esp32dev + +wifi: + ssid: "TestNetwork" + password: !secret wifi_password + +api: + encryption: + key: !secret api_key + +sensor: + - platform: template + name: "Test Sensor" + id: !secret test_secret diff --git a/tests/unit_tests/test_address_cache.py b/tests/unit_tests/test_address_cache.py new file mode 100644 index 0000000000..de43830d53 --- /dev/null +++ b/tests/unit_tests/test_address_cache.py @@ -0,0 +1,305 @@ +"""Tests for the address_cache module.""" + +from __future__ import annotations + +import logging + +import pytest +from pytest import LogCaptureFixture + +from esphome.address_cache import AddressCache, normalize_hostname + + +def test_normalize_simple_hostname() -> None: + """Test normalizing a simple hostname.""" + assert normalize_hostname("device") == "device" + assert normalize_hostname("device.local") == "device.local" + assert normalize_hostname("server.example.com") == "server.example.com" + + +def test_normalize_removes_trailing_dots() -> None: + """Test that trailing dots are removed.""" + assert normalize_hostname("device.") == "device" + assert normalize_hostname("device.local.") == "device.local" + assert normalize_hostname("server.example.com.") == "server.example.com" + assert normalize_hostname("device...") == "device" + + +def test_normalize_converts_to_lowercase() -> None: + """Test that hostnames are converted to lowercase.""" + assert normalize_hostname("DEVICE") == "device" + assert normalize_hostname("Device.Local") == "device.local" + assert normalize_hostname("Server.Example.COM") == "server.example.com" + + +def test_normalize_combined() -> None: + """Test combination of trailing dots and case conversion.""" + assert normalize_hostname("DEVICE.LOCAL.") == "device.local" + assert normalize_hostname("Server.Example.COM...") == "server.example.com" + + +def test_init_empty() -> None: + """Test initialization with empty caches.""" + cache = AddressCache() + assert cache.mdns_cache == {} + assert cache.dns_cache == {} + assert not cache.has_cache() + + +def test_init_with_caches() -> None: + """Test initialization with provided caches.""" + mdns_cache: dict[str, list[str]] = {"device.local": ["192.168.1.10"]} + dns_cache: dict[str, list[str]] = {"server.com": ["10.0.0.1"]} + cache = AddressCache(mdns_cache=mdns_cache, dns_cache=dns_cache) + assert cache.mdns_cache == mdns_cache + assert cache.dns_cache == dns_cache + assert cache.has_cache() + + +def test_get_mdns_addresses() -> None: + """Test getting mDNS addresses.""" + cache = AddressCache(mdns_cache={"device.local": ["192.168.1.10", "192.168.1.11"]}) + + # Direct lookup + assert cache.get_mdns_addresses("device.local") == [ + "192.168.1.10", + "192.168.1.11", + ] + + # Case insensitive lookup + assert cache.get_mdns_addresses("Device.Local") == [ + "192.168.1.10", + "192.168.1.11", + ] + + # With trailing dot + assert cache.get_mdns_addresses("device.local.") == [ + "192.168.1.10", + "192.168.1.11", + ] + + # Not found + assert cache.get_mdns_addresses("unknown.local") is None + + +def test_get_dns_addresses() -> None: + """Test getting DNS addresses.""" + cache = AddressCache(dns_cache={"server.com": ["10.0.0.1", "10.0.0.2"]}) + + # Direct lookup + assert cache.get_dns_addresses("server.com") == ["10.0.0.1", "10.0.0.2"] + + # Case insensitive lookup + assert cache.get_dns_addresses("Server.COM") == ["10.0.0.1", "10.0.0.2"] + + # With trailing dot + assert cache.get_dns_addresses("server.com.") == ["10.0.0.1", "10.0.0.2"] + + # Not found + assert cache.get_dns_addresses("unknown.com") is None + + +def test_get_addresses_auto_detection() -> None: + """Test automatic cache selection based on hostname.""" + cache = AddressCache( + mdns_cache={"device.local": ["192.168.1.10"]}, + dns_cache={"server.com": ["10.0.0.1"]}, + ) + + # Should use mDNS cache for .local domains + assert cache.get_addresses("device.local") == ["192.168.1.10"] + assert cache.get_addresses("device.local.") == ["192.168.1.10"] + assert cache.get_addresses("Device.Local") == ["192.168.1.10"] + + # Should use DNS cache for non-.local domains + assert cache.get_addresses("server.com") == ["10.0.0.1"] + assert cache.get_addresses("server.com.") == ["10.0.0.1"] + assert cache.get_addresses("Server.COM") == ["10.0.0.1"] + + # Not found + assert cache.get_addresses("unknown.local") is None + assert cache.get_addresses("unknown.com") is None + + +def test_has_cache() -> None: + """Test checking if cache has entries.""" + # Empty cache + cache = AddressCache() + assert not cache.has_cache() + + # Only mDNS cache + cache = AddressCache(mdns_cache={"device.local": ["192.168.1.10"]}) + assert cache.has_cache() + + # Only DNS cache + cache = AddressCache(dns_cache={"server.com": ["10.0.0.1"]}) + assert cache.has_cache() + + # Both caches + cache = AddressCache( + mdns_cache={"device.local": ["192.168.1.10"]}, + dns_cache={"server.com": ["10.0.0.1"]}, + ) + assert cache.has_cache() + + +def test_from_cli_args_empty() -> None: + """Test creating cache from empty CLI arguments.""" + cache = AddressCache.from_cli_args([], []) + assert cache.mdns_cache == {} + assert cache.dns_cache == {} + + +def test_from_cli_args_single_entry() -> None: + """Test creating cache from single CLI argument.""" + mdns_args: list[str] = ["device.local=192.168.1.10"] + dns_args: list[str] = ["server.com=10.0.0.1"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["192.168.1.10"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1"]} + + +def test_from_cli_args_multiple_ips() -> None: + """Test creating cache with multiple IPs per host.""" + mdns_args: list[str] = ["device.local=192.168.1.10,192.168.1.11"] + dns_args: list[str] = ["server.com=10.0.0.1,10.0.0.2,10.0.0.3"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["192.168.1.10", "192.168.1.11"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1", "10.0.0.2", "10.0.0.3"]} + + +def test_from_cli_args_multiple_entries() -> None: + """Test creating cache with multiple host entries.""" + mdns_args: list[str] = [ + "device1.local=192.168.1.10", + "device2.local=192.168.1.20,192.168.1.21", + ] + dns_args: list[str] = ["server1.com=10.0.0.1", "server2.com=10.0.0.2"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == { + "device1.local": ["192.168.1.10"], + "device2.local": ["192.168.1.20", "192.168.1.21"], + } + assert cache.dns_cache == { + "server1.com": ["10.0.0.1"], + "server2.com": ["10.0.0.2"], + } + + +def test_from_cli_args_normalization() -> None: + """Test that CLI arguments are normalized.""" + mdns_args: list[str] = ["Device1.Local.=192.168.1.10", "DEVICE2.LOCAL=192.168.1.20"] + dns_args: list[str] = ["Server1.COM.=10.0.0.1", "SERVER2.com=10.0.0.2"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + # Hostnames should be normalized (lowercase, no trailing dots) + assert cache.mdns_cache == { + "device1.local": ["192.168.1.10"], + "device2.local": ["192.168.1.20"], + } + assert cache.dns_cache == { + "server1.com": ["10.0.0.1"], + "server2.com": ["10.0.0.2"], + } + + +def test_from_cli_args_whitespace_handling() -> None: + """Test that whitespace in IPs is handled.""" + mdns_args: list[str] = ["device.local= 192.168.1.10 , 192.168.1.11 "] + dns_args: list[str] = ["server.com= 10.0.0.1 , 10.0.0.2 "] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["192.168.1.10", "192.168.1.11"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1", "10.0.0.2"]} + + +def test_from_cli_args_invalid_format(caplog: LogCaptureFixture) -> None: + """Test handling of invalid argument format.""" + mdns_args: list[str] = ["invalid_format", "device.local=192.168.1.10"] + dns_args: list[str] = ["server.com=10.0.0.1", "also_invalid"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + # Valid entries should still be processed + assert cache.mdns_cache == {"device.local": ["192.168.1.10"]} + assert cache.dns_cache == {"server.com": ["10.0.0.1"]} + + # Check that warnings were logged for invalid entries + assert "Invalid cache format: invalid_format" in caplog.text + assert "Invalid cache format: also_invalid" in caplog.text + + +def test_from_cli_args_ipv6() -> None: + """Test handling of IPv6 addresses.""" + mdns_args: list[str] = ["device.local=fe80::1,2001:db8::1"] + dns_args: list[str] = ["server.com=2001:db8::2,::1"] + + cache = AddressCache.from_cli_args(mdns_args, dns_args) + + assert cache.mdns_cache == {"device.local": ["fe80::1", "2001:db8::1"]} + assert cache.dns_cache == {"server.com": ["2001:db8::2", "::1"]} + + +def test_logging_output(caplog: LogCaptureFixture) -> None: + """Test that appropriate debug logging occurs.""" + caplog.set_level(logging.DEBUG) + + cache = AddressCache( + mdns_cache={"device.local": ["192.168.1.10"]}, + dns_cache={"server.com": ["10.0.0.1"]}, + ) + + # Test successful lookups log at debug level + result: list[str] | None = cache.get_mdns_addresses("device.local") + assert result == ["192.168.1.10"] + assert "Using mDNS cache for device.local" in caplog.text + + caplog.clear() + result = cache.get_dns_addresses("server.com") + assert result == ["10.0.0.1"] + assert "Using DNS cache for server.com" in caplog.text + + # Test that failed lookups don't log + caplog.clear() + result = cache.get_mdns_addresses("unknown.local") + assert result is None + assert "Using mDNS cache" not in caplog.text + + +@pytest.mark.parametrize( + "hostname,expected", + [ + ("test.local", "test.local"), + ("Test.Local.", "test.local"), + ("TEST.LOCAL...", "test.local"), + ("example.com", "example.com"), + ("EXAMPLE.COM.", "example.com"), + ], +) +def test_normalize_hostname_parametrized(hostname: str, expected: str) -> None: + """Test hostname normalization with various inputs.""" + assert normalize_hostname(hostname) == expected + + +@pytest.mark.parametrize( + "mdns_arg,expected", + [ + ("host=1.2.3.4", {"host": ["1.2.3.4"]}), + ("Host.Local=1.2.3.4,5.6.7.8", {"host.local": ["1.2.3.4", "5.6.7.8"]}), + ("HOST.LOCAL.=::1", {"host.local": ["::1"]}), + ], +) +def test_parse_cache_args_parametrized( + mdns_arg: str, expected: dict[str, list[str]] +) -> None: + """Test parsing of cache arguments with various formats.""" + cache = AddressCache.from_cli_args([mdns_arg], []) + assert cache.mdns_cache == expected diff --git a/tests/unit_tests/test_config_auto_load.py b/tests/unit_tests/test_config_auto_load.py new file mode 100644 index 0000000000..d31b17eeec --- /dev/null +++ b/tests/unit_tests/test_config_auto_load.py @@ -0,0 +1,131 @@ +"""Tests for AUTO_LOAD functionality including dynamic AUTO_LOAD.""" + +from pathlib import Path +from typing import Any +from unittest.mock import Mock + +import pytest + +from esphome import config, config_validation as cv, yaml_util +from esphome.core import CORE + + +@pytest.fixture +def fixtures_dir() -> Path: + """Get the fixtures directory.""" + return Path(__file__).parent / "fixtures" + + +@pytest.fixture +def default_component() -> Mock: + """Create a default mock component for unmocked components.""" + return Mock( + auto_load=[], + is_platform_component=False, + is_platform=False, + multi_conf=False, + multi_conf_no_default=False, + dependencies=[], + conflicts_with=[], + config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA), + ) + + +@pytest.fixture +def static_auto_load_component() -> Mock: + """Create a mock component with static AUTO_LOAD.""" + return Mock( + auto_load=["logger"], + is_platform_component=False, + is_platform=False, + multi_conf=False, + multi_conf_no_default=False, + dependencies=[], + conflicts_with=[], + config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA), + ) + + +def test_static_auto_load_adds_components( + mock_get_component: Mock, + fixtures_dir: Path, + static_auto_load_component: Mock, + default_component: Mock, +) -> None: + """Test that static AUTO_LOAD triggers loading of specified components.""" + CORE.config_path = fixtures_dir / "auto_load_static.yaml" + + config_file = fixtures_dir / "auto_load_static.yaml" + raw_config = yaml_util.load_yaml(config_file) + + component_mocks = {"test_component": static_auto_load_component} + mock_get_component.side_effect = lambda name: component_mocks.get( + name, default_component + ) + + result = config.validate_config(raw_config, {}) + + # Check for validation errors + assert not result.errors, f"Validation errors: {result.errors}" + + # Logger should have been auto-loaded by test_component + assert "logger" in result + assert "test_component" in result + + +def test_dynamic_auto_load_with_config_param( + mock_get_component: Mock, + fixtures_dir: Path, + default_component: Mock, +) -> None: + """Test that dynamic AUTO_LOAD evaluates based on configuration.""" + CORE.config_path = fixtures_dir / "auto_load_dynamic.yaml" + + config_file = fixtures_dir / "auto_load_dynamic.yaml" + raw_config = yaml_util.load_yaml(config_file) + + # Track if auto_load was called with config + auto_load_calls = [] + + def dynamic_auto_load(conf: dict[str, Any]) -> list[str]: + """Dynamically load components based on config.""" + auto_load_calls.append(conf) + component_map = { + "enable_logger": "logger", + "enable_api": "api", + } + return [comp for key, comp in component_map.items() if conf.get(key)] + + dynamic_component = Mock( + auto_load=dynamic_auto_load, + is_platform_component=False, + is_platform=False, + multi_conf=False, + multi_conf_no_default=False, + dependencies=[], + conflicts_with=[], + config_schema=cv.Schema({}, extra=cv.ALLOW_EXTRA), + ) + + component_mocks = {"test_component": dynamic_component} + mock_get_component.side_effect = lambda name: component_mocks.get( + name, default_component + ) + + result = config.validate_config(raw_config, {}) + + # Check for validation errors + assert not result.errors, f"Validation errors: {result.errors}" + + # Verify auto_load was called with the validated config + assert len(auto_load_calls) == 1, "auto_load should be called exactly once" + assert auto_load_calls[0].get("enable_logger") is True + assert auto_load_calls[0].get("enable_api") is False + + # Only logger should be auto-loaded (enable_logger=true in YAML) + assert "logger" in result, ( + f"Logger not found in result. Result keys: {list(result.keys())}" + ) + # API should NOT be auto-loaded (enable_api=false in YAML) + assert "api" not in result + assert "test_component" in result diff --git a/tests/unit_tests/test_config_normalization.py b/tests/unit_tests/test_config_normalization.py new file mode 100644 index 0000000000..d70f3c24e0 --- /dev/null +++ b/tests/unit_tests/test_config_normalization.py @@ -0,0 +1,115 @@ +"""Unit tests for esphome.config module.""" + +from collections.abc import Generator +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from esphome import config, yaml_util +from esphome.core import CORE + + +@pytest.fixture +def mock_get_platform() -> Generator[Mock, None, None]: + """Fixture for mocking get_platform.""" + with patch("esphome.config.get_platform") as mock_get_platform: + # Default mock platform + mock_get_platform.return_value = MagicMock() + yield mock_get_platform + + +@pytest.fixture +def fixtures_dir() -> Path: + """Get the fixtures directory.""" + return Path(__file__).parent / "fixtures" + + +def test_ota_component_configs_with_proper_platform_list( + mock_get_component: Mock, + mock_get_platform: Mock, +) -> None: + """Test iter_component_configs handles OTA properly configured as a list.""" + test_config = { + "ota": [ + {"platform": "esphome", "password": "test123", "id": "my_ota"}, + ], + } + + mock_get_component.return_value = MagicMock( + is_platform_component=True, multi_conf=False + ) + + configs = list(config.iter_component_configs(test_config)) + assert len(configs) == 2 + + assert configs[0][0] == "ota" + assert configs[0][2] == test_config["ota"] # The list itself + + assert configs[1][0] == "ota.esphome" + assert configs[1][2]["platform"] == "esphome" + assert configs[1][2]["password"] == "test123" + + +def test_iter_component_configs_with_multi_conf(mock_get_component: Mock) -> None: + """Test that iter_component_configs handles multi_conf components correctly.""" + test_config = { + "switch": [ + {"name": "Switch 1"}, + {"name": "Switch 2"}, + ], + } + + mock_get_component.return_value = MagicMock( + is_platform_component=False, multi_conf=True + ) + + configs = list(config.iter_component_configs(test_config)) + assert len(configs) == 2 + + for domain, component, conf in configs: + assert domain == "switch" + assert "name" in conf + + +def test_ota_no_platform_with_captive_portal(fixtures_dir: Path) -> None: + """Test OTA with no platform (ota:) gets normalized when captive_portal auto-loads.""" + CORE.config_path = fixtures_dir / "dummy.yaml" + + config_file = fixtures_dir / "ota_no_platform.yaml" + raw_config = yaml_util.load_yaml(config_file) + result = config.validate_config(raw_config, {}) + + assert "ota" in result + assert isinstance(result["ota"], list), f"Expected list, got {type(result['ota'])}" + platforms = {p.get("platform") for p in result["ota"]} + assert "web_server" in platforms, f"Expected web_server platform in {platforms}" + + +def test_ota_empty_dict_with_captive_portal(fixtures_dir: Path) -> None: + """Test OTA with empty dict ({}) gets normalized when captive_portal auto-loads.""" + CORE.config_path = fixtures_dir / "dummy.yaml" + + config_file = fixtures_dir / "ota_empty_dict.yaml" + raw_config = yaml_util.load_yaml(config_file) + result = config.validate_config(raw_config, {}) + + assert "ota" in result + assert isinstance(result["ota"], list), f"Expected list, got {type(result['ota'])}" + platforms = {p.get("platform") for p in result["ota"]} + assert "web_server" in platforms, f"Expected web_server platform in {platforms}" + + +def test_ota_with_platform_list_and_captive_portal(fixtures_dir: Path) -> None: + """Test OTA with proper platform list remains valid when captive_portal auto-loads.""" + CORE.config_path = fixtures_dir / "dummy.yaml" + + config_file = fixtures_dir / "ota_with_platform_list.yaml" + raw_config = yaml_util.load_yaml(config_file) + result = config.validate_config(raw_config, {}) + + assert "ota" in result + assert isinstance(result["ota"], list), f"Expected list, got {type(result['ota'])}" + platforms = {p.get("platform") for p in result["ota"]} + assert "esphome" in platforms, f"Expected esphome platform in {platforms}" + assert "web_server" in platforms, f"Expected web_server platform in {platforms}" diff --git a/tests/unit_tests/test_config_validation_paths.py b/tests/unit_tests/test_config_validation_paths.py index f8f038390e..f327e9c443 100644 --- a/tests/unit_tests/test_config_validation_paths.py +++ b/tests/unit_tests/test_config_validation_paths.py @@ -15,7 +15,7 @@ def test_directory_valid_path(setup_core: Path) -> None: result = cv.directory("test_directory") - assert result == "test_directory" + assert result == test_dir def test_directory_absolute_path(setup_core: Path) -> None: @@ -25,7 +25,7 @@ def test_directory_absolute_path(setup_core: Path) -> None: result = cv.directory(str(test_dir)) - assert result == str(test_dir) + assert result == test_dir def test_directory_nonexistent_path(setup_core: Path) -> None: @@ -52,7 +52,7 @@ def test_directory_with_parent_directory(setup_core: Path) -> None: result = cv.directory("parent/child/grandchild") - assert result == "parent/child/grandchild" + assert result == nested_dir def test_file_valid_path(setup_core: Path) -> None: @@ -62,7 +62,7 @@ def test_file_valid_path(setup_core: Path) -> None: result = cv.file_("test_file.yaml") - assert result == "test_file.yaml" + assert result == test_file def test_file_absolute_path(setup_core: Path) -> None: @@ -72,7 +72,7 @@ def test_file_absolute_path(setup_core: Path) -> None: result = cv.file_(str(test_file)) - assert result == str(test_file) + assert result == test_file def test_file_nonexistent_path(setup_core: Path) -> None: @@ -99,7 +99,7 @@ def test_file_with_parent_directory(setup_core: Path) -> None: result = cv.file_("configs/sensors/temperature.yaml") - assert result == "configs/sensors/temperature.yaml" + assert result == test_file def test_directory_handles_trailing_slash(setup_core: Path) -> None: @@ -108,29 +108,29 @@ def test_directory_handles_trailing_slash(setup_core: Path) -> None: test_dir.mkdir() result = cv.directory("test_dir/") - assert result == "test_dir/" + assert result == test_dir result = cv.directory("test_dir") - assert result == "test_dir" + assert result == test_dir def test_file_handles_various_extensions(setup_core: Path) -> None: """Test file_ validator works with different file extensions.""" yaml_file = setup_core / "config.yaml" yaml_file.write_text("yaml content") - assert cv.file_("config.yaml") == "config.yaml" + assert cv.file_("config.yaml") == yaml_file yml_file = setup_core / "config.yml" yml_file.write_text("yml content") - assert cv.file_("config.yml") == "config.yml" + assert cv.file_("config.yml") == yml_file txt_file = setup_core / "readme.txt" txt_file.write_text("text content") - assert cv.file_("readme.txt") == "readme.txt" + assert cv.file_("readme.txt") == txt_file no_ext_file = setup_core / "LICENSE" no_ext_file.write_text("license content") - assert cv.file_("LICENSE") == "LICENSE" + assert cv.file_("LICENSE") == no_ext_file def test_directory_with_symlink(setup_core: Path) -> None: @@ -142,7 +142,7 @@ def test_directory_with_symlink(setup_core: Path) -> None: symlink_dir.symlink_to(actual_dir) result = cv.directory("symlink_directory") - assert result == "symlink_directory" + assert result == symlink_dir def test_file_with_symlink(setup_core: Path) -> None: @@ -154,7 +154,7 @@ def test_file_with_symlink(setup_core: Path) -> None: symlink_file.symlink_to(actual_file) result = cv.file_("symlink_file.txt") - assert result == "symlink_file.txt" + assert result == symlink_file def test_directory_error_shows_full_path(setup_core: Path) -> None: @@ -175,7 +175,7 @@ def test_directory_with_spaces_in_name(setup_core: Path) -> None: dir_with_spaces.mkdir() result = cv.directory("my test directory") - assert result == "my test directory" + assert result == dir_with_spaces def test_file_with_spaces_in_name(setup_core: Path) -> None: @@ -184,4 +184,4 @@ def test_file_with_spaces_in_name(setup_core: Path) -> None: file_with_spaces.write_text("content") result = cv.file_("my test file.yaml") - assert result == "my test file.yaml" + assert result == file_with_spaces diff --git a/tests/unit_tests/test_core.py b/tests/unit_tests/test_core.py index f7dda9fb95..48eae06ea6 100644 --- a/tests/unit_tests/test_core.py +++ b/tests/unit_tests/test_core.py @@ -1,3 +1,7 @@ +import os +from pathlib import Path +from unittest.mock import patch + from hypothesis import given import pytest from strategies import mac_addr_strings @@ -533,8 +537,8 @@ class TestEsphomeCore: @pytest.fixture def target(self, fixture_path): target = core.EsphomeCore() - target.build_path = "foo/build" - target.config_path = "foo/config" + target.build_path = Path("foo/build") + target.config_path = Path("foo/config") return target def test_reset(self, target): @@ -577,3 +581,83 @@ class TestEsphomeCore: assert target.is_esp32 is False assert target.is_esp8266 is True + + @pytest.mark.skipif(os.name == "nt", reason="Unix-specific test") + def test_data_dir_default_unix(self, target): + """Test data_dir returns .esphome in config directory by default on Unix.""" + target.config_path = Path("/home/user/config.yaml") + assert target.data_dir == Path("/home/user/.esphome") + + @pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") + def test_data_dir_default_windows(self, target): + """Test data_dir returns .esphome in config directory by default on Windows.""" + target.config_path = Path("D:\\home\\user\\config.yaml") + assert target.data_dir == Path("D:\\home\\user\\.esphome") + + def test_data_dir_ha_addon(self, target): + """Test data_dir returns /data when running as Home Assistant addon.""" + target.config_path = Path("/config/test.yaml") + + with patch.dict(os.environ, {"ESPHOME_IS_HA_ADDON": "true"}): + assert target.data_dir == Path("/data") + + def test_data_dir_env_override(self, target): + """Test data_dir uses ESPHOME_DATA_DIR environment variable when set.""" + target.config_path = Path("/home/user/config.yaml") + + with patch.dict(os.environ, {"ESPHOME_DATA_DIR": "/custom/data/path"}): + assert target.data_dir == Path("/custom/data/path") + + @pytest.mark.skipif(os.name == "nt", reason="Unix-specific test") + def test_data_dir_priority_unix(self, target): + """Test data_dir priority on Unix: HA addon > env var > default.""" + target.config_path = Path("/config/test.yaml") + expected_default = "/config/.esphome" + + # Test HA addon takes priority over env var + with patch.dict( + os.environ, + {"ESPHOME_IS_HA_ADDON": "true", "ESPHOME_DATA_DIR": "/custom/path"}, + ): + assert target.data_dir == Path("/data") + + # Test env var is used when not HA addon + with patch.dict( + os.environ, + {"ESPHOME_IS_HA_ADDON": "false", "ESPHOME_DATA_DIR": "/custom/path"}, + ): + assert target.data_dir == Path("/custom/path") + + # Test default when neither is set + with patch.dict(os.environ, {}, clear=True): + # Ensure these env vars are not set + os.environ.pop("ESPHOME_IS_HA_ADDON", None) + os.environ.pop("ESPHOME_DATA_DIR", None) + assert target.data_dir == Path(expected_default) + + @pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") + def test_data_dir_priority_windows(self, target): + """Test data_dir priority on Windows: HA addon > env var > default.""" + target.config_path = Path("D:\\config\\test.yaml") + expected_default = "D:\\config\\.esphome" + + # Test HA addon takes priority over env var + with patch.dict( + os.environ, + {"ESPHOME_IS_HA_ADDON": "true", "ESPHOME_DATA_DIR": "/custom/path"}, + ): + assert target.data_dir == Path("/data") + + # Test env var is used when not HA addon + with patch.dict( + os.environ, + {"ESPHOME_IS_HA_ADDON": "false", "ESPHOME_DATA_DIR": "/custom/path"}, + ): + assert target.data_dir == Path("/custom/path") + + # Test default when neither is set + with patch.dict(os.environ, {}, clear=True): + # Ensure these env vars are not set + os.environ.pop("ESPHOME_IS_HA_ADDON", None) + os.environ.pop("ESPHOME_DATA_DIR", None) + assert target.data_dir == Path(expected_default) diff --git a/tests/unit_tests/test_espota2.py b/tests/unit_tests/test_espota2.py new file mode 100644 index 0000000000..bd1a6bde81 --- /dev/null +++ b/tests/unit_tests/test_espota2.py @@ -0,0 +1,738 @@ +"""Unit tests for esphome.espota2 module.""" + +from __future__ import annotations + +from collections.abc import Generator +import gzip +import hashlib +import io +from pathlib import Path +import socket +import struct +from unittest.mock import Mock, call, patch + +import pytest +from pytest import CaptureFixture + +from esphome import espota2 +from esphome.core import EsphomeError + +# Test constants +MOCK_RANDOM_VALUE = 0.123456 +MOCK_RANDOM_BYTES = b"0.123456" +MOCK_MD5_NONCE = b"12345678901234567890123456789012" # 32 char nonce for MD5 +MOCK_SHA256_NONCE = b"1234567890123456789012345678901234567890123456789012345678901234" # 64 char nonce for SHA256 + + +@pytest.fixture +def mock_socket() -> Mock: + """Create a mock socket for testing.""" + socket_mock = Mock() + socket_mock.close = Mock() + socket_mock.recv = Mock() + socket_mock.sendall = Mock() + socket_mock.settimeout = Mock() + socket_mock.connect = Mock() + socket_mock.setsockopt = Mock() + return socket_mock + + +@pytest.fixture +def mock_file() -> io.BytesIO: + """Create a mock firmware file for testing.""" + return io.BytesIO(b"firmware content here") + + +@pytest.fixture +def mock_time() -> Generator[None]: + """Mock time-related functions for consistent testing.""" + # Provide enough values for multiple calls (tests may call perform_ota multiple times) + with ( + patch("time.sleep"), + patch("time.perf_counter", side_effect=[0, 1, 0, 1, 0, 1]), + ): + yield + + +@pytest.fixture +def mock_random() -> Generator[Mock]: + """Mock random for predictable test values.""" + with patch("random.random", return_value=MOCK_RANDOM_VALUE) as mock_rand: + yield mock_rand + + +@pytest.fixture +def mock_resolve_ip() -> Generator[Mock]: + """Mock resolve_ip_address for testing.""" + with patch("esphome.espota2.resolve_ip_address") as mock: + mock.return_value = [ + (socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.100", 3232)) + ] + yield mock + + +@pytest.fixture +def mock_perform_ota() -> Generator[Mock]: + """Mock perform_ota function for testing.""" + with patch("esphome.espota2.perform_ota") as mock: + yield mock + + +@pytest.fixture +def mock_run_ota_impl() -> Generator[Mock]: + """Mock run_ota_impl_ function for testing.""" + with patch("esphome.espota2.run_ota_impl_") as mock: + mock.return_value = (0, "192.168.1.100") + yield mock + + +@pytest.fixture +def mock_socket_constructor(mock_socket: Mock) -> Generator[Mock]: + """Mock socket.socket constructor to return our mock socket.""" + with patch("socket.socket", return_value=mock_socket) as mock_constructor: + yield mock_constructor + + +def test_recv_decode_with_decode(mock_socket: Mock) -> None: + """Test recv_decode with decode=True returns list.""" + mock_socket.recv.return_value = b"\x01\x02\x03" + + result = espota2.recv_decode(mock_socket, 3, decode=True) + + assert result == [1, 2, 3] + mock_socket.recv.assert_called_once_with(3) + + +def test_recv_decode_without_decode(mock_socket: Mock) -> None: + """Test recv_decode with decode=False returns bytes.""" + mock_socket.recv.return_value = b"\x01\x02\x03" + + result = espota2.recv_decode(mock_socket, 3, decode=False) + + assert result == b"\x01\x02\x03" + mock_socket.recv.assert_called_once_with(3) + + +def test_receive_exactly_success(mock_socket: Mock) -> None: + """Test receive_exactly successfully receives expected data.""" + mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"] + + result = espota2.receive_exactly(mock_socket, 3, "test", espota2.RESPONSE_OK) + + assert result == [0, 1, 2] + assert mock_socket.recv.call_count == 2 + + +def test_receive_exactly_with_error_response(mock_socket: Mock) -> None: + """Test receive_exactly raises OTAError on error response.""" + mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]) + + with pytest.raises(espota2.OTAError, match="Error auth:.*Authentication invalid"): + espota2.receive_exactly(mock_socket, 1, "auth", [espota2.RESPONSE_OK]) + + mock_socket.close.assert_called_once() + + +def test_receive_exactly_socket_error(mock_socket: Mock) -> None: + """Test receive_exactly handles socket errors.""" + mock_socket.recv.side_effect = OSError("Connection reset") + + with pytest.raises(espota2.OTAError, match="Error receiving acknowledge test"): + espota2.receive_exactly(mock_socket, 1, "test", espota2.RESPONSE_OK) + + +@pytest.mark.parametrize( + ("error_code", "expected_msg"), + [ + (espota2.RESPONSE_ERROR_MAGIC, "Error: Invalid magic byte"), + (espota2.RESPONSE_ERROR_UPDATE_PREPARE, "Error: Couldn't prepare flash memory"), + (espota2.RESPONSE_ERROR_AUTH_INVALID, "Error: Authentication invalid"), + ( + espota2.RESPONSE_ERROR_WRITING_FLASH, + "Error: Writing OTA data to flash memory failed", + ), + (espota2.RESPONSE_ERROR_UPDATE_END, "Error: Finishing update failed"), + ( + espota2.RESPONSE_ERROR_INVALID_BOOTSTRAPPING, + "Error: Please press the reset button", + ), + ( + espota2.RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG, + "Error: ESP has been flashed with wrong flash size", + ), + ( + espota2.RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG, + "Error: ESP does not have the requested flash size", + ), + ( + espota2.RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE, + "Error: ESP does not have enough space", + ), + ( + espota2.RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE, + "Error: The OTA partition on the ESP is too small", + ), + ( + espota2.RESPONSE_ERROR_NO_UPDATE_PARTITION, + "Error: The OTA partition on the ESP couldn't be found", + ), + (espota2.RESPONSE_ERROR_MD5_MISMATCH, "Error: Application MD5 code mismatch"), + (espota2.RESPONSE_ERROR_UNKNOWN, "Unknown error from ESP"), + ], +) +def test_check_error_with_various_errors(error_code: int, expected_msg: str) -> None: + """Test check_error raises appropriate errors for different error codes.""" + with pytest.raises(espota2.OTAError, match=expected_msg): + espota2.check_error([error_code], [espota2.RESPONSE_OK]) + + +def test_check_error_unexpected_response() -> None: + """Test check_error raises error for unexpected response.""" + with pytest.raises(espota2.OTAError, match="Unexpected response from ESP: 0x7F"): + espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK]) + + +def test_send_check_with_various_data_types(mock_socket: Mock) -> None: + """Test send_check handles different data types.""" + + # Test with list/tuple + espota2.send_check(mock_socket, [0x01, 0x02], "list") + mock_socket.sendall.assert_called_with(b"\x01\x02") + + # Test with int + espota2.send_check(mock_socket, 0x42, "int") + mock_socket.sendall.assert_called_with(b"\x42") + + # Test with string + espota2.send_check(mock_socket, "hello", "string") + mock_socket.sendall.assert_called_with(b"hello") + + # Test with bytes (should pass through) + espota2.send_check(mock_socket, b"\xaa\xbb", "bytes") + mock_socket.sendall.assert_called_with(b"\xaa\xbb") + + +def test_send_check_socket_error(mock_socket: Mock) -> None: + """Test send_check handles socket errors.""" + mock_socket.sendall.side_effect = OSError("Broken pipe") + + with pytest.raises(espota2.OTAError, match="Error sending test"): + espota2.send_check(mock_socket, b"data", "test") + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_successful_md5_auth( + mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock +) -> None: + """Test successful OTA with MD5 authentication.""" + # Setup socket responses for recv calls + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_REQUEST_AUTH]), # Auth request + MOCK_MD5_NONCE, # 32 char hex nonce + bytes([espota2.RESPONSE_AUTH_OK]), # Auth result + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + + # Run OTA + espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin") + + # Verify magic bytes were sent + assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES)) + + # Verify features were sent (compression + SHA256 support) + assert mock_socket.sendall.call_args_list[1] == call( + bytes( + [ + espota2.FEATURE_SUPPORTS_COMPRESSION + | espota2.FEATURE_SUPPORTS_SHA256_AUTH + ] + ) + ) + + # Verify cnonce was sent (MD5 of random.random()) + cnonce = hashlib.md5(MOCK_RANDOM_BYTES).hexdigest() + assert mock_socket.sendall.call_args_list[2] == call(cnonce.encode()) + + # Verify auth result was computed correctly + expected_hash = hashlib.md5() + expected_hash.update(b"testpass") + expected_hash.update(MOCK_MD5_NONCE) + expected_hash.update(cnonce.encode()) + expected_result = expected_hash.hexdigest() + assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode()) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_no_auth(mock_socket: Mock, mock_file: io.BytesIO) -> None: + """Test OTA without authentication.""" + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_1_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + # Should not send any auth-related data + auth_calls = [ + call + for call in mock_socket.sendall.call_args_list + if "cnonce" in str(call) or "result" in str(call) + ] + assert len(auth_calls) == 0 + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_with_compression(mock_socket: Mock) -> None: + """Test OTA with compression support.""" + original_content = b"firmware" * 100 # Repeating content for compression + mock_file = io.BytesIO(original_content) + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_SUPPORTS_COMPRESSION]), # Device supports compression + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + # Verify compressed content was sent + # Get the binary size that was sent (4 bytes after features) + size_bytes = mock_socket.sendall.call_args_list[2][0][0] + sent_size = struct.unpack(">I", size_bytes)[0] + + # Size should be less than original due to compression + assert sent_size < len(original_content) + + # Verify the content sent was gzipped + compressed = gzip.compress(original_content, compresslevel=9) + assert sent_size == len(compressed) + + +def test_perform_ota_auth_without_password(mock_socket: Mock) -> None: + """Test OTA fails when auth is required but no password provided.""" + mock_file = io.BytesIO(b"firmware") + + responses = [ + bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]), + bytes([espota2.RESPONSE_HEADER_OK]), + bytes([espota2.RESPONSE_REQUEST_AUTH]), + ] + + mock_socket.recv.side_effect = responses + + with pytest.raises( + espota2.OTAError, match="ESP requests password, but no password given" + ): + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_md5_auth_wrong_password( + mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock +) -> None: + """Test OTA fails when MD5 authentication is rejected due to wrong password.""" + # Setup socket responses for recv calls + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_REQUEST_AUTH]), # Auth request + MOCK_MD5_NONCE, # 32 char hex nonce + bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]), # Auth rejected! + ] + + mock_socket.recv.side_effect = recv_responses + + with pytest.raises(espota2.OTAError, match="Error auth.*Authentication invalid"): + espota2.perform_ota(mock_socket, "wrongpassword", mock_file, "test.bin") + + # Verify the socket was closed after auth failure + mock_socket.close.assert_called() + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_sha256_auth_wrong_password( + mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock +) -> None: + """Test OTA fails when SHA256 authentication is rejected due to wrong password.""" + # Setup socket responses for recv calls + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), # SHA256 Auth request + MOCK_SHA256_NONCE, # 64 char hex nonce + bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]), # Auth rejected! + ] + + mock_socket.recv.side_effect = recv_responses + + with pytest.raises(espota2.OTAError, match="Error auth.*Authentication invalid"): + espota2.perform_ota(mock_socket, "wrongpassword", mock_file, "test.bin") + + # Verify the socket was closed after auth failure + mock_socket.close.assert_called() + + +def test_perform_ota_sha256_auth_without_password(mock_socket: Mock) -> None: + """Test OTA fails when SHA256 auth is required but no password provided.""" + mock_file = io.BytesIO(b"firmware") + + responses = [ + bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]), + bytes([espota2.RESPONSE_HEADER_OK]), + bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), + ] + + mock_socket.recv.side_effect = responses + + with pytest.raises( + espota2.OTAError, match="ESP requests password, but no password given" + ): + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + +def test_perform_ota_unexpected_auth_response(mock_socket: Mock) -> None: + """Test OTA fails when device sends an unexpected auth response.""" + mock_file = io.BytesIO(b"firmware") + + # Use 0x03 which is not in the expected auth responses + # This will be caught by check_error and raise "Unexpected response from ESP" + UNKNOWN_AUTH_METHOD = 0x03 + + responses = [ + bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]), + bytes([espota2.RESPONSE_HEADER_OK]), + bytes([UNKNOWN_AUTH_METHOD]), # Unknown auth method + ] + + mock_socket.recv.side_effect = responses + + # This will actually raise "Unexpected response from ESP" from check_error + with pytest.raises( + espota2.OTAError, match=r"Error auth: Unexpected response from ESP: 0x03" + ): + espota2.perform_ota(mock_socket, "password", mock_file, "test.bin") + + +def test_perform_ota_unsupported_version(mock_socket: Mock) -> None: + """Test OTA fails with unsupported version.""" + mock_file = io.BytesIO(b"firmware") + + responses = [ + bytes([espota2.RESPONSE_OK, 99]), # Unsupported version + ] + + mock_socket.recv.side_effect = responses + + with pytest.raises(espota2.OTAError, match="Device uses unsupported OTA version"): + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> None: + """Test OTA handles upload errors.""" + # Setup responses - provide enough for the recv calls + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + ] + # Add OSError to recv to simulate connection loss during chunk read + recv_responses.append(OSError("Connection lost")) + + mock_socket.recv.side_effect = recv_responses + + with pytest.raises(espota2.OTAError, match="Error receiving acknowledge chunk OK"): + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + +@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip") +def test_run_ota_impl_successful( + mock_socket: Mock, tmp_path: Path, mock_perform_ota: Mock +) -> None: + """Test run_ota_impl_ with successful upload.""" + # Create a real firmware file + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(b"firmware content") + + # Run OTA with real file path + result_code, result_host = espota2.run_ota_impl_( + "test.local", 3232, "password", str(firmware_file) + ) + + # Verify success + assert result_code == 0 + assert result_host == "192.168.1.100" + + # Verify socket was configured correctly + mock_socket.settimeout.assert_called_with(10.0) + mock_socket.connect.assert_called_once_with(("192.168.1.100", 3232)) + mock_socket.close.assert_called_once() + + # Verify perform_ota was called with real file + mock_perform_ota.assert_called_once() + call_args = mock_perform_ota.call_args[0] + assert call_args[0] == mock_socket + assert call_args[1] == "password" + # Verify the file object is a proper file handle + assert isinstance(call_args[2], io.IOBase) + assert call_args[3] == str(firmware_file) + + +@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip") +def test_run_ota_impl_connection_failed(mock_socket: Mock, tmp_path: Path) -> None: + """Test run_ota_impl_ when connection fails.""" + mock_socket.connect.side_effect = OSError("Connection refused") + + # Create a real firmware file + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(b"firmware content") + + result_code, result_host = espota2.run_ota_impl_( + "test.local", 3232, "password", str(firmware_file) + ) + + assert result_code == 1 + assert result_host is None + mock_socket.close.assert_called_once() + + +def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None: + """Test run_ota_impl_ when DNS resolution fails.""" + # Create a real firmware file + firmware_file = tmp_path / "firmware.bin" + firmware_file.write_bytes(b"firmware content") + + mock_resolve_ip.side_effect = EsphomeError("DNS resolution failed") + + with pytest.raises(espota2.OTAError, match="DNS resolution failed"): + result_code, result_host = espota2.run_ota_impl_( + "unknown.host", 3232, "password", str(firmware_file) + ) + + +def test_run_ota_wrapper(mock_run_ota_impl: Mock) -> None: + """Test run_ota wrapper function.""" + # Test successful case + mock_run_ota_impl.return_value = (0, "192.168.1.100") + result = espota2.run_ota("test.local", 3232, "pass", "fw.bin") + assert result == (0, "192.168.1.100") + + # Test error case + mock_run_ota_impl.side_effect = espota2.OTAError("Test error") + result = espota2.run_ota("test.local", 3232, "pass", "fw.bin") + assert result == (1, None) + + +def test_progress_bar(capsys: CaptureFixture[str]) -> None: + """Test ProgressBar functionality.""" + progress = espota2.ProgressBar() + + # Test initial update + progress.update(0.0) + captured = capsys.readouterr() + assert "0%" in captured.err + assert "[" in captured.err + + # Test progress update + progress.update(0.5) + captured = capsys.readouterr() + assert "50%" in captured.err + + # Test completion + progress.update(1.0) + captured = capsys.readouterr() + assert "100%" in captured.err + assert "Done" in captured.err + + # Test done method + progress.done() + captured = capsys.readouterr() + assert captured.err == "\n" + + # Test same progress doesn't update + progress.update(0.5) + progress.update(0.5) + captured = capsys.readouterr() + # Should only see one update (second call shouldn't write) + assert captured.err.count("50%") == 1 + + +# Tests for SHA256 authentication +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_successful_sha256_auth( + mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock +) -> None: + """Test successful OTA with SHA256 authentication.""" + # Setup socket responses for recv calls + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), # SHA256 Auth request + MOCK_SHA256_NONCE, # 64 char hex nonce + bytes([espota2.RESPONSE_AUTH_OK]), # Auth result + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + + # Run OTA + espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin") + + # Verify magic bytes were sent + assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES)) + + # Verify features were sent (compression + SHA256 support) + assert mock_socket.sendall.call_args_list[1] == call( + bytes( + [ + espota2.FEATURE_SUPPORTS_COMPRESSION + | espota2.FEATURE_SUPPORTS_SHA256_AUTH + ] + ) + ) + + # Verify cnonce was sent (SHA256 of random.random()) + cnonce = hashlib.sha256(MOCK_RANDOM_BYTES).hexdigest() + assert mock_socket.sendall.call_args_list[2] == call(cnonce.encode()) + + # Verify auth result was computed correctly with SHA256 + expected_hash = hashlib.sha256() + expected_hash.update(b"testpass") + expected_hash.update(MOCK_SHA256_NONCE) + expected_hash.update(cnonce.encode()) + expected_result = expected_hash.hexdigest() + assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode()) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_sha256_fallback_to_md5( + mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock +) -> None: + """Test SHA256-capable client falls back to MD5 for compatibility.""" + # This test verifies the temporary backward compatibility + # where a SHA256-capable client can still authenticate with MD5 + # This compatibility will be removed in 2026.1.0 + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes( + [espota2.RESPONSE_REQUEST_AUTH] + ), # MD5 Auth request (device doesn't support SHA256) + MOCK_MD5_NONCE, # 32 char hex nonce for MD5 + bytes([espota2.RESPONSE_AUTH_OK]), # Auth result + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + + # Run OTA - should work even though device requested MD5 + espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin") + + # Verify client still advertised SHA256 support + assert mock_socket.sendall.call_args_list[1] == call( + bytes( + [ + espota2.FEATURE_SUPPORTS_COMPRESSION + | espota2.FEATURE_SUPPORTS_SHA256_AUTH + ] + ) + ) + + # But authentication was done with MD5 + cnonce = hashlib.md5(MOCK_RANDOM_BYTES).hexdigest() + expected_hash = hashlib.md5() + expected_hash.update(b"testpass") + expected_hash.update(MOCK_MD5_NONCE) + expected_hash.update(cnonce.encode()) + expected_result = expected_hash.hexdigest() + assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode()) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_version_differences( + mock_socket: Mock, mock_file: io.BytesIO +) -> None: + """Test OTA behavior differences between version 1.0 and 2.0.""" + # Test version 1.0 - no chunk acknowledgments + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_1_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + # No RESPONSE_CHUNK_OK for v1 + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + # For v1.0, verify that we only get the expected number of recv calls + # v1.0 doesn't have chunk acknowledgments, so fewer recv calls + assert mock_socket.recv.call_count == 8 # v1.0 has 8 recv calls + + # Reset mock for v2.0 test + mock_socket.reset_mock() + + # Reset file position for second test + mock_file.seek(0) + + # Test version 2.0 - with chunk acknowledgments + recv_responses_v2 = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Features response + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_CHUNK_OK]), # v2.0 has chunk acknowledgment + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses_v2 + espota2.perform_ota(mock_socket, "", mock_file, "test.bin") + + # For v2.0, verify more recv calls due to chunk acknowledgments + assert mock_socket.recv.call_count == 9 # v2.0 has 9 recv calls (includes chunk OK) diff --git a/tests/unit_tests/test_external_files.py b/tests/unit_tests/test_external_files.py index 3fa7de2f64..05e0bd3523 100644 --- a/tests/unit_tests/test_external_files.py +++ b/tests/unit_tests/test_external_files.py @@ -42,7 +42,7 @@ def test_is_file_recent_with_recent_file(setup_core: Path) -> None: refresh = TimePeriod(seconds=3600) - result = external_files.is_file_recent(str(test_file), refresh) + result = external_files.is_file_recent(test_file, refresh) assert result is True @@ -53,11 +53,13 @@ def test_is_file_recent_with_old_file(setup_core: Path) -> None: test_file.write_text("content") old_time = time.time() - 7200 + mock_stat = MagicMock() + mock_stat.st_ctime = old_time - with patch("os.path.getctime", return_value=old_time): + with patch.object(Path, "stat", return_value=mock_stat): refresh = TimePeriod(seconds=3600) - result = external_files.is_file_recent(str(test_file), refresh) + result = external_files.is_file_recent(test_file, refresh) assert result is False @@ -67,7 +69,7 @@ def test_is_file_recent_nonexistent_file(setup_core: Path) -> None: test_file = setup_core / "nonexistent.txt" refresh = TimePeriod(seconds=3600) - result = external_files.is_file_recent(str(test_file), refresh) + result = external_files.is_file_recent(test_file, refresh) assert result is False @@ -77,10 +79,12 @@ def test_is_file_recent_with_zero_refresh(setup_core: Path) -> None: test_file = setup_core / "test.txt" test_file.write_text("content") - # Mock getctime to return a time 10 seconds ago - with patch("os.path.getctime", return_value=time.time() - 10): + # Mock stat to return a time 10 seconds ago + mock_stat = MagicMock() + mock_stat.st_ctime = time.time() - 10 + with patch.object(Path, "stat", return_value=mock_stat): refresh = TimePeriod(seconds=0) - result = external_files.is_file_recent(str(test_file), refresh) + result = external_files.is_file_recent(test_file, refresh) assert result is False @@ -97,7 +101,7 @@ def test_has_remote_file_changed_not_modified( mock_head.return_value = mock_response url = "https://example.com/file.txt" - result = external_files.has_remote_file_changed(url, str(test_file)) + result = external_files.has_remote_file_changed(url, test_file) assert result is False mock_head.assert_called_once() @@ -121,7 +125,7 @@ def test_has_remote_file_changed_modified( mock_head.return_value = mock_response url = "https://example.com/file.txt" - result = external_files.has_remote_file_changed(url, str(test_file)) + result = external_files.has_remote_file_changed(url, test_file) assert result is True @@ -131,7 +135,7 @@ def test_has_remote_file_changed_no_local_file(setup_core: Path) -> None: test_file = setup_core / "nonexistent.txt" url = "https://example.com/file.txt" - result = external_files.has_remote_file_changed(url, str(test_file)) + result = external_files.has_remote_file_changed(url, test_file) assert result is True @@ -149,7 +153,7 @@ def test_has_remote_file_changed_network_error( url = "https://example.com/file.txt" with pytest.raises(Invalid, match="Could not check if.*Network error"): - external_files.has_remote_file_changed(url, str(test_file)) + external_files.has_remote_file_changed(url, test_file) @patch("esphome.external_files.requests.head") @@ -165,7 +169,7 @@ def test_has_remote_file_changed_timeout( mock_head.return_value = mock_response url = "https://example.com/file.txt" - external_files.has_remote_file_changed(url, str(test_file)) + external_files.has_remote_file_changed(url, test_file) call_args = mock_head.call_args assert call_args[1]["timeout"] == external_files.NETWORK_TIMEOUT @@ -191,6 +195,6 @@ def test_is_file_recent_handles_float_seconds(setup_core: Path) -> None: refresh = TimePeriod(seconds=3600.5) - result = external_files.is_file_recent(str(test_file), refresh) + result = external_files.is_file_recent(test_file, refresh) assert result is True diff --git a/tests/unit_tests/test_git.py b/tests/unit_tests/test_git.py new file mode 100644 index 0000000000..6a51206ec2 --- /dev/null +++ b/tests/unit_tests/test_git.py @@ -0,0 +1,246 @@ +"""Tests for git.py module.""" + +from datetime import datetime, timedelta +import hashlib +import os +from pathlib import Path +from unittest.mock import Mock + +from esphome import git +from esphome.core import CORE, TimePeriodSeconds + + +def test_clone_or_update_with_never_refresh( + tmp_path: Path, mock_run_git_command: Mock +) -> None: + """Test that NEVER_REFRESH skips updates for existing repos.""" + # Set up CORE.config_path so data_dir uses tmp_path + CORE.config_path = tmp_path / "test.yaml" + + # Compute the expected repo directory path + url = "https://github.com/test/repo" + ref = None + key = f"{url}@{ref}" + domain = "test" + + # Compute hash-based directory name (matching _compute_destination_path logic) + h = hashlib.new("sha256") + h.update(key.encode()) + repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8] + + # Create the git repo directory structure + repo_dir.mkdir(parents=True) + git_dir = repo_dir / ".git" + git_dir.mkdir() + + # Create FETCH_HEAD file with current timestamp + fetch_head = git_dir / "FETCH_HEAD" + fetch_head.write_text("test") + + # Call with NEVER_REFRESH + result_dir, revert = git.clone_or_update( + url=url, + ref=ref, + refresh=git.NEVER_REFRESH, + domain=domain, + ) + + # Should NOT call git commands since NEVER_REFRESH and repo exists + mock_run_git_command.assert_not_called() + assert result_dir == repo_dir + assert revert is None + + +def test_clone_or_update_with_refresh_updates_old_repo( + tmp_path: Path, mock_run_git_command: Mock +) -> None: + """Test that refresh triggers update for old repos.""" + # Set up CORE.config_path so data_dir uses tmp_path + CORE.config_path = tmp_path / "test.yaml" + + # Compute the expected repo directory path + url = "https://github.com/test/repo" + ref = None + key = f"{url}@{ref}" + domain = "test" + + # Compute hash-based directory name (matching _compute_destination_path logic) + h = hashlib.new("sha256") + h.update(key.encode()) + repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8] + + # Create the git repo directory structure + repo_dir.mkdir(parents=True) + git_dir = repo_dir / ".git" + git_dir.mkdir() + + # Create FETCH_HEAD file with old timestamp (2 days ago) + fetch_head = git_dir / "FETCH_HEAD" + fetch_head.write_text("test") + old_time = datetime.now() - timedelta(days=2) + fetch_head.touch() # Create the file + # Set modification time to 2 days ago + os.utime(fetch_head, (old_time.timestamp(), old_time.timestamp())) + + # Mock git command responses + mock_run_git_command.return_value = "abc123" # SHA for rev-parse + + # Call with refresh=1d (1 day) + refresh = TimePeriodSeconds(days=1) + result_dir, revert = git.clone_or_update( + url=url, + ref=ref, + refresh=refresh, + domain=domain, + ) + + # Should call git fetch and update commands since repo is older than refresh + assert mock_run_git_command.called + # Check for fetch command + fetch_calls = [ + call + for call in mock_run_git_command.call_args_list + if len(call[0]) > 0 and "fetch" in call[0][0] + ] + assert len(fetch_calls) > 0 + + +def test_clone_or_update_with_refresh_skips_fresh_repo( + tmp_path: Path, mock_run_git_command: Mock +) -> None: + """Test that refresh doesn't update fresh repos.""" + # Set up CORE.config_path so data_dir uses tmp_path + CORE.config_path = tmp_path / "test.yaml" + + # Compute the expected repo directory path + url = "https://github.com/test/repo" + ref = None + key = f"{url}@{ref}" + domain = "test" + + # Compute hash-based directory name (matching _compute_destination_path logic) + h = hashlib.new("sha256") + h.update(key.encode()) + repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8] + + # Create the git repo directory structure + repo_dir.mkdir(parents=True) + git_dir = repo_dir / ".git" + git_dir.mkdir() + + # Create FETCH_HEAD file with recent timestamp (1 hour ago) + fetch_head = git_dir / "FETCH_HEAD" + fetch_head.write_text("test") + recent_time = datetime.now() - timedelta(hours=1) + fetch_head.touch() # Create the file + # Set modification time to 1 hour ago + os.utime(fetch_head, (recent_time.timestamp(), recent_time.timestamp())) + + # Call with refresh=1d (1 day) + refresh = TimePeriodSeconds(days=1) + result_dir, revert = git.clone_or_update( + url=url, + ref=ref, + refresh=refresh, + domain=domain, + ) + + # Should NOT call git fetch since repo is fresh + mock_run_git_command.assert_not_called() + assert result_dir == repo_dir + assert revert is None + + +def test_clone_or_update_clones_missing_repo( + tmp_path: Path, mock_run_git_command: Mock +) -> None: + """Test that missing repos are cloned regardless of refresh setting.""" + # Set up CORE.config_path so data_dir uses tmp_path + CORE.config_path = tmp_path / "test.yaml" + + # Compute the expected repo directory path + url = "https://github.com/test/repo" + ref = None + key = f"{url}@{ref}" + domain = "test" + + # Compute hash-based directory name (matching _compute_destination_path logic) + h = hashlib.new("sha256") + h.update(key.encode()) + repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8] + + # Create base directory but NOT the repo itself + base_dir = tmp_path / ".esphome" / domain + base_dir.mkdir(parents=True) + # repo_dir should NOT exist + assert not repo_dir.exists() + + # Test with NEVER_REFRESH - should still clone since repo doesn't exist + result_dir, revert = git.clone_or_update( + url=url, + ref=ref, + refresh=git.NEVER_REFRESH, + domain=domain, + ) + + # Should call git clone + assert mock_run_git_command.called + clone_calls = [ + call + for call in mock_run_git_command.call_args_list + if len(call[0]) > 0 and "clone" in call[0][0] + ] + assert len(clone_calls) > 0 + + +def test_clone_or_update_with_none_refresh_always_updates( + tmp_path: Path, mock_run_git_command: Mock +) -> None: + """Test that refresh=None always updates existing repos.""" + # Set up CORE.config_path so data_dir uses tmp_path + CORE.config_path = tmp_path / "test.yaml" + + # Compute the expected repo directory path + url = "https://github.com/test/repo" + ref = None + key = f"{url}@{ref}" + domain = "test" + + # Compute hash-based directory name (matching _compute_destination_path logic) + h = hashlib.new("sha256") + h.update(key.encode()) + repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8] + + # Create the git repo directory structure + repo_dir.mkdir(parents=True) + git_dir = repo_dir / ".git" + git_dir.mkdir() + + # Create FETCH_HEAD file with very recent timestamp (1 second ago) + fetch_head = git_dir / "FETCH_HEAD" + fetch_head.write_text("test") + recent_time = datetime.now() - timedelta(seconds=1) + fetch_head.touch() # Create the file + # Set modification time to 1 second ago + os.utime(fetch_head, (recent_time.timestamp(), recent_time.timestamp())) + + # Mock git command responses + mock_run_git_command.return_value = "abc123" # SHA for rev-parse + + # Call with refresh=None (default behavior) + result_dir, revert = git.clone_or_update( + url=url, + ref=ref, + refresh=None, + domain=domain, + ) + + # Should call git fetch and update commands since refresh=None means always update + assert mock_run_git_command.called + # Check for fetch command + fetch_calls = [ + call + for call in mock_run_git_command.call_args_list + if len(call[0]) > 0 and "fetch" in call[0][0] + ] + assert len(fetch_calls) > 0 diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index 9f51206ff9..87ed901ecb 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -1,5 +1,8 @@ import logging +import os +from pathlib import Path import socket +import stat from unittest.mock import patch from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr @@ -8,6 +11,7 @@ from hypothesis.strategies import ip_addresses import pytest from esphome import helpers +from esphome.address_cache import AddressCache from esphome.core import EsphomeError @@ -150,11 +154,11 @@ def test_walk_files(fixture_path): actual = list(helpers.walk_files(path)) # Ensure paths start with the root - assert all(p.startswith(str(path)) for p in actual) + assert all(p.is_relative_to(path) for p in actual) class Test_write_file_if_changed: - def test_src_and_dst_match(self, tmp_path): + def test_src_and_dst_match(self, tmp_path: Path): text = "A files are unique.\n" initial = text dst = tmp_path / "file-a.txt" @@ -164,7 +168,7 @@ class Test_write_file_if_changed: assert dst.read_text() == text - def test_src_and_dst_do_not_match(self, tmp_path): + def test_src_and_dst_do_not_match(self, tmp_path: Path): text = "A files are unique.\n" initial = "B files are unique.\n" dst = tmp_path / "file-a.txt" @@ -174,7 +178,7 @@ class Test_write_file_if_changed: assert dst.read_text() == text - def test_dst_does_not_exist(self, tmp_path): + def test_dst_does_not_exist(self, tmp_path: Path): text = "A files are unique.\n" dst = tmp_path / "file-a.txt" @@ -184,7 +188,7 @@ class Test_write_file_if_changed: class Test_copy_file_if_changed: - def test_src_and_dst_match(self, tmp_path, fixture_path): + def test_src_and_dst_match(self, tmp_path: Path, fixture_path: Path): src = fixture_path / "helpers" / "file-a.txt" initial = fixture_path / "helpers" / "file-a.txt" dst = tmp_path / "file-a.txt" @@ -193,7 +197,7 @@ class Test_copy_file_if_changed: helpers.copy_file_if_changed(src, dst) - def test_src_and_dst_do_not_match(self, tmp_path, fixture_path): + def test_src_and_dst_do_not_match(self, tmp_path: Path, fixture_path: Path): src = fixture_path / "helpers" / "file-a.txt" initial = fixture_path / "helpers" / "file-c.txt" dst = tmp_path / "file-a.txt" @@ -204,7 +208,7 @@ class Test_copy_file_if_changed: assert src.read_text() == dst.read_text() - def test_dst_does_not_exist(self, tmp_path, fixture_path): + def test_dst_does_not_exist(self, tmp_path: Path, fixture_path: Path): src = fixture_path / "helpers" / "file-a.txt" dst = tmp_path / "file-a.txt" @@ -554,6 +558,217 @@ def test_addr_preference_ipv6_link_local_with_scope() -> None: assert helpers.addr_preference_(addr_info) == 1 # Has scope, so it's usable +def test_mkdir_p(tmp_path: Path) -> None: + """Test mkdir_p creates directories recursively.""" + # Test creating nested directories + nested_path = tmp_path / "level1" / "level2" / "level3" + helpers.mkdir_p(nested_path) + assert nested_path.exists() + assert nested_path.is_dir() + + # Test that mkdir_p is idempotent (doesn't fail if directory exists) + helpers.mkdir_p(nested_path) + assert nested_path.exists() + + # Test with empty path (should do nothing) + helpers.mkdir_p("") + + # Test with existing directory + existing_dir = tmp_path / "existing" + existing_dir.mkdir() + helpers.mkdir_p(existing_dir) + assert existing_dir.exists() + + +def test_mkdir_p_file_exists_error(tmp_path: Path) -> None: + """Test mkdir_p raises error when path is a file.""" + # Create a file + file_path = tmp_path / "test_file.txt" + file_path.write_text("test content") + + # Try to create directory with same name as existing file + with pytest.raises(EsphomeError, match=r"Error creating directories"): + helpers.mkdir_p(file_path) + + +def test_mkdir_p_with_existing_file_raises_error(tmp_path: Path) -> None: + """Test mkdir_p raises error when trying to create dir over existing file.""" + # Create a file where we want to create a directory + file_path = tmp_path / "existing_file" + file_path.write_text("content") + + # Try to create a directory with a path that goes through the file + dir_path = file_path / "subdir" + + with pytest.raises(EsphomeError, match=r"Error creating directories"): + helpers.mkdir_p(dir_path) + + +def test_read_file(tmp_path: Path) -> None: + """Test read_file reads file content correctly.""" + # Test reading regular file + test_file = tmp_path / "test.txt" + expected_content = "Test content\nLine 2\n" + test_file.write_text(expected_content) + + content = helpers.read_file(test_file) + assert content == expected_content + + # Test reading file with UTF-8 characters + utf8_file = tmp_path / "utf8.txt" + utf8_content = "Hello 世界 🌍" + utf8_file.write_text(utf8_content, encoding="utf-8") + + content = helpers.read_file(utf8_file) + assert content == utf8_content + + +def test_read_file_not_found() -> None: + """Test read_file raises error for non-existent file.""" + with pytest.raises(EsphomeError, match=r"Error reading file"): + helpers.read_file(Path("/nonexistent/file.txt")) + + +def test_read_file_unicode_decode_error(tmp_path: Path) -> None: + """Test read_file raises error for invalid UTF-8.""" + test_file = tmp_path / "invalid.txt" + # Write invalid UTF-8 bytes + test_file.write_bytes(b"\xff\xfe") + + with pytest.raises(EsphomeError, match=r"Error reading file"): + helpers.read_file(test_file) + + +@pytest.mark.skipif(os.name == "nt", reason="Unix-specific test") +def test_write_file_unix(tmp_path: Path) -> None: + """Test write_file writes content correctly on Unix.""" + # Test writing string content + test_file = tmp_path / "test.txt" + content = "Test content\nLine 2" + helpers.write_file(test_file, content) + + assert test_file.read_text() == content + # Check file permissions + assert oct(test_file.stat().st_mode)[-3:] == "644" + + # Test overwriting existing file + new_content = "New content" + helpers.write_file(test_file, new_content) + assert test_file.read_text() == new_content + + # Test writing to nested directories (should create them) + nested_file = tmp_path / "dir1" / "dir2" / "file.txt" + helpers.write_file(nested_file, content) + assert nested_file.read_text() == content + + +@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") +def test_write_file_windows(tmp_path: Path) -> None: + """Test write_file writes content correctly on Windows.""" + # Test writing string content + test_file = tmp_path / "test.txt" + content = "Test content\nLine 2" + helpers.write_file(test_file, content) + + assert test_file.read_text() == content + # Windows doesn't have Unix-style 644 permissions + + # Test overwriting existing file + new_content = "New content" + helpers.write_file(test_file, new_content) + assert test_file.read_text() == new_content + + # Test writing to nested directories (should create them) + nested_file = tmp_path / "dir1" / "dir2" / "file.txt" + helpers.write_file(nested_file, content) + assert nested_file.read_text() == content + + +@pytest.mark.skipif(os.name == "nt", reason="Unix-specific permission test") +def test_write_file_to_non_writable_directory_unix(tmp_path: Path) -> None: + """Test write_file raises error when directory is not writable on Unix.""" + # Create a directory and make it read-only + read_only_dir = tmp_path / "readonly" + read_only_dir.mkdir() + test_file = read_only_dir / "test.txt" + + # Make directory read-only (no write permission) + read_only_dir.chmod(0o555) + + try: + with pytest.raises(EsphomeError, match=r"Could not write file"): + helpers.write_file(test_file, "content") + finally: + # Restore write permissions for cleanup + read_only_dir.chmod(0o755) + + +@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") +def test_write_file_to_non_writable_directory_windows(tmp_path: Path) -> None: + """Test write_file error handling on Windows.""" + # Windows handles permissions differently - test a different error case + # Try to write to a file path that contains an existing file as a directory component + existing_file = tmp_path / "file.txt" + existing_file.write_text("content") + + # Try to write to a path that treats the file as a directory + invalid_path = existing_file / "subdir" / "test.txt" + + with pytest.raises(EsphomeError, match=r"Could not write file"): + helpers.write_file(invalid_path, "content") + + +@pytest.mark.skipif(os.name == "nt", reason="Unix-specific permission test") +def test_write_file_with_permission_bits_unix(tmp_path: Path) -> None: + """Test that write_file sets correct permissions on Unix.""" + test_file = tmp_path / "test.txt" + helpers.write_file(test_file, "content") + + # Check that file has 644 permissions + file_mode = test_file.stat().st_mode + assert stat.S_IMODE(file_mode) == 0o644 + + +@pytest.mark.skipif(os.name == "nt", reason="Unix-specific permission test") +def test_copy_file_if_changed_permission_recovery_unix(tmp_path: Path) -> None: + """Test copy_file_if_changed handles permission errors correctly on Unix.""" + # Test with read-only destination file + src = tmp_path / "source.txt" + dst = tmp_path / "dest.txt" + src.write_text("new content") + dst.write_text("old content") + dst.chmod(0o444) # Make destination read-only + + try: + # Should handle permission error by deleting and retrying + helpers.copy_file_if_changed(src, dst) + assert dst.read_text() == "new content" + finally: + # Restore write permissions for cleanup + if dst.exists(): + dst.chmod(0o644) + + +def test_copy_file_if_changed_creates_directories(tmp_path: Path) -> None: + """Test copy_file_if_changed creates missing directories.""" + src = tmp_path / "source.txt" + dst = tmp_path / "subdir" / "nested" / "dest.txt" + src.write_text("content") + + helpers.copy_file_if_changed(src, dst) + assert dst.exists() + assert dst.read_text() == "content" + + +def test_copy_file_if_changed_nonexistent_source(tmp_path: Path) -> None: + """Test copy_file_if_changed with non-existent source.""" + src = tmp_path / "nonexistent.txt" + dst = tmp_path / "dest.txt" + + with pytest.raises(EsphomeError, match=r"Error copying file"): + helpers.copy_file_if_changed(src, dst) + + def test_resolve_ip_address_sorting() -> None: """Test that results are sorted by preference.""" # Create multiple address infos with different preferences @@ -594,3 +809,84 @@ def test_resolve_ip_address_sorting() -> None: assert result[0][4][0] == "2001:db8::1" # IPv6 (preference 1) assert result[1][4][0] == "192.168.1.100" # IPv4 (preference 2) assert result[2][4][0] == "fe80::1" # Link-local no scope (preference 3) + + +def test_resolve_ip_address_with_cache() -> None: + """Test that the cache is used when provided.""" + cache = AddressCache( + mdns_cache={"test.local": ["192.168.1.100", "192.168.1.101"]}, + dns_cache={ + "example.com": ["93.184.216.34", "2606:2800:220:1:248:1893:25c8:1946"] + }, + ) + + # Test mDNS cache hit + result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache) + + # Should return cached addresses without calling resolver + assert len(result) == 2 + assert result[0][4][0] == "192.168.1.100" + assert result[1][4][0] == "192.168.1.101" + + # Test DNS cache hit + result = helpers.resolve_ip_address("example.com", 6053, address_cache=cache) + + # Should return cached addresses with IPv6 first due to preference + assert len(result) == 2 + assert result[0][4][0] == "2606:2800:220:1:248:1893:25c8:1946" # IPv6 first + assert result[1][4][0] == "93.184.216.34" # IPv4 second + + +def test_resolve_ip_address_cache_miss() -> None: + """Test that resolver is called when not in cache.""" + cache = AddressCache(mdns_cache={"other.local": ["192.168.1.200"]}) + + mock_addr_info = AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053), + ) + + with patch("esphome.resolver.AsyncResolver") as MockResolver: + mock_resolver = MockResolver.return_value + mock_resolver.resolve.return_value = [mock_addr_info] + + result = helpers.resolve_ip_address("test.local", 6053, address_cache=cache) + + # Should call resolver since test.local is not in cache + MockResolver.assert_called_once_with(["test.local"], 6053) + assert len(result) == 1 + assert result[0][4][0] == "192.168.1.100" + + +def test_resolve_ip_address_mixed_cached_uncached() -> None: + """Test resolution with mix of cached and uncached hosts.""" + cache = AddressCache(mdns_cache={"cached.local": ["192.168.1.50"]}) + + mock_addr_info = AddrInfo( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=IPv4Sockaddr(address="192.168.1.100", port=6053), + ) + + with patch("esphome.resolver.AsyncResolver") as MockResolver: + mock_resolver = MockResolver.return_value + mock_resolver.resolve.return_value = [mock_addr_info] + + # Pass a list with cached IP, cached hostname, and uncached hostname + result = helpers.resolve_ip_address( + ["192.168.1.10", "cached.local", "uncached.local"], + 6053, + address_cache=cache, + ) + + # Should only resolve uncached.local + MockResolver.assert_called_once_with(["uncached.local"], 6053) + + # Results should include all addresses + addresses = [r[4][0] for r in result] + assert "192.168.1.10" in addresses # Direct IP + assert "192.168.1.50" in addresses # From cache + assert "192.168.1.100" in addresses # From resolver diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index fff0b2cd48..e35378145a 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -4,17 +4,22 @@ from __future__ import annotations from collections.abc import Generator from dataclasses import dataclass +import logging from pathlib import Path +import re from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest from pytest import CaptureFixture +from esphome import platformio_api from esphome.__main__ import ( Purpose, choose_upload_log_host, + command_clean_all, command_rename, + command_update_all, command_wizard, get_port_type, has_ip_address, @@ -26,7 +31,9 @@ from esphome.__main__ import ( mqtt_get_ip, show_logs, upload_program, + upload_using_esptool, ) +from esphome.components.esp32.const import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32 from esphome.const import ( CONF_API, CONF_BROKER, @@ -55,6 +62,17 @@ from esphome.const import ( from esphome.core import CORE, EsphomeError +def strip_ansi_codes(text: str) -> str: + """Remove ANSI escape codes from text. + + This helps make test assertions cleaner by removing color codes and other + terminal formatting that can make tests brittle. + """ + # Pattern to match ANSI escape sequences + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + @dataclass class MockSerialPort: """Mock serial port for testing. @@ -207,6 +225,14 @@ def mock_run_external_process() -> Generator[Mock]: yield mock +@pytest.fixture +def mock_run_external_command() -> Generator[Mock]: + """Mock run_external_command for testing.""" + with patch("esphome.__main__.run_external_command") as mock: + mock.return_value = 0 # Default to success + yield mock + + def test_choose_upload_log_host_with_string_default() -> None: """Test with a single string default device.""" setup_core() @@ -805,6 +831,122 @@ def test_upload_program_serial_esp8266_with_file( ) +def test_upload_using_esptool_path_conversion( + tmp_path: Path, + mock_run_external_command: Mock, + mock_get_idedata: Mock, +) -> None: + """Test upload_using_esptool properly converts Path objects to strings for esptool. + + This test ensures that img.path (Path object) is converted to string before + passing to esptool, preventing AttributeError. + """ + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path, name="test") + + # Set up ESP32-specific data required by get_esp32_variant() + CORE.data[KEY_ESP32] = {KEY_VARIANT: VARIANT_ESP32} + + # Create mock IDEData with Path objects + mock_idedata = MagicMock(spec=platformio_api.IDEData) + mock_idedata.firmware_bin_path = tmp_path / "firmware.bin" + mock_idedata.extra_flash_images = [ + platformio_api.FlashImage(path=tmp_path / "bootloader.bin", offset="0x1000"), + platformio_api.FlashImage(path=tmp_path / "partitions.bin", offset="0x8000"), + ] + + mock_get_idedata.return_value = mock_idedata + + # Create the actual firmware files so they exist + (tmp_path / "firmware.bin").touch() + (tmp_path / "bootloader.bin").touch() + (tmp_path / "partitions.bin").touch() + + config = {CONF_ESPHOME: {"platformio_options": {}}} + + # Call upload_using_esptool without custom file argument + result = upload_using_esptool(config, "/dev/ttyUSB0", None, None) + + assert result == 0 + + # Verify that run_external_command was called + assert mock_run_external_command.call_count == 1 + + # Get the actual call arguments + call_args = mock_run_external_command.call_args[0] + + # The first argument should be esptool.main function, + # followed by the command arguments + assert len(call_args) > 1 + + # Find the indices of the flash image arguments + # They should come after "write-flash" and "-z" + cmd_list = list(call_args[1:]) # Skip the esptool.main function + + # Verify all paths are strings, not Path objects + # The firmware and flash images should be at specific positions + write_flash_idx = cmd_list.index("write-flash") + + # After write-flash we have: -z, --flash-size, detect, then offset/path pairs + # Check firmware at offset 0x10000 (ESP32) + firmware_offset_idx = write_flash_idx + 4 + assert cmd_list[firmware_offset_idx] == "0x10000" + firmware_path = cmd_list[firmware_offset_idx + 1] + assert isinstance(firmware_path, str) + assert firmware_path.endswith("firmware.bin") + + # Check bootloader + bootloader_offset_idx = firmware_offset_idx + 2 + assert cmd_list[bootloader_offset_idx] == "0x1000" + bootloader_path = cmd_list[bootloader_offset_idx + 1] + assert isinstance(bootloader_path, str) + assert bootloader_path.endswith("bootloader.bin") + + # Check partitions + partitions_offset_idx = bootloader_offset_idx + 2 + assert cmd_list[partitions_offset_idx] == "0x8000" + partitions_path = cmd_list[partitions_offset_idx + 1] + assert isinstance(partitions_path, str) + assert partitions_path.endswith("partitions.bin") + + +def test_upload_using_esptool_with_file_path( + tmp_path: Path, + mock_run_external_command: Mock, +) -> None: + """Test upload_using_esptool with a custom file that's a Path object.""" + setup_core(platform=PLATFORM_ESP8266, tmp_path=tmp_path, name="test") + + # Create a test firmware file + firmware_file = tmp_path / "custom_firmware.bin" + firmware_file.touch() + + config = {CONF_ESPHOME: {"platformio_options": {}}} + + # Call with a Path object as the file argument (though usually it's a string) + result = upload_using_esptool(config, "/dev/ttyUSB0", str(firmware_file), None) + + assert result == 0 + + # Verify that run_external_command was called + mock_run_external_command.assert_called_once() + + # Get the actual call arguments + call_args = mock_run_external_command.call_args[0] + cmd_list = list(call_args[1:]) # Skip the esptool.main function + + # Find the firmware path in the command + write_flash_idx = cmd_list.index("write-flash") + + # For custom file, it should be at offset 0x0 + firmware_offset_idx = write_flash_idx + 4 + assert cmd_list[firmware_offset_idx] == "0x0" + firmware_path = cmd_list[firmware_offset_idx + 1] + + # Verify it's a string, not a Path object + assert isinstance(firmware_path, str) + assert firmware_path.endswith("custom_firmware.bin") + + @pytest.mark.parametrize( "platform,device", [ @@ -885,7 +1027,7 @@ def test_upload_program_ota_success( assert exit_code == 0 assert host == "192.168.1.100" - expected_firmware = str( + expected_firmware = ( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( @@ -919,7 +1061,9 @@ def test_upload_program_ota_with_file_arg( assert exit_code == 0 assert host == "192.168.1.100" - mock_run_ota.assert_called_once_with(["192.168.1.100"], 3232, "", "custom.bin") + mock_run_ota.assert_called_once_with( + ["192.168.1.100"], 3232, "", Path("custom.bin") + ) def test_upload_program_ota_no_config( @@ -972,7 +1116,7 @@ def test_upload_program_ota_with_mqtt_resolution( assert exit_code == 0 assert host == "192.168.1.100" mock_mqtt_get_ip.assert_called_once_with(config, "user", "pass", "client") - expected_firmware = str( + expected_firmware = ( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with(["192.168.1.100"], 3232, "", expected_firmware) @@ -1382,7 +1526,7 @@ def test_command_wizard(tmp_path: Path) -> None: result = command_wizard(args) assert result == 0 - mock_wizard.assert_called_once_with(str(config_file)) + mock_wizard.assert_called_once_with(config_file) def test_command_rename_invalid_characters( @@ -1407,7 +1551,7 @@ def test_command_rename_complex_yaml( config_file = tmp_path / "test.yaml" config_file.write_text("# Complex YAML without esphome section\nsome_key: value\n") setup_core(tmp_path=tmp_path) - CORE.config_path = str(config_file) + CORE.config_path = config_file args = MockArgs(name="newname") result = command_rename(args, {}) @@ -1436,7 +1580,7 @@ wifi: password: "test1234" """) setup_core(tmp_path=tmp_path) - CORE.config_path = str(config_file) + CORE.config_path = config_file # Set up CORE.config to avoid ValueError when accessing CORE.address CORE.config = {CONF_ESPHOME: {CONF_NAME: "oldname"}} @@ -1486,7 +1630,7 @@ esp32: board: nodemcu-32s """) setup_core(tmp_path=tmp_path) - CORE.config_path = str(config_file) + CORE.config_path = config_file # Set up CORE.config to avoid ValueError when accessing CORE.address CORE.config = { @@ -1523,7 +1667,7 @@ esp32: board: nodemcu-32s """) setup_core(tmp_path=tmp_path) - CORE.config_path = str(config_file) + CORE.config_path = config_file args = MockArgs(name="newname", dashboard=False) @@ -1543,3 +1687,263 @@ esp32: captured = capfd.readouterr() assert "Rename failed" in captured.out + + +def test_command_update_all_path_string_conversion( + tmp_path: Path, + mock_run_external_process: Mock, + capfd: CaptureFixture[str], +) -> None: + """Test that command_update_all properly converts Path objects to strings in output.""" + yaml1 = tmp_path / "device1.yaml" + yaml1.write_text(""" +esphome: + name: device1 + +esp32: + board: nodemcu-32s +""") + + yaml2 = tmp_path / "device2.yaml" + yaml2.write_text(""" +esphome: + name: device2 + +esp8266: + board: nodemcuv2 +""") + + setup_core(tmp_path=tmp_path) + mock_run_external_process.return_value = 0 + + assert command_update_all(MockArgs(configuration=[str(tmp_path)])) == 0 + + captured = capfd.readouterr() + clean_output = strip_ansi_codes(captured.out) + + # Check that Path objects were properly converted to strings + # The output should contain file paths without causing TypeError + assert "device1.yaml" in clean_output + assert "device2.yaml" in clean_output + assert "SUCCESS" in clean_output + assert "SUMMARY" in clean_output + + # Verify run_external_process was called for each file + assert mock_run_external_process.call_count == 2 + + +def test_command_update_all_with_failures( + tmp_path: Path, + mock_run_external_process: Mock, + capfd: CaptureFixture[str], +) -> None: + """Test command_update_all handles mixed success/failure cases properly.""" + yaml1 = tmp_path / "success_device.yaml" + yaml1.write_text(""" +esphome: + name: success_device + +esp32: + board: nodemcu-32s +""") + + yaml2 = tmp_path / "failed_device.yaml" + yaml2.write_text(""" +esphome: + name: failed_device + +esp8266: + board: nodemcuv2 +""") + + setup_core(tmp_path=tmp_path) + + # Mock mixed results - first succeeds, second fails + mock_run_external_process.side_effect = [0, 1] + + # Should return 1 (failure) since one device failed + assert command_update_all(MockArgs(configuration=[str(tmp_path)])) == 1 + + captured = capfd.readouterr() + clean_output = strip_ansi_codes(captured.out) + + # Check that both success and failure are properly displayed + assert "SUCCESS" in clean_output + assert "ERROR" in clean_output or "FAILED" in clean_output + assert "SUMMARY" in clean_output + + # Files are processed in alphabetical order, so we need to check which one succeeded/failed + # The mock_run_external_process.side_effect = [0, 1] applies to files in alphabetical order + # So "failed_device.yaml" gets 0 (success) and "success_device.yaml" gets 1 (failure) + assert "failed_device.yaml: SUCCESS" in clean_output + assert "success_device.yaml: FAILED" in clean_output + + +def test_command_update_all_empty_directory( + tmp_path: Path, + mock_run_external_process: Mock, + capfd: CaptureFixture[str], +) -> None: + """Test command_update_all with an empty directory (no YAML files).""" + setup_core(tmp_path=tmp_path) + + assert command_update_all(MockArgs(configuration=[str(tmp_path)])) == 0 + mock_run_external_process.assert_not_called() + + captured = capfd.readouterr() + clean_output = strip_ansi_codes(captured.out) + + assert "SUMMARY" in clean_output + + +def test_command_update_all_single_file( + tmp_path: Path, + mock_run_external_process: Mock, + capfd: CaptureFixture[str], +) -> None: + """Test command_update_all with a single YAML file specified.""" + yaml_file = tmp_path / "single_device.yaml" + yaml_file.write_text(""" +esphome: + name: single_device + +esp32: + board: nodemcu-32s +""") + + setup_core(tmp_path=tmp_path) + mock_run_external_process.return_value = 0 + + assert command_update_all(MockArgs(configuration=[str(yaml_file)])) == 0 + + captured = capfd.readouterr() + clean_output = strip_ansi_codes(captured.out) + + assert "single_device.yaml" in clean_output + assert "SUCCESS" in clean_output + mock_run_external_process.assert_called_once() + + +def test_command_update_all_path_formatting_in_color_calls( + tmp_path: Path, + mock_run_external_process: Mock, + capfd: CaptureFixture[str], +) -> None: + """Test that Path objects are properly converted when passed to color() function.""" + yaml_file = tmp_path / "test-device_123.yaml" + yaml_file.write_text(""" +esphome: + name: test-device_123 + +esp32: + board: nodemcu-32s +""") + + setup_core(tmp_path=tmp_path) + mock_run_external_process.return_value = 0 + + assert command_update_all(MockArgs(configuration=[str(tmp_path)])) == 0 + + captured = capfd.readouterr() + clean_output = strip_ansi_codes(captured.out) + + assert "test-device_123.yaml" in clean_output + assert "Updating" in clean_output + assert "SUCCESS" in clean_output + assert "SUMMARY" in clean_output + + # Should not have any Python error messages + assert "TypeError" not in clean_output + assert "can only concatenate str" not in clean_output + + +def test_command_clean_all_success( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test command_clean_all when writer.clean_all() succeeds.""" + args = MockArgs(configuration=["/path/to/config1", "/path/to/config2"]) + + # Set logger level to capture INFO messages + with ( + caplog.at_level(logging.INFO), + patch("esphome.writer.clean_all") as mock_clean_all, + ): + result = command_clean_all(args) + + assert result == 0 + mock_clean_all.assert_called_once_with(["/path/to/config1", "/path/to/config2"]) + + # Check that success message was logged + assert "Done!" in caplog.text + + +def test_command_clean_all_oserror( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test command_clean_all when writer.clean_all() raises OSError.""" + args = MockArgs(configuration=["/path/to/config1"]) + + # Create a mock OSError with a specific message + mock_error = OSError("Permission denied: cannot delete directory") + + # Set logger level to capture ERROR and INFO messages + with ( + caplog.at_level(logging.INFO), + patch("esphome.writer.clean_all", side_effect=mock_error) as mock_clean_all, + ): + result = command_clean_all(args) + + assert result == 1 + mock_clean_all.assert_called_once_with(["/path/to/config1"]) + + # Check that error message was logged + assert ( + "Error cleaning all files: Permission denied: cannot delete directory" + in caplog.text + ) + # Should not have success message + assert "Done!" not in caplog.text + + +def test_command_clean_all_oserror_no_message( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test command_clean_all when writer.clean_all() raises OSError without message.""" + args = MockArgs(configuration=["/path/to/config1"]) + + # Create a mock OSError without a message + mock_error = OSError() + + # Set logger level to capture ERROR and INFO messages + with ( + caplog.at_level(logging.INFO), + patch("esphome.writer.clean_all", side_effect=mock_error) as mock_clean_all, + ): + result = command_clean_all(args) + + assert result == 1 + mock_clean_all.assert_called_once_with(["/path/to/config1"]) + + # Check that error message was logged (should show empty string for OSError without message) + assert "Error cleaning all files:" in caplog.text + # Should not have success message + assert "Done!" not in caplog.text + + +def test_command_clean_all_args_used() -> None: + """Test that command_clean_all uses args.configuration parameter.""" + # Test with different configuration paths + args1 = MockArgs(configuration=["/path/to/config1"]) + args2 = MockArgs(configuration=["/path/to/config2", "/path/to/config3"]) + + with patch("esphome.writer.clean_all") as mock_clean_all: + result1 = command_clean_all(args1) + result2 = command_clean_all(args2) + + assert result1 == 0 + assert result2 == 0 + assert mock_clean_all.call_count == 2 + + # Verify the correct configuration paths were passed + mock_clean_all.assert_any_call(["/path/to/config1"]) + mock_clean_all.assert_any_call(["/path/to/config2", "/path/to/config3"]) diff --git a/tests/unit_tests/test_platformio_api.py b/tests/unit_tests/test_platformio_api.py index 7c7883d391..07948cc6ad 100644 --- a/tests/unit_tests/test_platformio_api.py +++ b/tests/unit_tests/test_platformio_api.py @@ -15,45 +15,45 @@ from esphome.core import CORE, EsphomeError def test_idedata_firmware_elf_path(setup_core: Path) -> None: """Test IDEData.firmware_elf_path returns correct path.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" raw_data = {"prog_path": "/path/to/firmware.elf"} idedata = platformio_api.IDEData(raw_data) - assert idedata.firmware_elf_path == "/path/to/firmware.elf" + assert idedata.firmware_elf_path == Path("/path/to/firmware.elf") def test_idedata_firmware_bin_path(setup_core: Path) -> None: """Test IDEData.firmware_bin_path returns Path with .bin extension.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" prog_path = str(Path("/path/to/firmware.elf")) raw_data = {"prog_path": prog_path} idedata = platformio_api.IDEData(raw_data) result = idedata.firmware_bin_path - assert isinstance(result, str) - expected = str(Path("/path/to/firmware.bin")) + assert isinstance(result, Path) + expected = Path("/path/to/firmware.bin") assert result == expected - assert result.endswith(".bin") + assert str(result).endswith(".bin") def test_idedata_firmware_bin_path_preserves_directory(setup_core: Path) -> None: """Test firmware_bin_path preserves the directory structure.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" prog_path = str(Path("/complex/path/to/build/firmware.elf")) raw_data = {"prog_path": prog_path} idedata = platformio_api.IDEData(raw_data) result = idedata.firmware_bin_path - expected = str(Path("/complex/path/to/build/firmware.bin")) + expected = Path("/complex/path/to/build/firmware.bin") assert result == expected def test_idedata_extra_flash_images(setup_core: Path) -> None: """Test IDEData.extra_flash_images returns list of FlashImage objects.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" raw_data = { "prog_path": "/path/to/firmware.elf", @@ -69,15 +69,15 @@ def test_idedata_extra_flash_images(setup_core: Path) -> None: images = idedata.extra_flash_images assert len(images) == 2 assert all(isinstance(img, platformio_api.FlashImage) for img in images) - assert images[0].path == "/path/to/bootloader.bin" + assert images[0].path == Path("/path/to/bootloader.bin") assert images[0].offset == "0x1000" - assert images[1].path == "/path/to/partition.bin" + assert images[1].path == Path("/path/to/partition.bin") assert images[1].offset == "0x8000" def test_idedata_extra_flash_images_empty(setup_core: Path) -> None: """Test extra_flash_images returns empty list when no extra images.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" raw_data = {"prog_path": "/path/to/firmware.elf", "extra": {"flash_images": []}} idedata = platformio_api.IDEData(raw_data) @@ -88,7 +88,7 @@ def test_idedata_extra_flash_images_empty(setup_core: Path) -> None: def test_idedata_cc_path(setup_core: Path) -> None: """Test IDEData.cc_path returns compiler path.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" raw_data = { "prog_path": "/path/to/firmware.elf", @@ -104,9 +104,9 @@ def test_idedata_cc_path(setup_core: Path) -> None: def test_flash_image_dataclass() -> None: """Test FlashImage dataclass stores path and offset correctly.""" - image = platformio_api.FlashImage(path="/path/to/image.bin", offset="0x10000") + image = platformio_api.FlashImage(path=Path("/path/to/image.bin"), offset="0x10000") - assert image.path == "/path/to/image.bin" + assert image.path == Path("/path/to/image.bin") assert image.offset == "0x10000" @@ -114,7 +114,7 @@ def test_load_idedata_returns_dict( setup_core: Path, mock_run_platformio_cli_run ) -> None: """Test _load_idedata returns parsed idedata dict when successful.""" - CORE.build_path = str(setup_core / "build" / "test") + CORE.build_path = setup_core / "build" / "test" CORE.name = "test" # Create required files @@ -366,7 +366,7 @@ def test_get_idedata_caches_result( assert result1 is result2 assert isinstance(result1, platformio_api.IDEData) - assert result1.firmware_elf_path == "/test/firmware.elf" + assert result1.firmware_elf_path == Path("/test/firmware.elf") def test_idedata_addr2line_path_windows(setup_core: Path) -> None: @@ -434,9 +434,9 @@ def test_patched_clean_build_dir_removes_outdated(setup_core: Path) -> None: os.utime(platformio_ini, (build_mtime + 1, build_mtime + 1)) # Track if directory was removed - removed_paths: list[str] = [] + removed_paths: list[Path] = [] - def track_rmtree(path: str) -> None: + def track_rmtree(path: Path) -> None: removed_paths.append(path) shutil.rmtree(path) @@ -466,7 +466,7 @@ def test_patched_clean_build_dir_removes_outdated(setup_core: Path) -> None: # Verify directory was removed and recreated assert len(removed_paths) == 1 - assert removed_paths[0] == str(build_dir) + assert removed_paths[0] == build_dir assert build_dir.exists() # makedirs recreated it diff --git a/tests/unit_tests/test_storage_json.py b/tests/unit_tests/test_storage_json.py index e1abe565b1..a3a38960e7 100644 --- a/tests/unit_tests/test_storage_json.py +++ b/tests/unit_tests/test_storage_json.py @@ -15,12 +15,12 @@ from esphome.core import CORE def test_storage_path(setup_core: Path) -> None: """Test storage_path returns correct path for current config.""" - CORE.config_path = str(setup_core / "my_device.yaml") + CORE.config_path = setup_core / "my_device.yaml" result = storage_json.storage_path() data_dir = Path(CORE.data_dir) - expected = str(data_dir / "storage" / "my_device.yaml.json") + expected = data_dir / "storage" / "my_device.yaml.json" assert result == expected @@ -29,20 +29,20 @@ def test_ext_storage_path(setup_core: Path) -> None: result = storage_json.ext_storage_path("other_device.yaml") data_dir = Path(CORE.data_dir) - expected = str(data_dir / "storage" / "other_device.yaml.json") + expected = data_dir / "storage" / "other_device.yaml.json" assert result == expected def test_ext_storage_path_handles_various_extensions(setup_core: Path) -> None: """Test ext_storage_path works with different file extensions.""" result_yml = storage_json.ext_storage_path("device.yml") - assert result_yml.endswith("device.yml.json") + assert str(result_yml).endswith("device.yml.json") result_no_ext = storage_json.ext_storage_path("device") - assert result_no_ext.endswith("device.json") + assert str(result_no_ext).endswith("device.json") result_path = storage_json.ext_storage_path("my/device.yaml") - assert result_path.endswith("device.yaml.json") + assert str(result_path).endswith("device.yaml.json") def test_esphome_storage_path(setup_core: Path) -> None: @@ -50,7 +50,7 @@ def test_esphome_storage_path(setup_core: Path) -> None: result = storage_json.esphome_storage_path() data_dir = Path(CORE.data_dir) - expected = str(data_dir / "esphome.json") + expected = data_dir / "esphome.json" assert result == expected @@ -59,27 +59,27 @@ def test_ignored_devices_storage_path(setup_core: Path) -> None: result = storage_json.ignored_devices_storage_path() data_dir = Path(CORE.data_dir) - expected = str(data_dir / "ignored-devices.json") + expected = data_dir / "ignored-devices.json" assert result == expected def test_trash_storage_path(setup_core: Path) -> None: """Test trash_storage_path returns correct path.""" - CORE.config_path = str(setup_core / "configs" / "device.yaml") + CORE.config_path = setup_core / "configs" / "device.yaml" result = storage_json.trash_storage_path() - expected = str(setup_core / "configs" / "trash") + expected = setup_core / "configs" / "trash" assert result == expected def test_archive_storage_path(setup_core: Path) -> None: """Test archive_storage_path returns correct path.""" - CORE.config_path = str(setup_core / "configs" / "device.yaml") + CORE.config_path = setup_core / "configs" / "device.yaml" result = storage_json.archive_storage_path() - expected = str(setup_core / "configs" / "archive") + expected = setup_core / "configs" / "archive" assert result == expected @@ -87,12 +87,12 @@ def test_storage_path_with_subdirectory(setup_core: Path) -> None: """Test storage paths work correctly when config is in subdirectory.""" subdir = setup_core / "configs" / "basement" subdir.mkdir(parents=True, exist_ok=True) - CORE.config_path = str(subdir / "sensor.yaml") + CORE.config_path = subdir / "sensor.yaml" result = storage_json.storage_path() data_dir = Path(CORE.data_dir) - expected = str(data_dir / "storage" / "sensor.yaml.json") + expected = data_dir / "storage" / "sensor.yaml.json" assert result == expected @@ -173,16 +173,16 @@ def test_storage_paths_with_ha_addon(mock_is_ha_addon: bool, tmp_path: Path) -> """Test storage paths when running as Home Assistant addon.""" mock_is_ha_addon.return_value = True - CORE.config_path = str(tmp_path / "test.yaml") + CORE.config_path = tmp_path / "test.yaml" result = storage_json.storage_path() # When is_ha_addon is True, CORE.data_dir returns "/data" # This is the standard mount point for HA addon containers - expected = str(Path("/data") / "storage" / "test.yaml.json") + expected = Path("/data") / "storage" / "test.yaml.json" assert result == expected result = storage_json.esphome_storage_path() - expected = str(Path("/data") / "esphome.json") + expected = Path("/data") / "esphome.json" assert result == expected @@ -375,7 +375,7 @@ def test_storage_json_load_valid_file(tmp_path: Path) -> None: file_path = tmp_path / "storage.json" file_path.write_text(json.dumps(storage_data)) - result = storage_json.StorageJSON.load(str(file_path)) + result = storage_json.StorageJSON.load(file_path) assert result is not None assert result.name == "loaded_device" @@ -386,8 +386,8 @@ def test_storage_json_load_valid_file(tmp_path: Path) -> None: assert result.address == "10.0.0.1" assert result.web_port == 8080 assert result.target_platform == "ESP32" - assert result.build_path == "/loaded/build" - assert result.firmware_bin_path == "/loaded/firmware.bin" + assert result.build_path == Path("/loaded/build") + assert result.firmware_bin_path == Path("/loaded/firmware.bin") assert result.loaded_integrations == {"wifi", "api"} assert result.loaded_platforms == {"sensor"} assert result.no_mdns is True @@ -400,7 +400,7 @@ def test_storage_json_load_invalid_file(tmp_path: Path) -> None: file_path = tmp_path / "invalid.json" file_path.write_text("not valid json{") - result = storage_json.StorageJSON.load(str(file_path)) + result = storage_json.StorageJSON.load(file_path) assert result is None @@ -654,7 +654,7 @@ def test_storage_json_load_legacy_esphomeyaml_version(tmp_path: Path) -> None: file_path = tmp_path / "legacy.json" file_path.write_text(json.dumps(storage_data)) - result = storage_json.StorageJSON.load(str(file_path)) + result = storage_json.StorageJSON.load(file_path) assert result is not None assert result.esphome_version == "1.14.0" # Should map to esphome_version diff --git a/tests/unit_tests/test_substitutions.py b/tests/unit_tests/test_substitutions.py index b2b7cb1ea4..dd419aba9c 100644 --- a/tests/unit_tests/test_substitutions.py +++ b/tests/unit_tests/test_substitutions.py @@ -1,6 +1,6 @@ import glob import logging -import os +from pathlib import Path from esphome import yaml_util from esphome.components import substitutions @@ -52,9 +52,8 @@ def dict_diff(a, b, path=""): return diffs -def write_yaml(path, data): - with open(path, "w", encoding="utf-8") as f: - f.write(yaml_util.dump(data)) +def write_yaml(path: Path, data: dict) -> None: + path.write_text(yaml_util.dump(data), encoding="utf-8") def test_substitutions_fixtures(fixture_path): @@ -64,11 +63,10 @@ def test_substitutions_fixtures(fixture_path): failures = [] for source_path in sources: + source_path = Path(source_path) try: - expected_path = source_path.replace(".input.yaml", ".approved.yaml") - test_case = os.path.splitext(os.path.basename(source_path))[0].replace( - ".input", "" - ) + expected_path = source_path.with_suffix("").with_suffix(".approved.yaml") + test_case = source_path.with_suffix("").stem # Load using ESPHome's YAML loader config = yaml_util.load_yaml(source_path) @@ -81,12 +79,12 @@ def test_substitutions_fixtures(fixture_path): substitutions.do_substitution_pass(config, None) # Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE - if os.path.isfile(expected_path): + if expected_path.is_file(): expected = yaml_util.load_yaml(expected_path) elif DEV_MODE: expected = {} else: - assert os.path.isfile(expected_path), ( + assert expected_path.is_file(), ( f"Expected file missing: {expected_path}" ) @@ -97,16 +95,14 @@ def test_substitutions_fixtures(fixture_path): if got_sorted != expected_sorted: diff = "\n".join(dict_diff(got_sorted, expected_sorted)) msg = ( - f"Substitution result mismatch for {os.path.basename(source_path)}\n" + f"Substitution result mismatch for {source_path.name}\n" f"Diff:\n{diff}\n\n" f"Got: {got_sorted}\n" f"Expected: {expected_sorted}" ) # Write out the received file when test fails if DEV_MODE: - received_path = os.path.join( - os.path.dirname(source_path), f"{test_case}.received.yaml" - ) + received_path = source_path.with_name(f"{test_case}.received.yaml") write_yaml(received_path, config) print(msg) failures.append(msg) diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index 34f40a651f..85873caea8 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -1,5 +1,7 @@ """Tests for esphome.util module.""" +from __future__ import annotations + from pathlib import Path import pytest @@ -30,21 +32,21 @@ def test_list_yaml_files_with_files_and_directories(tmp_path: Path) -> None: # Test with mixed input (directories and files) configs = [ - str(dir1), - str(standalone1), - str(dir2), - str(standalone2), + dir1, + standalone1, + dir2, + standalone2, ] result = util.list_yaml_files(configs) # Should include all YAML files but not the .txt file assert set(result) == { - str(dir1 / "config1.yaml"), - str(dir1 / "config2.yml"), - str(dir2 / "config3.yaml"), - str(standalone1), - str(standalone2), + dir1 / "config1.yaml", + dir1 / "config2.yml", + dir2 / "config3.yaml", + standalone1, + standalone2, } # Check that results are sorted assert result == sorted(result) @@ -61,12 +63,12 @@ def test_list_yaml_files_only_directories(tmp_path: Path) -> None: (dir1 / "b.yml").write_text("test: b") (dir2 / "c.yaml").write_text("test: c") - result = util.list_yaml_files([str(dir1), str(dir2)]) + result = util.list_yaml_files([dir1, dir2]) assert set(result) == { - str(dir1 / "a.yaml"), - str(dir1 / "b.yml"), - str(dir2 / "c.yaml"), + dir1 / "a.yaml", + dir1 / "b.yml", + dir2 / "c.yaml", } assert result == sorted(result) @@ -86,17 +88,17 @@ def test_list_yaml_files_only_files(tmp_path: Path) -> None: # Include a non-YAML file to test filtering result = util.list_yaml_files( [ - str(file1), - str(file2), - str(file3), - str(non_yaml), + file1, + file2, + file3, + non_yaml, ] ) assert set(result) == { - str(file1), - str(file2), - str(file3), + file1, + file2, + file3, } assert result == sorted(result) @@ -106,7 +108,7 @@ def test_list_yaml_files_empty_directory(tmp_path: Path) -> None: empty_dir = tmp_path / "empty" empty_dir.mkdir() - result = util.list_yaml_files([str(empty_dir)]) + result = util.list_yaml_files([empty_dir]) assert result == [] @@ -119,7 +121,7 @@ def test_list_yaml_files_nonexistent_path(tmp_path: Path) -> None: # Should raise an error for non-existent directory with pytest.raises(FileNotFoundError): - util.list_yaml_files([str(nonexistent), str(existing)]) + util.list_yaml_files([nonexistent, existing]) def test_list_yaml_files_mixed_extensions(tmp_path: Path) -> None: @@ -135,11 +137,11 @@ def test_list_yaml_files_mixed_extensions(tmp_path: Path) -> None: yml_file.write_text("test: yml") other_file.write_text("test: txt") - result = util.list_yaml_files([str(dir1)]) + result = util.list_yaml_files([dir1]) assert set(result) == { - str(yaml_file), - str(yml_file), + yaml_file, + yml_file, } @@ -172,17 +174,18 @@ def test_list_yaml_files_does_not_recurse_into_subdirectories(tmp_path: Path) -> assert len(result) == 3 # Check that only root-level files are found - assert str(root / "config1.yaml") in result - assert str(root / "config2.yml") in result - assert str(root / "device.yaml") in result + assert root / "config1.yaml" in result + assert root / "config2.yml" in result + assert root / "device.yaml" in result # Ensure nested files are NOT found for r in result: - assert "subdir" not in r - assert "deeper" not in r - assert "nested1.yaml" not in r - assert "nested2.yml" not in r - assert "very_nested.yaml" not in r + r_str = str(r) + assert "subdir" not in r_str + assert "deeper" not in r_str + assert "nested1.yaml" not in r_str + assert "nested2.yml" not in r_str + assert "very_nested.yaml" not in r_str def test_list_yaml_files_excludes_secrets(tmp_path: Path) -> None: @@ -200,10 +203,10 @@ def test_list_yaml_files_excludes_secrets(tmp_path: Path) -> None: # Should find 2 files (config.yaml and device.yaml), not secrets assert len(result) == 2 - assert str(root / "config.yaml") in result - assert str(root / "device.yaml") in result - assert str(root / "secrets.yaml") not in result - assert str(root / "secrets.yml") not in result + assert root / "config.yaml" in result + assert root / "device.yaml" in result + assert root / "secrets.yaml" not in result + assert root / "secrets.yml" not in result def test_list_yaml_files_excludes_hidden_files(tmp_path: Path) -> None: @@ -221,90 +224,181 @@ def test_list_yaml_files_excludes_hidden_files(tmp_path: Path) -> None: # Should find only non-hidden files assert len(result) == 2 - assert str(root / "config.yaml") in result - assert str(root / "device.yaml") in result - assert str(root / ".hidden.yaml") not in result - assert str(root / ".backup.yml") not in result + assert root / "config.yaml" in result + assert root / "device.yaml" in result + assert root / ".hidden.yaml" not in result + assert root / ".backup.yml" not in result def test_filter_yaml_files_basic() -> None: """Test filter_yaml_files function.""" files = [ - "/path/to/config.yaml", - "/path/to/device.yml", - "/path/to/readme.txt", - "/path/to/script.py", - "/path/to/data.json", - "/path/to/another.yaml", + Path("/path/to/config.yaml"), + Path("/path/to/device.yml"), + Path("/path/to/readme.txt"), + Path("/path/to/script.py"), + Path("/path/to/data.json"), + Path("/path/to/another.yaml"), ] result = util.filter_yaml_files(files) assert len(result) == 3 - assert "/path/to/config.yaml" in result - assert "/path/to/device.yml" in result - assert "/path/to/another.yaml" in result - assert "/path/to/readme.txt" not in result - assert "/path/to/script.py" not in result - assert "/path/to/data.json" not in result + assert Path("/path/to/config.yaml") in result + assert Path("/path/to/device.yml") in result + assert Path("/path/to/another.yaml") in result + assert Path("/path/to/readme.txt") not in result + assert Path("/path/to/script.py") not in result + assert Path("/path/to/data.json") not in result def test_filter_yaml_files_excludes_secrets() -> None: """Test that filter_yaml_files excludes secrets files.""" files = [ - "/path/to/config.yaml", - "/path/to/secrets.yaml", - "/path/to/secrets.yml", - "/path/to/device.yaml", - "/some/dir/secrets.yaml", + Path("/path/to/config.yaml"), + Path("/path/to/secrets.yaml"), + Path("/path/to/secrets.yml"), + Path("/path/to/device.yaml"), + Path("/some/dir/secrets.yaml"), ] result = util.filter_yaml_files(files) assert len(result) == 2 - assert "/path/to/config.yaml" in result - assert "/path/to/device.yaml" in result - assert "/path/to/secrets.yaml" not in result - assert "/path/to/secrets.yml" not in result - assert "/some/dir/secrets.yaml" not in result + assert Path("/path/to/config.yaml") in result + assert Path("/path/to/device.yaml") in result + assert Path("/path/to/secrets.yaml") not in result + assert Path("/path/to/secrets.yml") not in result + assert Path("/some/dir/secrets.yaml") not in result def test_filter_yaml_files_excludes_hidden() -> None: """Test that filter_yaml_files excludes hidden files.""" files = [ - "/path/to/config.yaml", - "/path/to/.hidden.yaml", - "/path/to/.backup.yml", - "/path/to/device.yaml", - "/some/dir/.config.yaml", + Path("/path/to/config.yaml"), + Path("/path/to/.hidden.yaml"), + Path("/path/to/.backup.yml"), + Path("/path/to/device.yaml"), + Path("/some/dir/.config.yaml"), ] result = util.filter_yaml_files(files) assert len(result) == 2 - assert "/path/to/config.yaml" in result - assert "/path/to/device.yaml" in result - assert "/path/to/.hidden.yaml" not in result - assert "/path/to/.backup.yml" not in result - assert "/some/dir/.config.yaml" not in result + assert Path("/path/to/config.yaml") in result + assert Path("/path/to/device.yaml") in result + assert Path("/path/to/.hidden.yaml") not in result + assert Path("/path/to/.backup.yml") not in result + assert Path("/some/dir/.config.yaml") not in result def test_filter_yaml_files_case_sensitive() -> None: """Test that filter_yaml_files is case-sensitive for extensions.""" files = [ - "/path/to/config.yaml", - "/path/to/config.YAML", - "/path/to/config.YML", - "/path/to/config.Yaml", - "/path/to/config.yml", + Path("/path/to/config.yaml"), + Path("/path/to/config.YAML"), + Path("/path/to/config.YML"), + Path("/path/to/config.Yaml"), + Path("/path/to/config.yml"), ] result = util.filter_yaml_files(files) # Should only match lowercase .yaml and .yml assert len(result) == 2 - assert "/path/to/config.yaml" in result - assert "/path/to/config.yml" in result - assert "/path/to/config.YAML" not in result - assert "/path/to/config.YML" not in result - assert "/path/to/config.Yaml" not in result + + # Check the actual suffixes to ensure case-sensitive filtering + result_suffixes = [p.suffix for p in result] + assert ".yaml" in result_suffixes + assert ".yml" in result_suffixes + + # Verify the filtered files have the expected names + result_names = [p.name for p in result] + assert "config.yaml" in result_names + assert "config.yml" in result_names + # Ensure uppercase extensions are NOT included + assert "config.YAML" not in result_names + assert "config.YML" not in result_names + assert "config.Yaml" not in result_names + + +@pytest.mark.parametrize( + ("input_str", "expected"), + [ + # Empty string + ("", "''"), + # Simple strings that don't need quoting + ("hello", "hello"), + ("test123", "test123"), + ("file.txt", "file.txt"), + ("/path/to/file", "/path/to/file"), + ("user@host", "user@host"), + ("value:123", "value:123"), + ("item,list", "item,list"), + ("path-with-dash", "path-with-dash"), + # Strings that need quoting + ("hello world", "'hello world'"), + ("test\ttab", "'test\ttab'"), + ("line\nbreak", "'line\nbreak'"), + ("semicolon;here", "'semicolon;here'"), + ("pipe|symbol", "'pipe|symbol'"), + ("redirect>file", "'redirect>file'"), + ("redirect None: + """Test shlex_quote properly escapes shell arguments.""" + assert util.shlex_quote(input_str) == expected + + +def test_shlex_quote_safe_characters() -> None: + """Test that safe characters are not quoted.""" + # These characters are considered safe and shouldn't be quoted + safe_chars = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789@%+=:,./-_" + ) + for char in safe_chars: + assert util.shlex_quote(char) == char + assert util.shlex_quote(f"test{char}test") == f"test{char}test" + + +def test_shlex_quote_unsafe_characters() -> None: + """Test that unsafe characters trigger quoting.""" + # These characters should trigger quoting + unsafe_chars = ' \t\n;|>&<$`"\\?*[](){}!#~^' + for char in unsafe_chars: + result = util.shlex_quote(f"test{char}test") + assert result.startswith("'") + assert result.endswith("'") + + +def test_shlex_quote_edge_cases() -> None: + """Test edge cases for shlex_quote.""" + # Multiple single quotes + assert util.shlex_quote("'''") == "''\"'\"''\"'\"''\"'\"''" + + # Mixed quotes + assert util.shlex_quote('"\'"') == "'\"'\"'\"'\"'" + + # Only whitespace + assert util.shlex_quote(" ") == "' '" + assert util.shlex_quote("\t") == "'\t'" + assert util.shlex_quote("\n") == "'\n'" + assert util.shlex_quote(" ") == "' '" diff --git a/tests/unit_tests/test_vscode.py b/tests/unit_tests/test_vscode.py index 4b28a2215b..63bdf3e255 100644 --- a/tests/unit_tests/test_vscode.py +++ b/tests/unit_tests/test_vscode.py @@ -1,5 +1,5 @@ import json -import os +from pathlib import Path from unittest.mock import Mock, patch from esphome import vscode @@ -45,7 +45,7 @@ RESULT_NO_ERROR = '{"type": "result", "yaml_errors": [], "validation_errors": [] def test_multi_file(): - source_path = os.path.join("dir_path", "x.yaml") + source_path = str(Path("dir_path", "x.yaml")) output_lines = _run_repl_test( [ _validate(source_path), @@ -62,7 +62,7 @@ esp8266: expected_lines = [ _read_file(source_path), - _read_file(os.path.join("dir_path", "secrets.yaml")), + _read_file(str(Path("dir_path", "secrets.yaml"))), RESULT_NO_ERROR, ] @@ -70,7 +70,7 @@ esp8266: def test_shows_correct_range_error(): - source_path = os.path.join("dir_path", "x.yaml") + source_path = str(Path("dir_path", "x.yaml")) output_lines = _run_repl_test( [ _validate(source_path), @@ -98,7 +98,7 @@ esp8266: def test_shows_correct_loaded_file_error(): - source_path = os.path.join("dir_path", "x.yaml") + source_path = str(Path("dir_path", "x.yaml")) output_lines = _run_repl_test( [ _validate(source_path), @@ -121,7 +121,7 @@ packages: validation_error = error["validation_errors"][0] assert validation_error["message"].startswith("[broad] is an invalid option for") range = validation_error["range"] - assert range["document"] == os.path.join("dir_path", ".pkg.esp8266.yaml") + assert range["document"] == str(Path("dir_path", ".pkg.esp8266.yaml")) assert range["start_line"] == 1 assert range["start_col"] == 2 assert range["end_line"] == 1 diff --git a/tests/unit_tests/test_wizard.py b/tests/unit_tests/test_wizard.py index 7af4db813a..fd53a0b0b7 100644 --- a/tests/unit_tests/test_wizard.py +++ b/tests/unit_tests/test_wizard.py @@ -1,6 +1,5 @@ """Tests for the wizard.py file.""" -import os from pathlib import Path from typing import Any from unittest.mock import MagicMock @@ -127,7 +126,7 @@ def test_wizard_write_sets_platform( # Given del default_config["platform"] monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **default_config) @@ -147,7 +146,7 @@ def test_wizard_empty_config(tmp_path: Path, monkeypatch: MonkeyPatch): "name": "test-empty", } monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **empty_config) @@ -168,7 +167,7 @@ def test_wizard_upload_config(tmp_path: Path, monkeypatch: MonkeyPatch): "file_text": "# imported file 📁\n\n", } monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **empty_config) @@ -189,7 +188,7 @@ def test_wizard_write_defaults_platform_from_board_esp8266( default_config["board"] = [*ESP8266_BOARD_PINS][0] monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **default_config) @@ -210,7 +209,7 @@ def test_wizard_write_defaults_platform_from_board_esp32( default_config["board"] = [*ESP32_BOARD_PINS][0] monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **default_config) @@ -231,7 +230,7 @@ def test_wizard_write_defaults_platform_from_board_bk72xx( default_config["board"] = [*BK72XX_BOARD_PINS][0] monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **default_config) @@ -252,7 +251,7 @@ def test_wizard_write_defaults_platform_from_board_ln882x( default_config["board"] = [*LN882X_BOARD_PINS][0] monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **default_config) @@ -273,7 +272,7 @@ def test_wizard_write_defaults_platform_from_board_rtl87xx( default_config["board"] = [*RTL87XX_BOARD_PINS][0] monkeypatch.setattr(wz, "write_file", MagicMock()) - monkeypatch.setattr(CORE, "config_path", os.path.dirname(tmp_path)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When wz.wizard_write(tmp_path, **default_config) @@ -362,7 +361,7 @@ def test_wizard_rejects_path_with_invalid_extension(): """ # Given - config_file = "test.json" + config_file = Path("test.json") # When retval = wz.wizard(config_file) @@ -371,31 +370,31 @@ def test_wizard_rejects_path_with_invalid_extension(): assert retval == 1 -def test_wizard_rejects_existing_files(tmpdir): +def test_wizard_rejects_existing_files(tmp_path): """ The wizard should reject any configuration file that already exists """ # Given - config_file = tmpdir.join("test.yaml") - config_file.write("") + config_file = tmp_path / "test.yaml" + config_file.write_text("") # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 2 def test_wizard_accepts_default_answers_esp8266( - tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] ): """ The wizard should accept the given default answers for esp8266 """ # Given - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) @@ -403,14 +402,14 @@ def test_wizard_accepts_default_answers_esp8266( monkeypatch.setattr(wz, "wizard_write", MagicMock()) # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 0 def test_wizard_accepts_default_answers_esp32( - tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] ): """ The wizard should accept the given default answers for esp32 @@ -419,7 +418,7 @@ def test_wizard_accepts_default_answers_esp32( # Given wizard_answers[1] = "ESP32" wizard_answers[2] = "nodemcu-32s" - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) @@ -427,14 +426,14 @@ def test_wizard_accepts_default_answers_esp32( monkeypatch.setattr(wz, "wizard_write", MagicMock()) # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 0 def test_wizard_offers_better_node_name( - tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] ): """ When the node name does not conform, a better alternative is offered @@ -451,7 +450,7 @@ def test_wizard_offers_better_node_name( wz, "default_input", MagicMock(side_effect=lambda _, default: default) ) - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) @@ -459,7 +458,7 @@ def test_wizard_offers_better_node_name( monkeypatch.setattr(wz, "wizard_write", MagicMock()) # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 0 @@ -467,7 +466,7 @@ def test_wizard_offers_better_node_name( def test_wizard_requires_correct_platform( - tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] ): """ When the platform is not either esp32 or esp8266, the wizard should reject it @@ -476,7 +475,7 @@ def test_wizard_requires_correct_platform( # Given wizard_answers.insert(1, "foobar") # add invalid entry for platform - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) @@ -484,14 +483,14 @@ def test_wizard_requires_correct_platform( monkeypatch.setattr(wz, "wizard_write", MagicMock()) # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 0 def test_wizard_requires_correct_board( - tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] ): """ When the board is not a valid esp8266 board, the wizard should reject it @@ -500,7 +499,7 @@ def test_wizard_requires_correct_board( # Given wizard_answers.insert(2, "foobar") # add an invalid entry for board - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) @@ -508,14 +507,14 @@ def test_wizard_requires_correct_board( monkeypatch.setattr(wz, "wizard_write", MagicMock()) # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 0 def test_wizard_requires_valid_ssid( - tmpdir, monkeypatch: MonkeyPatch, wizard_answers: list[str] + tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str] ): """ When the board is not a valid esp8266 board, the wizard should reject it @@ -524,7 +523,7 @@ def test_wizard_requires_valid_ssid( # Given wizard_answers.insert(3, "") # add an invalid entry for ssid - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) @@ -532,28 +531,28 @@ def test_wizard_requires_valid_ssid( monkeypatch.setattr(wz, "wizard_write", MagicMock()) # When - retval = wz.wizard(str(config_file)) + retval = wz.wizard(config_file) # Then assert retval == 0 def test_wizard_write_protects_existing_config( - tmpdir, default_config: dict[str, Any], monkeypatch: MonkeyPatch + tmp_path: Path, default_config: dict[str, Any], monkeypatch: MonkeyPatch ): """ The wizard_write function should not overwrite existing config files and return False """ # Given - config_file = tmpdir.join("test.yaml") + config_file = tmp_path / "test.yaml" original_content = "# Original config content\n" - config_file.write(original_content) + config_file.write_text(original_content) - monkeypatch.setattr(CORE, "config_path", str(tmpdir)) + monkeypatch.setattr(CORE, "config_path", tmp_path.parent) # When - result = wz.wizard_write(str(config_file), **default_config) + result = wz.wizard_write(config_file, **default_config) # Then assert result is False # Should return False when file exists - assert config_file.read() == original_content + assert config_file.read_text() == original_content diff --git a/tests/unit_tests/test_writer.py b/tests/unit_tests/test_writer.py index 970e0fada6..bffd2b3881 100644 --- a/tests/unit_tests/test_writer.py +++ b/tests/unit_tests/test_writer.py @@ -257,10 +257,7 @@ def test_clean_cmake_cache( cmake_cache_file.write_text("# CMake cache file") # Setup mocks - mock_core.relative_pioenvs_path.side_effect = [ - str(pioenvs_dir), # First call for directory check - str(cmake_cache_file), # Second call for file path - ] + mock_core.relative_pioenvs_path.return_value = pioenvs_dir mock_core.name = "test_device" # Verify file exists before @@ -288,7 +285,7 @@ def test_clean_cmake_cache_no_pioenvs_dir( pioenvs_dir = tmp_path / ".pioenvs" # Setup mocks - mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) + mock_core.relative_pioenvs_path.return_value = pioenvs_dir # Verify directory doesn't exist assert not pioenvs_dir.exists() @@ -314,10 +311,7 @@ def test_clean_cmake_cache_no_cmake_file( cmake_cache_file = device_dir / "CMakeCache.txt" # Setup mocks - mock_core.relative_pioenvs_path.side_effect = [ - str(pioenvs_dir), # First call for directory check - str(cmake_cache_file), # Second call for file path - ] + mock_core.relative_pioenvs_path.return_value = pioenvs_dir mock_core.name = "test_device" # Verify file doesn't exist @@ -358,9 +352,9 @@ def test_clean_build( (platformio_cache_dir / "downloads" / "package.tar.gz").write_text("package") # Setup mocks - mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) - mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) - mock_core.relative_build_path.return_value = str(dependencies_lock) + mock_core.relative_pioenvs_path.return_value = pioenvs_dir + mock_core.relative_piolibdeps_path.return_value = piolibdeps_dir + mock_core.relative_build_path.return_value = dependencies_lock # Verify all exist before assert pioenvs_dir.exists() @@ -368,11 +362,17 @@ def test_clean_build( assert dependencies_lock.exists() assert platformio_cache_dir.exists() - # Mock PlatformIO's get_project_cache_dir + # Mock PlatformIO's ProjectConfig cache_dir with patch( - "platformio.project.helpers.get_project_cache_dir" - ) as mock_get_cache_dir: - mock_get_cache_dir.return_value = str(platformio_cache_dir) + "platformio.project.config.ProjectConfig.get_instance" + ) as mock_get_instance: + mock_config = MagicMock() + mock_get_instance.return_value = mock_config + mock_config.get.side_effect = ( + lambda section, option: str(platformio_cache_dir) + if (section, option) == ("platformio", "cache_dir") + else "" + ) # Call the function with caplog.at_level("INFO"): @@ -408,9 +408,9 @@ def test_clean_build_partial_exists( dependencies_lock = tmp_path / "dependencies.lock" # Setup mocks - mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) - mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) - mock_core.relative_build_path.return_value = str(dependencies_lock) + mock_core.relative_pioenvs_path.return_value = pioenvs_dir + mock_core.relative_piolibdeps_path.return_value = piolibdeps_dir + mock_core.relative_build_path.return_value = dependencies_lock # Verify only pioenvs exists assert pioenvs_dir.exists() @@ -445,9 +445,9 @@ def test_clean_build_nothing_exists( dependencies_lock = tmp_path / "dependencies.lock" # Setup mocks - mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) - mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) - mock_core.relative_build_path.return_value = str(dependencies_lock) + mock_core.relative_pioenvs_path.return_value = pioenvs_dir + mock_core.relative_piolibdeps_path.return_value = piolibdeps_dir + mock_core.relative_build_path.return_value = dependencies_lock # Verify nothing exists assert not pioenvs_dir.exists() @@ -481,9 +481,9 @@ def test_clean_build_platformio_not_available( dependencies_lock.write_text("lock file") # Setup mocks - mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) - mock_core.relative_piolibdeps_path.return_value = str(piolibdeps_dir) - mock_core.relative_build_path.return_value = str(dependencies_lock) + mock_core.relative_pioenvs_path.return_value = pioenvs_dir + mock_core.relative_piolibdeps_path.return_value = piolibdeps_dir + mock_core.relative_build_path.return_value = dependencies_lock # Verify all exist before assert pioenvs_dir.exists() @@ -492,7 +492,7 @@ def test_clean_build_platformio_not_available( # Mock import error for platformio with ( - patch.dict("sys.modules", {"platformio.project.helpers": None}), + patch.dict("sys.modules", {"platformio.project.config": None}), caplog.at_level("INFO"), ): # Call the function @@ -519,18 +519,24 @@ def test_clean_build_empty_cache_dir( pioenvs_dir.mkdir() # Setup mocks - mock_core.relative_pioenvs_path.return_value = str(pioenvs_dir) - mock_core.relative_piolibdeps_path.return_value = str(tmp_path / ".piolibdeps") - mock_core.relative_build_path.return_value = str(tmp_path / "dependencies.lock") + mock_core.relative_pioenvs_path.return_value = pioenvs_dir + mock_core.relative_piolibdeps_path.return_value = tmp_path / ".piolibdeps" + mock_core.relative_build_path.return_value = tmp_path / "dependencies.lock" # Verify pioenvs exists before assert pioenvs_dir.exists() - # Mock PlatformIO's get_project_cache_dir to return whitespace + # Mock PlatformIO's ProjectConfig cache_dir to return whitespace with patch( - "platformio.project.helpers.get_project_cache_dir" - ) as mock_get_cache_dir: - mock_get_cache_dir.return_value = " " # Whitespace only + "platformio.project.config.ProjectConfig.get_instance" + ) as mock_get_instance: + mock_config = MagicMock() + mock_get_instance.return_value = mock_config + mock_config.get.side_effect = ( + lambda section, option: " " # Whitespace only + if (section, option) == ("platformio", "cache_dir") + else "" + ) # Call the function with caplog.at_level("INFO"): @@ -552,7 +558,7 @@ def test_write_gitignore_creates_new_file( gitignore_path = tmp_path / ".gitignore" # Setup mocks - mock_core.relative_config_path.return_value = str(gitignore_path) + mock_core.relative_config_path.return_value = gitignore_path # Verify file doesn't exist assert not gitignore_path.exists() @@ -576,7 +582,7 @@ def test_write_gitignore_skips_existing_file( gitignore_path.write_text(existing_content) # Setup mocks - mock_core.relative_config_path.return_value = str(gitignore_path) + mock_core.relative_config_path.return_value = gitignore_path # Verify file exists with custom content assert gitignore_path.exists() @@ -615,7 +621,7 @@ void loop() {{}}""" main_cpp.write_text(existing_content) # Setup mocks - mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.relative_src_path.return_value = main_cpp mock_core.cpp_global_section = "// Global section" # Call the function @@ -652,7 +658,7 @@ def test_write_cpp_creates_new_file( main_cpp = tmp_path / "main.cpp" # Setup mocks - mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.relative_src_path.return_value = main_cpp mock_core.cpp_global_section = "// Global section" # Verify file doesn't exist @@ -668,7 +674,7 @@ def test_write_cpp_creates_new_file( # Get the content that would be written mock_write_file.assert_called_once() written_path, written_content = mock_write_file.call_args[0] - assert written_path == str(main_cpp) + assert written_path == main_cpp # Check that all necessary parts are in the new file assert '#include "esphome.h"' in written_content @@ -698,7 +704,7 @@ def test_write_cpp_with_missing_end_marker( main_cpp.write_text(existing_content) # Setup mocks - mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.relative_src_path.return_value = main_cpp # Call should raise an error with pytest.raises(EsphomeError, match="Could not find auto generated code end"): @@ -724,8 +730,258 @@ def test_write_cpp_with_duplicate_markers( main_cpp.write_text(existing_content) # Setup mocks - mock_core.relative_src_path.return_value = str(main_cpp) + mock_core.relative_src_path.return_value = main_cpp # Call should raise an error with pytest.raises(EsphomeError, match="Found multiple auto generate code begins"): write_cpp("// New code") + + +@patch("esphome.writer.CORE") +def test_clean_all( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_all removes build and PlatformIO dirs.""" + # Create build directories for multiple configurations + config1_dir = tmp_path / "config1" + config2_dir = tmp_path / "config2" + config1_dir.mkdir() + config2_dir.mkdir() + + build_dir1 = config1_dir / ".esphome" + build_dir2 = config2_dir / ".esphome" + build_dir1.mkdir() + build_dir2.mkdir() + (build_dir1 / "dummy.txt").write_text("x") + (build_dir2 / "dummy.txt").write_text("x") + + # Create PlatformIO directories + pio_cache = tmp_path / "pio_cache" + pio_packages = tmp_path / "pio_packages" + pio_platforms = tmp_path / "pio_platforms" + pio_core = tmp_path / "pio_core" + for d in (pio_cache, pio_packages, pio_platforms, pio_core): + d.mkdir() + (d / "keep").write_text("x") + + # Mock ProjectConfig + with patch( + "platformio.project.config.ProjectConfig.get_instance" + ) as mock_get_instance: + mock_config = MagicMock() + mock_get_instance.return_value = mock_config + + def cfg_get(section: str, option: str) -> str: + mapping = { + ("platformio", "cache_dir"): str(pio_cache), + ("platformio", "packages_dir"): str(pio_packages), + ("platformio", "platforms_dir"): str(pio_platforms), + ("platformio", "core_dir"): str(pio_core), + } + return mapping.get((section, option), "") + + mock_config.get.side_effect = cfg_get + + # Call + from esphome.writer import clean_all + + with caplog.at_level("INFO"): + clean_all([str(config1_dir), str(config2_dir)]) + + # Verify deletions - .esphome directories remain but contents are cleaned + # The .esphome directory itself is not removed because it may contain storage + assert build_dir1.exists() + assert build_dir2.exists() + + # Verify that files in .esphome were removed + assert not (build_dir1 / "dummy.txt").exists() + assert not (build_dir2 / "dummy.txt").exists() + assert not pio_cache.exists() + assert not pio_packages.exists() + assert not pio_platforms.exists() + assert not pio_core.exists() + + # Verify logging mentions each + assert "Cleaning" in caplog.text + assert str(build_dir1) in caplog.text + assert str(build_dir2) in caplog.text + assert "PlatformIO cache" in caplog.text + assert "PlatformIO packages" in caplog.text + assert "PlatformIO platforms" in caplog.text + assert "PlatformIO core" in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_all_preserves_storage( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_all preserves storage directory.""" + # Create build directory with storage subdirectory + config_dir = tmp_path / "config" + config_dir.mkdir() + + build_dir = config_dir / ".esphome" + build_dir.mkdir() + (build_dir / "dummy.txt").write_text("x") + (build_dir / "other_file.txt").write_text("y") + + # Create storage directory with content + storage_dir = build_dir / "storage" + storage_dir.mkdir() + (storage_dir / "storage.json").write_text('{"test": "data"}') + (storage_dir / "other_storage.txt").write_text("storage content") + + # Call clean_all + from esphome.writer import clean_all + + with caplog.at_level("INFO"): + clean_all([str(config_dir)]) + + # Verify .esphome directory still exists + assert build_dir.exists() + + # Verify storage directory still exists with its contents + assert storage_dir.exists() + assert (storage_dir / "storage.json").exists() + assert (storage_dir / "other_storage.txt").exists() + + # Verify storage contents are intact + assert (storage_dir / "storage.json").read_text() == '{"test": "data"}' + assert (storage_dir / "other_storage.txt").read_text() == "storage content" + + # Verify other files were removed + assert not (build_dir / "dummy.txt").exists() + assert not (build_dir / "other_file.txt").exists() + + # Verify logging mentions deletion + assert "Cleaning" in caplog.text + assert str(build_dir) in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_all_platformio_not_available( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_all when PlatformIO is not available.""" + # Build dirs + config_dir = tmp_path / "config" + config_dir.mkdir() + build_dir = config_dir / ".esphome" + build_dir.mkdir() + + # PlatformIO dirs that should remain untouched + pio_cache = tmp_path / "pio_cache" + pio_cache.mkdir() + + from esphome.writer import clean_all + + with ( + patch.dict("sys.modules", {"platformio.project.config": None}), + caplog.at_level("INFO"), + ): + clean_all([str(config_dir)]) + + # Build dir contents cleaned, PlatformIO dirs remain + assert build_dir.exists() + assert pio_cache.exists() + + # No PlatformIO-specific logs + assert "PlatformIO" not in caplog.text + + +@patch("esphome.writer.CORE") +def test_clean_all_partial_exists( + mock_core: MagicMock, + tmp_path: Path, +) -> None: + """Test clean_all when only some build dirs exist.""" + config_dir = tmp_path / "config" + config_dir.mkdir() + build_dir = config_dir / ".esphome" + build_dir.mkdir() + + with patch( + "platformio.project.config.ProjectConfig.get_instance" + ) as mock_get_instance: + mock_config = MagicMock() + mock_get_instance.return_value = mock_config + # Return non-existent dirs + mock_config.get.side_effect = lambda *_args, **_kw: str( + tmp_path / "does_not_exist" + ) + + from esphome.writer import clean_all + + clean_all([str(config_dir)]) + + assert build_dir.exists() + + +@patch("esphome.writer.CORE") +def test_clean_all_removes_non_storage_directories( + mock_core: MagicMock, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test clean_all removes directories other than storage.""" + # Create build directory with various subdirectories + config_dir = tmp_path / "config" + config_dir.mkdir() + + build_dir = config_dir / ".esphome" + build_dir.mkdir() + + # Create files + (build_dir / "file1.txt").write_text("content1") + (build_dir / "file2.txt").write_text("content2") + + # Create storage directory (should be preserved) + storage_dir = build_dir / "storage" + storage_dir.mkdir() + (storage_dir / "storage.json").write_text('{"test": "data"}') + + # Create other directories (should be removed) + cache_dir = build_dir / "cache" + cache_dir.mkdir() + (cache_dir / "cache_file.txt").write_text("cache content") + + logs_dir = build_dir / "logs" + logs_dir.mkdir() + (logs_dir / "log1.txt").write_text("log content") + + temp_dir = build_dir / "temp" + temp_dir.mkdir() + (temp_dir / "temp_file.txt").write_text("temp content") + + # Call clean_all + from esphome.writer import clean_all + + with caplog.at_level("INFO"): + clean_all([str(config_dir)]) + + # Verify .esphome directory still exists + assert build_dir.exists() + + # Verify storage directory and its contents are preserved + assert storage_dir.exists() + assert (storage_dir / "storage.json").exists() + assert (storage_dir / "storage.json").read_text() == '{"test": "data"}' + + # Verify files were removed + assert not (build_dir / "file1.txt").exists() + assert not (build_dir / "file2.txt").exists() + + # Verify non-storage directories were removed + assert not cache_dir.exists() + assert not logs_dir.exists() + assert not temp_dir.exists() + + # Verify logging mentions cleaning + assert "Cleaning" in caplog.text + assert str(build_dir) in caplog.text diff --git a/tests/unit_tests/test_yaml_util.py b/tests/unit_tests/test_yaml_util.py index f31e9554dc..eac0ceabb8 100644 --- a/tests/unit_tests/test_yaml_util.py +++ b/tests/unit_tests/test_yaml_util.py @@ -1,9 +1,26 @@ -from esphome import yaml_util +from pathlib import Path +import shutil +from unittest.mock import patch + +import pytest + +from esphome import core, yaml_util from esphome.components import substitutions from esphome.core import EsphomeError +from esphome.util import OrderedDict -def test_include_with_vars(fixture_path): +@pytest.fixture(autouse=True) +def clear_secrets_cache() -> None: + """Clear the secrets cache before each test.""" + yaml_util._SECRET_VALUES.clear() + yaml_util._SECRET_CACHE.clear() + yield + yaml_util._SECRET_VALUES.clear() + yaml_util._SECRET_CACHE.clear() + + +def test_include_with_vars(fixture_path: Path) -> None: yaml_file = fixture_path / "yaml_util" / "includetest.yaml" actual = yaml_util.load_yaml(yaml_file) @@ -50,15 +67,214 @@ def test_parsing_with_custom_loader(fixture_path): """ yaml_file = fixture_path / "yaml_util" / "includetest.yaml" - loader_calls = [] + loader_calls: list[Path] = [] - def custom_loader(fname): + def custom_loader(fname: Path): loader_calls.append(fname) - with open(yaml_file, encoding="utf-8") as f_handle: + with yaml_file.open(encoding="utf-8") as f_handle: yaml_util.parse_yaml(yaml_file, f_handle, custom_loader) assert len(loader_calls) == 3 - assert loader_calls[0].endswith("includes/included.yaml") - assert loader_calls[1].endswith("includes/list.yaml") - assert loader_calls[2].endswith("includes/scalar.yaml") + assert loader_calls[0].parts[-2:] == ("includes", "included.yaml") + assert loader_calls[1].parts[-2:] == ("includes", "list.yaml") + assert loader_calls[2].parts[-2:] == ("includes", "scalar.yaml") + + +def test_construct_secret_simple(fixture_path: Path) -> None: + """Test loading a YAML file with !secret tags.""" + yaml_file = fixture_path / "yaml_util" / "test_secret.yaml" + + actual = yaml_util.load_yaml(yaml_file) + + # Check that secrets were properly loaded + assert actual["wifi"]["password"] == "super_secret_wifi" + assert actual["api"]["encryption"]["key"] == "0123456789abcdef" + assert actual["sensor"][0]["id"] == "my_secret_value" + + +def test_construct_secret_missing(fixture_path: Path, tmp_path: Path) -> None: + """Test that missing secrets raise proper errors.""" + # Create a YAML file with a secret that doesn't exist + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +esphome: + name: test + +wifi: + password: !secret nonexistent_secret +""") + + # Create an empty secrets file + secrets_yaml = tmp_path / "secrets.yaml" + secrets_yaml.write_text("some_other_secret: value") + + with pytest.raises(EsphomeError, match="Secret 'nonexistent_secret' not defined"): + yaml_util.load_yaml(test_yaml) + + +def test_construct_secret_no_secrets_file(tmp_path: Path) -> None: + """Test that missing secrets.yaml file raises proper error.""" + # Create a YAML file with a secret but no secrets.yaml + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +wifi: + password: !secret some_secret +""") + + # Mock CORE.config_path to avoid NoneType error + with ( + patch.object(core.CORE, "config_path", tmp_path / "main.yaml"), + pytest.raises(EsphomeError, match="secrets.yaml"), + ): + yaml_util.load_yaml(test_yaml) + + +def test_construct_secret_fallback_to_main_config_dir( + fixture_path: Path, tmp_path: Path +) -> None: + """Test fallback to main config directory for secrets.""" + # Create a subdirectory with a YAML file that uses secrets + subdir = tmp_path / "subdir" + subdir.mkdir() + + test_yaml = subdir / "test.yaml" + test_yaml.write_text(""" +wifi: + password: !secret test_secret +""") + + # Create secrets.yaml in the main directory + main_secrets = tmp_path / "secrets.yaml" + main_secrets.write_text("test_secret: main_secret_value") + + # Mock CORE.config_path to point to main directory + with patch.object(core.CORE, "config_path", tmp_path / "main.yaml"): + actual = yaml_util.load_yaml(test_yaml) + assert actual["wifi"]["password"] == "main_secret_value" + + +def test_construct_include_dir_named(fixture_path: Path, tmp_path: Path) -> None: + """Test !include_dir_named directive.""" + # Copy fixture directory to temporary location + src_dir = fixture_path / "yaml_util" + dst_dir = tmp_path / "yaml_util" + shutil.copytree(src_dir, dst_dir) + + # Create test YAML that uses include_dir_named + test_yaml = dst_dir / "test_include_named.yaml" + test_yaml.write_text(""" +sensor: !include_dir_named named_dir +""") + + actual = yaml_util.load_yaml(test_yaml) + actual_sensor = actual["sensor"] + + # Check that files were loaded with their names as keys + assert isinstance(actual_sensor, OrderedDict) + assert "sensor1" in actual_sensor + assert "sensor2" in actual_sensor + assert "sensor3" in actual_sensor # Files from subdirs are included with basename + + # Check content of loaded files + assert actual_sensor["sensor1"]["platform"] == "template" + assert actual_sensor["sensor1"]["name"] == "Sensor 1" + assert actual_sensor["sensor2"]["platform"] == "template" + assert actual_sensor["sensor2"]["name"] == "Sensor 2" + + # Check that subdirectory files are included with their basename + assert actual_sensor["sensor3"]["platform"] == "template" + assert actual_sensor["sensor3"]["name"] == "Sensor 3 in subdir" + + # Check that hidden files and non-YAML files are not included + assert ".hidden" not in actual_sensor + assert "not_yaml" not in actual_sensor + + +def test_construct_include_dir_named_empty_dir(tmp_path: Path) -> None: + """Test !include_dir_named with empty directory.""" + # Create empty directory + empty_dir = tmp_path / "empty_dir" + empty_dir.mkdir() + + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +sensor: !include_dir_named empty_dir +""") + + actual = yaml_util.load_yaml(test_yaml) + + # Should return empty OrderedDict + assert isinstance(actual["sensor"], OrderedDict) + assert len(actual["sensor"]) == 0 + + +def test_construct_include_dir_named_with_dots(tmp_path: Path) -> None: + """Test that include_dir_named ignores files starting with dots.""" + # Create directory with various files + test_dir = tmp_path / "test_dir" + test_dir.mkdir() + + # Create visible file + visible_file = test_dir / "visible.yaml" + visible_file.write_text("key: visible_value") + + # Create hidden file + hidden_file = test_dir / ".hidden.yaml" + hidden_file.write_text("key: hidden_value") + + # Create hidden directory with files + hidden_dir = test_dir / ".hidden_dir" + hidden_dir.mkdir() + hidden_subfile = hidden_dir / "subfile.yaml" + hidden_subfile.write_text("key: hidden_subfile_value") + + test_yaml = tmp_path / "test.yaml" + test_yaml.write_text(""" +test: !include_dir_named test_dir +""") + + actual = yaml_util.load_yaml(test_yaml) + + # Should only include visible file + assert "visible" in actual["test"] + assert actual["test"]["visible"]["key"] == "visible_value" + + # Should not include hidden files or directories + assert ".hidden" not in actual["test"] + assert ".hidden_dir" not in actual["test"] + + +def test_find_files_recursive(fixture_path: Path, tmp_path: Path) -> None: + """Test that _find_files works recursively through include_dir_named.""" + # Copy fixture directory to temporary location + src_dir = fixture_path / "yaml_util" + dst_dir = tmp_path / "yaml_util" + shutil.copytree(src_dir, dst_dir) + + # This indirectly tests _find_files by using include_dir_named + test_yaml = dst_dir / "test_include_recursive.yaml" + test_yaml.write_text(""" +all_sensors: !include_dir_named named_dir +""") + + actual = yaml_util.load_yaml(test_yaml) + + # Should find sensor1.yaml, sensor2.yaml, and subdir/sensor3.yaml (all flattened) + assert len(actual["all_sensors"]) == 3 + assert "sensor1" in actual["all_sensors"] + assert "sensor2" in actual["all_sensors"] + assert "sensor3" in actual["all_sensors"] + + +def test_secret_values_tracking(fixture_path: Path) -> None: + """Test that secret values are properly tracked for dumping.""" + yaml_file = fixture_path / "yaml_util" / "test_secret.yaml" + + yaml_util.load_yaml(yaml_file) + + # Check that secret values are tracked + assert "super_secret_wifi" in yaml_util._SECRET_VALUES + assert yaml_util._SECRET_VALUES["super_secret_wifi"] == "wifi_password" + assert "0123456789abcdef" in yaml_util._SECRET_VALUES + assert yaml_util._SECRET_VALUES["0123456789abcdef"] == "api_key"