1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-29 08:32:26 +01:00

Merge branch 'integration' into memory_api

This commit is contained in:
J. Nick Koston
2025-09-21 11:06:30 -06:00
18 changed files with 1344 additions and 75 deletions

View File

@@ -16,7 +16,7 @@ from esphome.const import (
CONF_SAFE_MODE,
CONF_VERSION,
)
from esphome.core import coroutine_with_priority
from esphome.core import CORE, coroutine_with_priority
from esphome.coroutine import CoroPriority
import esphome.final_validate as fv
@@ -24,9 +24,22 @@ _LOGGER = logging.getLogger(__name__)
CODEOWNERS = ["@esphome/core"]
AUTO_LOAD = ["md5", "socket"]
DEPENDENCIES = ["network"]
def supports_sha256() -> bool:
"""Check if the current platform supports SHA256 for OTA authentication."""
return bool(CORE.is_esp32 or CORE.is_esp8266 or CORE.is_rp2040 or CORE.is_libretiny)
def AUTO_LOAD() -> list[str]:
"""Conditionally auto-load sha256 only on platforms that support it."""
base_components = ["md5", "socket"]
if supports_sha256():
return base_components + ["sha256"]
return base_components
esphome = cg.esphome_ns.namespace("esphome")
ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent)
@@ -126,6 +139,11 @@ FINAL_VALIDATE_SCHEMA = ota_esphome_final_validate
async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID])
cg.add(var.set_port(config[CONF_PORT]))
# Only include SHA256 support on platforms that have it
if supports_sha256():
cg.add_define("USE_OTA_SHA256")
if CONF_PASSWORD in config:
cg.add(var.set_auth_password(config[CONF_PASSWORD]))
cg.add_define("USE_OTA_PASSWORD")

View File

@@ -1,6 +1,9 @@
#include "ota_esphome.h"
#ifdef USE_OTA
#include "esphome/components/md5/md5.h"
#ifdef USE_OTA_SHA256
#include "esphome/components/sha256/sha256.h"
#endif
#include "esphome/components/network/util.h"
#include "esphome/components/ota/ota_backend.h"
#include "esphome/components/ota/ota_backend_arduino_esp32.h"
@@ -95,6 +98,33 @@ void ESPHomeOTAComponent::loop() {
}
static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01;
#ifdef USE_OTA_SHA256
static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02;
#endif
// Temporary flag to allow MD5 downgrade for ~3 versions (until 2026.1.0)
// This allows users to downgrade via OTA if they encounter issues after updating.
// Without this, users would need to do a serial flash to downgrade.
// TODO: Remove this flag and all associated code in 2026.1.0
#define ALLOW_OTA_DOWNGRADE_MD5
template<typename HashClass> struct HashTraits;
template<> struct HashTraits<md5::MD5Digest> {
static constexpr int NONCE_SIZE = 8;
static constexpr int HEX_SIZE = 32;
static constexpr const char *NAME = "MD5";
static constexpr ota::OTAResponseTypes AUTH_REQUEST = ota::OTA_RESPONSE_REQUEST_AUTH;
};
#ifdef USE_OTA_SHA256
template<> struct HashTraits<sha256::SHA256> {
static constexpr int NONCE_SIZE = 16;
static constexpr int HEX_SIZE = 64;
static constexpr const char *NAME = "SHA256";
static constexpr ota::OTAResponseTypes AUTH_REQUEST = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH;
};
#endif
void ESPHomeOTAComponent::handle_handshake_() {
/// Handle the initial OTA handshake.
@@ -225,57 +255,55 @@ void ESPHomeOTAComponent::handle_data_() {
#ifdef USE_OTA_PASSWORD
if (!this->password_.empty()) {
buf[0] = ota::OTA_RESPONSE_REQUEST_AUTH;
this->writeall_(buf, 1);
md5::MD5Digest md5{};
md5.init();
sprintf(sbuf, "%08" PRIx32, random_uint32());
md5.add(sbuf, 8);
md5.calculate();
md5.get_hex(sbuf);
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
bool auth_success = false;
// Send nonce, 32 bytes hex MD5
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
ESP_LOGW(TAG, "Auth: Writing nonce failed");
#ifdef USE_OTA_SHA256
// SECURITY HARDENING: Prefer SHA256 authentication on platforms that support it.
//
// This is a hardening measure to prevent future downgrade attacks where an attacker
// could force the use of MD5 authentication by manipulating the feature flags.
//
// While MD5 is currently still acceptable for our OTA authentication use case
// (where the password is a shared secret and we're only authenticating, not
// encrypting), at some point in the future MD5 will likely become so weak that
// it could be practically attacked.
//
// We enforce SHA256 now on capable platforms because:
// 1. We can't retroactively update device firmware in the field
// 2. Clients (like esphome CLI) can always be updated to support SHA256
// 3. This prevents any possibility of downgrade attacks in the future
//
// Devices that don't support SHA256 (due to platform limitations) will
// continue to use MD5 as their only option (see #else branch below).
bool client_supports_sha256 = (ota_features & FEATURE_SUPPORTS_SHA256_AUTH) != 0;
#ifdef ALLOW_OTA_DOWNGRADE_MD5
// Temporary compatibility mode: Allow MD5 for ~3 versions to enable OTA downgrades
// This prevents users from being locked out if they need to downgrade after updating
// TODO: Remove this entire ifdef block in 2026.1.0
if (client_supports_sha256) {
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
} else {
ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)");
auth_success = this->perform_hash_auth_<md5::MD5Digest>(this->password_);
}
#else
// Strict mode: SHA256 required on capable platforms (future default)
if (!client_supports_sha256) {
ESP_LOGW(TAG, "Client requires SHA256");
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
#endif // ALLOW_OTA_DOWNGRADE_MD5
#else
// Platform only supports MD5 - use it as the only available option
// This is not a security downgrade as the platform cannot support SHA256
auth_success = this->perform_hash_auth_<md5::MD5Digest>(this->password_);
#endif // USE_OTA_SHA256
// prepare challenge
md5.init();
md5.add(this->password_.c_str(), this->password_.length());
// add nonce
md5.add(sbuf, 32);
// Receive cnonce, 32 bytes hex MD5
if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Auth: Reading cnonce failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
sbuf[32] = '\0';
ESP_LOGV(TAG, "Auth: CNonce is %s", sbuf);
// add cnonce
md5.add(sbuf, 32);
// calculate result
md5.calculate();
md5.get_hex(sbuf);
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
// Receive result, 32 bytes hex MD5
if (!this->readall_(buf + 64, 32)) {
ESP_LOGW(TAG, "Auth: Reading response failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
sbuf[64 + 32] = '\0';
ESP_LOGV(TAG, "Auth: Response is %s", sbuf + 64);
bool matches = true;
for (uint8_t i = 0; i < 32; i++)
matches = matches && buf[i] == buf[64 + i];
if (!matches) {
ESP_LOGW(TAG, "Auth failed! Passwords do not match");
if (!auth_success) {
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
@@ -499,5 +527,110 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
delay(1);
}
// Template function definition - placed at end to ensure all types are complete
template<typename HashClass> bool ESPHomeOTAComponent::perform_hash_auth_(const std::string &password) {
using Traits = HashTraits<HashClass>;
// Minimize stack usage by reusing buffers
// We only need 2 buffers at most at the same time
constexpr size_t hex_buffer_size = Traits::HEX_SIZE + 1;
// These two buffers are reused throughout the function
char hex_buffer1[hex_buffer_size]; // Used for: nonce -> expected result
char hex_buffer2[hex_buffer_size]; // Used for: cnonce -> response
// Small stack buffer for auth request and nonce seed bytes
uint8_t buf[1];
uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256)
// Send auth request type
buf[0] = Traits::AUTH_REQUEST;
this->writeall_(buf, 1);
HashClass hasher;
hasher.init();
// Generate nonce seed bytes
uint32_t r1 = random_uint32();
// Convert first uint32 to bytes (always needed for MD5)
nonce_bytes[0] = (r1 >> 24) & 0xFF;
nonce_bytes[1] = (r1 >> 16) & 0xFF;
nonce_bytes[2] = (r1 >> 8) & 0xFF;
nonce_bytes[3] = r1 & 0xFF;
if (Traits::NONCE_SIZE == 8) {
// MD5: 8 chars = "%08x" format = 4 bytes from one random uint32
hasher.add(nonce_bytes, 4);
}
#ifdef USE_OTA_SHA256
else {
// SHA256: 16 chars = "%08x%08x" format = 8 bytes from two random uint32s
uint32_t r2 = random_uint32();
nonce_bytes[4] = (r2 >> 24) & 0xFF;
nonce_bytes[5] = (r2 >> 16) & 0xFF;
nonce_bytes[6] = (r2 >> 8) & 0xFF;
nonce_bytes[7] = r2 & 0xFF;
hasher.add(nonce_bytes, 8);
}
#endif
hasher.calculate();
// Use hex_buffer1 for nonce
hasher.get_hex(hex_buffer1);
hex_buffer1[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s Nonce is %s", Traits::NAME, hex_buffer1);
// Send nonce
if (!this->writeall_(reinterpret_cast<uint8_t *>(hex_buffer1), Traits::HEX_SIZE)) {
ESP_LOGW(TAG, "Auth: Writing %s nonce failed", Traits::NAME);
return false;
}
// Prepare challenge
hasher.init();
hasher.add(password.c_str(), password.length());
hasher.add(hex_buffer1, Traits::HEX_SIZE); // Add nonce
// Receive cnonce into hex_buffer2
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), Traits::HEX_SIZE)) {
ESP_LOGW(TAG, "Auth: Reading %s cnonce failed", Traits::NAME);
return false;
}
hex_buffer2[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s CNonce is %s", Traits::NAME, hex_buffer2);
// Add cnonce to hash
hasher.add(hex_buffer2, Traits::HEX_SIZE);
// Calculate result - reuse hex_buffer1 for expected
hasher.calculate();
hasher.get_hex(hex_buffer1);
hex_buffer1[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s Result is %s", Traits::NAME, hex_buffer1);
// Receive response - reuse hex_buffer2
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), Traits::HEX_SIZE)) {
ESP_LOGW(TAG, "Auth: Reading %s response failed", Traits::NAME);
return false;
}
hex_buffer2[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s Response is %s", Traits::NAME, hex_buffer2);
// Compare
bool matches = memcmp(hex_buffer1, hex_buffer2, Traits::HEX_SIZE) == 0;
if (!matches) {
ESP_LOGW(TAG, "Auth failed! %s passwords do not match", Traits::NAME);
}
return matches;
}
// Explicit template instantiations
template bool ESPHomeOTAComponent::perform_hash_auth_<md5::MD5Digest>(const std::string &);
#ifdef USE_OTA_SHA256
template bool ESPHomeOTAComponent::perform_hash_auth_<sha256::SHA256>(const std::string &);
#endif
} // namespace esphome
#endif

View File

@@ -30,6 +30,7 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
protected:
void handle_handshake_();
void handle_data_();
template<typename HashClass> bool perform_hash_auth_(const std::string &password);
bool readall_(uint8_t *buf, size_t len);
bool writeall_(const uint8_t *buf, size_t len);
void log_socket_error_(const LogString *msg);

View File

@@ -14,6 +14,7 @@ namespace ota {
enum OTAResponseTypes {
OTA_RESPONSE_OK = 0x00,
OTA_RESPONSE_REQUEST_AUTH = 0x01,
OTA_RESPONSE_REQUEST_SHA256_AUTH = 0x02,
OTA_RESPONSE_HEADER_OK = 0x40,
OTA_RESPONSE_AUTH_OK = 0x41,

View File

@@ -0,0 +1,24 @@
import esphome.codegen as cg
import esphome.config_validation as cv
from esphome.core import CORE
from esphome.helpers import IS_MACOS
from esphome.types import ConfigType
CODEOWNERS = ["@esphome/core"]
sha256_ns = cg.esphome_ns.namespace("sha256")
CONFIG_SCHEMA = cv.Schema({})
async def to_code(config: ConfigType) -> None:
# Add OpenSSL library for host platform
if CORE.is_host:
if IS_MACOS:
# macOS needs special handling for Homebrew OpenSSL
cg.add_build_flag("-I/opt/homebrew/opt/openssl/include")
cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib")
cg.add_build_flag("-lcrypto")
else:
# Linux and other Unix systems usually have OpenSSL in standard paths
cg.add_build_flag("-lcrypto")

View File

@@ -0,0 +1,160 @@
#include "sha256.h"
// Only compile SHA256 implementation on platforms that support it
#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST)
#include "esphome/core/helpers.h"
#include <cstring>
namespace esphome::sha256 {
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
SHA256::~SHA256() {
if (this->ctx_) {
mbedtls_sha256_free(&this->ctx_->ctx);
}
}
void SHA256::init() {
if (!this->ctx_) {
this->ctx_ = std::make_unique<SHA256Context>();
}
mbedtls_sha256_init(&this->ctx_->ctx);
mbedtls_sha256_starts(&this->ctx_->ctx, 0); // 0 = SHA256, not SHA224
}
void SHA256::add(const uint8_t *data, size_t len) {
if (!this->ctx_) {
this->init();
}
mbedtls_sha256_update(&this->ctx_->ctx, data, len);
}
void SHA256::calculate() {
if (!this->ctx_) {
this->init();
}
mbedtls_sha256_finish(&this->ctx_->ctx, this->ctx_->hash);
}
#elif defined(USE_ESP8266) || defined(USE_RP2040)
SHA256::~SHA256() = default;
void SHA256::init() {
if (!this->ctx_) {
this->ctx_ = std::make_unique<SHA256Context>();
}
br_sha256_init(&this->ctx_->ctx);
this->ctx_->calculated = false;
}
void SHA256::add(const uint8_t *data, size_t len) {
if (!this->ctx_) {
this->init();
}
br_sha256_update(&this->ctx_->ctx, data, len);
}
void SHA256::calculate() {
if (!this->ctx_) {
this->init();
}
if (!this->ctx_->calculated) {
br_sha256_out(&this->ctx_->ctx, this->ctx_->hash);
this->ctx_->calculated = true;
}
}
#elif defined(USE_HOST)
SHA256::~SHA256() {
if (this->ctx_ && this->ctx_->ctx) {
EVP_MD_CTX_free(this->ctx_->ctx);
this->ctx_->ctx = nullptr;
}
}
void SHA256::init() {
if (!this->ctx_) {
this->ctx_ = std::make_unique<SHA256Context>();
}
if (this->ctx_->ctx) {
EVP_MD_CTX_free(this->ctx_->ctx);
}
this->ctx_->ctx = EVP_MD_CTX_new();
EVP_DigestInit_ex(this->ctx_->ctx, EVP_sha256(), nullptr);
this->ctx_->calculated = false;
}
void SHA256::add(const uint8_t *data, size_t len) {
if (!this->ctx_) {
this->init();
}
EVP_DigestUpdate(this->ctx_->ctx, data, len);
}
void SHA256::calculate() {
if (!this->ctx_) {
this->init();
}
if (!this->ctx_->calculated) {
unsigned int len = 32;
EVP_DigestFinal_ex(this->ctx_->ctx, this->ctx_->hash, &len);
this->ctx_->calculated = true;
}
}
#else
#error "SHA256 not supported on this platform"
#endif
void SHA256::get_bytes(uint8_t *output) {
if (!this->ctx_) {
memset(output, 0, 32);
return;
}
memcpy(output, this->ctx_->hash, 32);
}
void SHA256::get_hex(char *output) {
if (!this->ctx_) {
memset(output, '0', 64);
output[64] = '\0';
return;
}
for (size_t i = 0; i < 32; i++) {
uint8_t byte = this->ctx_->hash[i];
output[i * 2] = format_hex_char(byte >> 4);
output[i * 2 + 1] = format_hex_char(byte & 0x0F);
}
}
std::string SHA256::get_hex_string() {
char buf[65];
this->get_hex(buf);
return std::string(buf);
}
bool SHA256::equals_bytes(const uint8_t *expected) {
if (!this->ctx_) {
return false;
}
return memcmp(this->ctx_->hash, expected, 32) == 0;
}
bool SHA256::equals_hex(const char *expected) {
if (!this->ctx_) {
return false;
}
uint8_t parsed[32];
if (!parse_hex(expected, parsed, 32)) {
return false;
}
return this->equals_bytes(parsed);
}
} // namespace esphome::sha256
#endif // Platform check

View File

@@ -0,0 +1,69 @@
#pragma once
#include "esphome/core/defines.h"
// Only define SHA256 on platforms that support it
#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST)
#include <cstdint>
#include <string>
#include <memory>
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
#include "mbedtls/sha256.h"
#elif defined(USE_ESP8266) || defined(USE_RP2040)
#include <bearssl/bearssl_hash.h>
#elif defined(USE_HOST)
#include <openssl/evp.h>
#else
#error "SHA256 not supported on this platform"
#endif
namespace esphome::sha256 {
class SHA256 {
public:
SHA256() = default;
~SHA256();
void init();
void add(const uint8_t *data, size_t len);
void add(const char *data, size_t len) { this->add((const uint8_t *) data, len); }
void add(const std::string &data) { this->add(data.c_str(), data.length()); }
void calculate();
void get_bytes(uint8_t *output);
void get_hex(char *output);
std::string get_hex_string();
bool equals_bytes(const uint8_t *expected);
bool equals_hex(const char *expected);
protected:
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
struct SHA256Context {
mbedtls_sha256_context ctx;
uint8_t hash[32];
};
#elif defined(USE_ESP8266) || defined(USE_RP2040)
struct SHA256Context {
br_sha256_context ctx;
uint8_t hash[32];
bool calculated{false};
};
#elif defined(USE_HOST)
struct SHA256Context {
EVP_MD_CTX *ctx{nullptr};
uint8_t hash[32];
bool calculated{false};
};
#else
#error "SHA256 not supported on this platform"
#endif
std::unique_ptr<SHA256Context> ctx_;
};
} // namespace esphome::sha256
#endif // Platform check

View File

@@ -116,6 +116,7 @@
#define USE_API_PLAINTEXT
#define USE_API_SERVICES
#define USE_MD5
#define USE_SHA256
#define USE_MQTT
#define USE_NETWORK
#define USE_ONLINE_IMAGE_BMP_SUPPORT

View File

@@ -82,6 +82,16 @@ template<typename T> constexpr T byteswap(T n) {
return m;
}
template<> constexpr uint8_t byteswap(uint8_t n) { return n; }
#ifdef USE_LIBRETINY
// LibreTiny's Beken framework redefines __builtin_bswap functions as non-constexpr
template<> inline uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
template<> inline uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
template<> inline uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
template<> inline int8_t byteswap(int8_t n) { return n; }
template<> inline int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
template<> inline int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
template<> inline int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
#else
template<> constexpr uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
template<> constexpr uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
template<> constexpr uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
@@ -89,6 +99,7 @@ template<> constexpr int8_t byteswap(int8_t n) { return n; }
template<> constexpr int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
template<> constexpr int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
template<> constexpr int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
#endif
///@}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Callable
import gzip
import hashlib
import io
@@ -9,12 +10,14 @@ import random
import socket
import sys
import time
from typing import Any
from esphome.core import EsphomeError
from esphome.helpers import resolve_ip_address
RESPONSE_OK = 0x00
RESPONSE_REQUEST_AUTH = 0x01
RESPONSE_REQUEST_SHA256_AUTH = 0x02
RESPONSE_HEADER_OK = 0x40
RESPONSE_AUTH_OK = 0x41
@@ -45,6 +48,7 @@ OTA_VERSION_2_0 = 2
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
FEATURE_SUPPORTS_COMPRESSION = 0x01
FEATURE_SUPPORTS_SHA256_AUTH = 0x02
UPLOAD_BLOCK_SIZE = 8192
@@ -52,6 +56,12 @@ UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8
_LOGGER = logging.getLogger(__name__)
# Authentication method lookup table: response -> (hash_func, nonce_size, name)
_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = {
RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"),
RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"),
}
class ProgressBar:
def __init__(self):
@@ -81,18 +91,43 @@ class OTAError(EsphomeError):
pass
def recv_decode(sock, amount, decode=True):
def recv_decode(
sock: socket.socket, amount: int, decode: bool = True
) -> bytes | list[int]:
"""Receive data from socket and optionally decode to list of integers.
:param sock: Socket to receive data from.
:param amount: Number of bytes to receive.
:param decode: If True, convert bytes to list of integers, otherwise return raw bytes.
:return: List of integers if decode=True, otherwise raw bytes.
"""
data = sock.recv(amount)
if not decode:
return data
return list(data)
def receive_exactly(sock, amount, msg, expect, decode=True):
data = [] if decode else b""
def receive_exactly(
sock: socket.socket,
amount: int,
msg: str,
expect: int | list[int] | None,
decode: bool = True,
) -> list[int] | bytes:
"""Receive exactly the specified amount of data from socket with error checking.
:param sock: Socket to receive data from.
:param amount: Exact number of bytes to receive.
:param msg: Description of what is being received for error messages.
:param expect: Expected response code(s) for validation, None to skip validation.
:param decode: If True, return list of integers, otherwise return raw bytes.
:return: List of integers if decode=True, otherwise raw bytes.
:raises OTAError: If receiving fails or response doesn't match expected.
"""
data: list[int] | bytes = [] if decode else b""
try:
data += recv_decode(sock, 1, decode=decode)
data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
except OSError as err:
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
@@ -104,13 +139,19 @@ def receive_exactly(sock, amount, msg, expect, decode=True):
while len(data) < amount:
try:
data += recv_decode(sock, amount - len(data), decode=decode)
data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
except OSError as err:
raise OTAError(f"Error receiving {msg}: {err}") from err
return data
def check_error(data, expect):
def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
"""Check response data for error codes and validate against expected response.
:param data: Response data from device (first byte is the response code).
:param expect: Expected response code(s), None to skip validation.
:raises OTAError: If an error code is detected or response doesn't match expected.
"""
if not expect:
return
dat = data[0]
@@ -177,7 +218,16 @@ def check_error(data, expect):
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
def send_check(sock, data, msg):
def send_check(
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
) -> None:
"""Send data to socket with error handling.
:param sock: Socket to send data to.
:param data: Data to send (can be list/tuple of ints, single int, string, or bytes).
:param msg: Description of what is being sent for error messages.
:raises OTAError: If sending fails.
"""
try:
if isinstance(data, (list, tuple)):
data = bytes(data)
@@ -210,10 +260,14 @@ def perform_ota(
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
)
# Features
send_check(sock, FEATURE_SUPPORTS_COMPRESSION, "features")
# Features - send both compression and SHA256 auth support
features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH
send_check(sock, features_to_send, "features")
features = receive_exactly(
sock, 1, "features", [RESPONSE_HEADER_OK, RESPONSE_SUPPORTS_COMPRESSION]
sock,
1,
"features",
None, # Accept any response
)[0]
if features == RESPONSE_SUPPORTS_COMPRESSION:
@@ -222,31 +276,52 @@ def perform_ota(
else:
upload_contents = file_contents
(auth,) = receive_exactly(
sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK]
)
if auth == RESPONSE_REQUEST_AUTH:
def perform_auth(
sock: socket.socket,
password: str,
hash_func: Callable[..., Any],
nonce_size: int,
hash_name: str,
) -> None:
"""Perform challenge-response authentication using specified hash algorithm."""
if not password:
raise OTAError("ESP requests password, but no password given!")
nonce = receive_exactly(
sock, 32, "authentication nonce", [], decode=False
).decode()
_LOGGER.debug("Auth: Nonce is %s", nonce)
cnonce = hashlib.md5(str(random.random()).encode()).hexdigest()
_LOGGER.debug("Auth: CNonce is %s", cnonce)
nonce_bytes = receive_exactly(
sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False
)
assert isinstance(nonce_bytes, bytes)
nonce = nonce_bytes.decode()
_LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce)
# Generate cnonce
cnonce = hash_func(str(random.random()).encode()).hexdigest()
_LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce)
send_check(sock, cnonce, "auth cnonce")
result_md5 = hashlib.md5()
result_md5.update(password.encode("utf-8"))
result_md5.update(nonce.encode())
result_md5.update(cnonce.encode())
result = result_md5.hexdigest()
_LOGGER.debug("Auth: Result is %s", result)
# Calculate challenge response
hasher = hash_func()
hasher.update(password.encode("utf-8"))
hasher.update(nonce.encode())
hasher.update(cnonce.encode())
result = hasher.hexdigest()
_LOGGER.debug("Auth: %s Result is %s", hash_name, result)
send_check(sock, result, "auth result")
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
(auth,) = receive_exactly(
sock,
1,
"auth",
[RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK],
)
if auth != RESPONSE_AUTH_OK:
hash_func, nonce_size, hash_name = _AUTH_METHODS[auth]
perform_auth(sock, password, hash_func, nonce_size, hash_name)
# Set higher timeout during upload
sock.settimeout(30.0)