1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-22 21:22:22 +01:00

Merge branch 'sha256_ota' into integration

This commit is contained in:
J. Nick Koston
2025-09-21 11:06:21 -06:00
18 changed files with 1344 additions and 75 deletions

View File

@@ -407,6 +407,7 @@ esphome/components/sensor/* @esphome/core
esphome/components/sfa30/* @ghsensdev esphome/components/sfa30/* @ghsensdev
esphome/components/sgp40/* @SenexCrenshaw esphome/components/sgp40/* @SenexCrenshaw
esphome/components/sgp4x/* @martgras @SenexCrenshaw esphome/components/sgp4x/* @martgras @SenexCrenshaw
esphome/components/sha256/* @esphome/core
esphome/components/shelly_dimmer/* @edge90 @rnauber esphome/components/shelly_dimmer/* @edge90 @rnauber
esphome/components/sht3xd/* @mrtoy-me esphome/components/sht3xd/* @mrtoy-me
esphome/components/sht4x/* @sjtrny esphome/components/sht4x/* @sjtrny

View File

@@ -16,7 +16,7 @@ from esphome.const import (
CONF_SAFE_MODE, CONF_SAFE_MODE,
CONF_VERSION, CONF_VERSION,
) )
from esphome.core import coroutine_with_priority from esphome.core import CORE, coroutine_with_priority
from esphome.coroutine import CoroPriority from esphome.coroutine import CoroPriority
import esphome.final_validate as fv import esphome.final_validate as fv
@@ -24,9 +24,22 @@ _LOGGER = logging.getLogger(__name__)
CODEOWNERS = ["@esphome/core"] CODEOWNERS = ["@esphome/core"]
AUTO_LOAD = ["md5", "socket"]
DEPENDENCIES = ["network"] DEPENDENCIES = ["network"]
def supports_sha256() -> bool:
"""Check if the current platform supports SHA256 for OTA authentication."""
return bool(CORE.is_esp32 or CORE.is_esp8266 or CORE.is_rp2040 or CORE.is_libretiny)
def AUTO_LOAD() -> list[str]:
"""Conditionally auto-load sha256 only on platforms that support it."""
base_components = ["md5", "socket"]
if supports_sha256():
return base_components + ["sha256"]
return base_components
esphome = cg.esphome_ns.namespace("esphome") esphome = cg.esphome_ns.namespace("esphome")
ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent) ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent)
@@ -126,6 +139,11 @@ FINAL_VALIDATE_SCHEMA = ota_esphome_final_validate
async def to_code(config): async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID]) var = cg.new_Pvariable(config[CONF_ID])
cg.add(var.set_port(config[CONF_PORT])) cg.add(var.set_port(config[CONF_PORT]))
# Only include SHA256 support on platforms that have it
if supports_sha256():
cg.add_define("USE_OTA_SHA256")
if CONF_PASSWORD in config: if CONF_PASSWORD in config:
cg.add(var.set_auth_password(config[CONF_PASSWORD])) cg.add(var.set_auth_password(config[CONF_PASSWORD]))
cg.add_define("USE_OTA_PASSWORD") cg.add_define("USE_OTA_PASSWORD")

View File

@@ -1,6 +1,9 @@
#include "ota_esphome.h" #include "ota_esphome.h"
#ifdef USE_OTA #ifdef USE_OTA
#include "esphome/components/md5/md5.h" #include "esphome/components/md5/md5.h"
#ifdef USE_OTA_SHA256
#include "esphome/components/sha256/sha256.h"
#endif
#include "esphome/components/network/util.h" #include "esphome/components/network/util.h"
#include "esphome/components/ota/ota_backend.h" #include "esphome/components/ota/ota_backend.h"
#include "esphome/components/ota/ota_backend_arduino_esp32.h" #include "esphome/components/ota/ota_backend_arduino_esp32.h"
@@ -95,6 +98,33 @@ void ESPHomeOTAComponent::loop() {
} }
static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01; static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01;
#ifdef USE_OTA_SHA256
static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02;
#endif
// Temporary flag to allow MD5 downgrade for ~3 versions (until 2026.1.0)
// This allows users to downgrade via OTA if they encounter issues after updating.
// Without this, users would need to do a serial flash to downgrade.
// TODO: Remove this flag and all associated code in 2026.1.0
#define ALLOW_OTA_DOWNGRADE_MD5
template<typename HashClass> struct HashTraits;
template<> struct HashTraits<md5::MD5Digest> {
static constexpr int NONCE_SIZE = 8;
static constexpr int HEX_SIZE = 32;
static constexpr const char *NAME = "MD5";
static constexpr ota::OTAResponseTypes AUTH_REQUEST = ota::OTA_RESPONSE_REQUEST_AUTH;
};
#ifdef USE_OTA_SHA256
template<> struct HashTraits<sha256::SHA256> {
static constexpr int NONCE_SIZE = 16;
static constexpr int HEX_SIZE = 64;
static constexpr const char *NAME = "SHA256";
static constexpr ota::OTAResponseTypes AUTH_REQUEST = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH;
};
#endif
void ESPHomeOTAComponent::handle_handshake_() { void ESPHomeOTAComponent::handle_handshake_() {
/// Handle the initial OTA handshake. /// Handle the initial OTA handshake.
@@ -225,57 +255,55 @@ void ESPHomeOTAComponent::handle_data_() {
#ifdef USE_OTA_PASSWORD #ifdef USE_OTA_PASSWORD
if (!this->password_.empty()) { if (!this->password_.empty()) {
buf[0] = ota::OTA_RESPONSE_REQUEST_AUTH; bool auth_success = false;
this->writeall_(buf, 1);
md5::MD5Digest md5{};
md5.init();
sprintf(sbuf, "%08" PRIx32, random_uint32());
md5.add(sbuf, 8);
md5.calculate();
md5.get_hex(sbuf);
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
// Send nonce, 32 bytes hex MD5 #ifdef USE_OTA_SHA256
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) { // SECURITY HARDENING: Prefer SHA256 authentication on platforms that support it.
ESP_LOGW(TAG, "Auth: Writing nonce failed"); //
// This is a hardening measure to prevent future downgrade attacks where an attacker
// could force the use of MD5 authentication by manipulating the feature flags.
//
// While MD5 is currently still acceptable for our OTA authentication use case
// (where the password is a shared secret and we're only authenticating, not
// encrypting), at some point in the future MD5 will likely become so weak that
// it could be practically attacked.
//
// We enforce SHA256 now on capable platforms because:
// 1. We can't retroactively update device firmware in the field
// 2. Clients (like esphome CLI) can always be updated to support SHA256
// 3. This prevents any possibility of downgrade attacks in the future
//
// Devices that don't support SHA256 (due to platform limitations) will
// continue to use MD5 as their only option (see #else branch below).
bool client_supports_sha256 = (ota_features & FEATURE_SUPPORTS_SHA256_AUTH) != 0;
#ifdef ALLOW_OTA_DOWNGRADE_MD5
// Temporary compatibility mode: Allow MD5 for ~3 versions to enable OTA downgrades
// This prevents users from being locked out if they need to downgrade after updating
// TODO: Remove this entire ifdef block in 2026.1.0
if (client_supports_sha256) {
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
} else {
ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)");
auth_success = this->perform_hash_auth_<md5::MD5Digest>(this->password_);
}
#else
// Strict mode: SHA256 required on capable platforms (future default)
if (!client_supports_sha256) {
ESP_LOGW(TAG, "Client requires SHA256");
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
#endif // ALLOW_OTA_DOWNGRADE_MD5
#else
// Platform only supports MD5 - use it as the only available option
// This is not a security downgrade as the platform cannot support SHA256
auth_success = this->perform_hash_auth_<md5::MD5Digest>(this->password_);
#endif // USE_OTA_SHA256
// prepare challenge if (!auth_success) {
md5.init();
md5.add(this->password_.c_str(), this->password_.length());
// add nonce
md5.add(sbuf, 32);
// Receive cnonce, 32 bytes hex MD5
if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Auth: Reading cnonce failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
sbuf[32] = '\0';
ESP_LOGV(TAG, "Auth: CNonce is %s", sbuf);
// add cnonce
md5.add(sbuf, 32);
// calculate result
md5.calculate();
md5.get_hex(sbuf);
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
// Receive result, 32 bytes hex MD5
if (!this->readall_(buf + 64, 32)) {
ESP_LOGW(TAG, "Auth: Reading response failed");
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
}
sbuf[64 + 32] = '\0';
ESP_LOGV(TAG, "Auth: Response is %s", sbuf + 64);
bool matches = true;
for (uint8_t i = 0; i < 32; i++)
matches = matches && buf[i] == buf[64 + i];
if (!matches) {
ESP_LOGW(TAG, "Auth failed! Passwords do not match");
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID; error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
goto error; // NOLINT(cppcoreguidelines-avoid-goto) goto error; // NOLINT(cppcoreguidelines-avoid-goto)
} }
@@ -499,5 +527,110 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
delay(1); delay(1);
} }
// Template function definition - placed at end to ensure all types are complete
template<typename HashClass> bool ESPHomeOTAComponent::perform_hash_auth_(const std::string &password) {
using Traits = HashTraits<HashClass>;
// Minimize stack usage by reusing buffers
// We only need 2 buffers at most at the same time
constexpr size_t hex_buffer_size = Traits::HEX_SIZE + 1;
// These two buffers are reused throughout the function
char hex_buffer1[hex_buffer_size]; // Used for: nonce -> expected result
char hex_buffer2[hex_buffer_size]; // Used for: cnonce -> response
// Small stack buffer for auth request and nonce seed bytes
uint8_t buf[1];
uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256)
// Send auth request type
buf[0] = Traits::AUTH_REQUEST;
this->writeall_(buf, 1);
HashClass hasher;
hasher.init();
// Generate nonce seed bytes
uint32_t r1 = random_uint32();
// Convert first uint32 to bytes (always needed for MD5)
nonce_bytes[0] = (r1 >> 24) & 0xFF;
nonce_bytes[1] = (r1 >> 16) & 0xFF;
nonce_bytes[2] = (r1 >> 8) & 0xFF;
nonce_bytes[3] = r1 & 0xFF;
if (Traits::NONCE_SIZE == 8) {
// MD5: 8 chars = "%08x" format = 4 bytes from one random uint32
hasher.add(nonce_bytes, 4);
}
#ifdef USE_OTA_SHA256
else {
// SHA256: 16 chars = "%08x%08x" format = 8 bytes from two random uint32s
uint32_t r2 = random_uint32();
nonce_bytes[4] = (r2 >> 24) & 0xFF;
nonce_bytes[5] = (r2 >> 16) & 0xFF;
nonce_bytes[6] = (r2 >> 8) & 0xFF;
nonce_bytes[7] = r2 & 0xFF;
hasher.add(nonce_bytes, 8);
}
#endif
hasher.calculate();
// Use hex_buffer1 for nonce
hasher.get_hex(hex_buffer1);
hex_buffer1[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s Nonce is %s", Traits::NAME, hex_buffer1);
// Send nonce
if (!this->writeall_(reinterpret_cast<uint8_t *>(hex_buffer1), Traits::HEX_SIZE)) {
ESP_LOGW(TAG, "Auth: Writing %s nonce failed", Traits::NAME);
return false;
}
// Prepare challenge
hasher.init();
hasher.add(password.c_str(), password.length());
hasher.add(hex_buffer1, Traits::HEX_SIZE); // Add nonce
// Receive cnonce into hex_buffer2
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), Traits::HEX_SIZE)) {
ESP_LOGW(TAG, "Auth: Reading %s cnonce failed", Traits::NAME);
return false;
}
hex_buffer2[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s CNonce is %s", Traits::NAME, hex_buffer2);
// Add cnonce to hash
hasher.add(hex_buffer2, Traits::HEX_SIZE);
// Calculate result - reuse hex_buffer1 for expected
hasher.calculate();
hasher.get_hex(hex_buffer1);
hex_buffer1[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s Result is %s", Traits::NAME, hex_buffer1);
// Receive response - reuse hex_buffer2
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), Traits::HEX_SIZE)) {
ESP_LOGW(TAG, "Auth: Reading %s response failed", Traits::NAME);
return false;
}
hex_buffer2[Traits::HEX_SIZE] = '\0';
ESP_LOGV(TAG, "Auth: %s Response is %s", Traits::NAME, hex_buffer2);
// Compare
bool matches = memcmp(hex_buffer1, hex_buffer2, Traits::HEX_SIZE) == 0;
if (!matches) {
ESP_LOGW(TAG, "Auth failed! %s passwords do not match", Traits::NAME);
}
return matches;
}
// Explicit template instantiations
template bool ESPHomeOTAComponent::perform_hash_auth_<md5::MD5Digest>(const std::string &);
#ifdef USE_OTA_SHA256
template bool ESPHomeOTAComponent::perform_hash_auth_<sha256::SHA256>(const std::string &);
#endif
} // namespace esphome } // namespace esphome
#endif #endif

View File

@@ -30,6 +30,7 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
protected: protected:
void handle_handshake_(); void handle_handshake_();
void handle_data_(); void handle_data_();
template<typename HashClass> bool perform_hash_auth_(const std::string &password);
bool readall_(uint8_t *buf, size_t len); bool readall_(uint8_t *buf, size_t len);
bool writeall_(const uint8_t *buf, size_t len); bool writeall_(const uint8_t *buf, size_t len);
void log_socket_error_(const LogString *msg); void log_socket_error_(const LogString *msg);

View File

@@ -14,6 +14,7 @@ namespace ota {
enum OTAResponseTypes { enum OTAResponseTypes {
OTA_RESPONSE_OK = 0x00, OTA_RESPONSE_OK = 0x00,
OTA_RESPONSE_REQUEST_AUTH = 0x01, OTA_RESPONSE_REQUEST_AUTH = 0x01,
OTA_RESPONSE_REQUEST_SHA256_AUTH = 0x02,
OTA_RESPONSE_HEADER_OK = 0x40, OTA_RESPONSE_HEADER_OK = 0x40,
OTA_RESPONSE_AUTH_OK = 0x41, OTA_RESPONSE_AUTH_OK = 0x41,

View File

@@ -0,0 +1,24 @@
import esphome.codegen as cg
import esphome.config_validation as cv
from esphome.core import CORE
from esphome.helpers import IS_MACOS
from esphome.types import ConfigType
CODEOWNERS = ["@esphome/core"]
sha256_ns = cg.esphome_ns.namespace("sha256")
CONFIG_SCHEMA = cv.Schema({})
async def to_code(config: ConfigType) -> None:
# Add OpenSSL library for host platform
if CORE.is_host:
if IS_MACOS:
# macOS needs special handling for Homebrew OpenSSL
cg.add_build_flag("-I/opt/homebrew/opt/openssl/include")
cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib")
cg.add_build_flag("-lcrypto")
else:
# Linux and other Unix systems usually have OpenSSL in standard paths
cg.add_build_flag("-lcrypto")

View File

@@ -0,0 +1,160 @@
#include "sha256.h"
// Only compile SHA256 implementation on platforms that support it
#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST)
#include "esphome/core/helpers.h"
#include <cstring>
namespace esphome::sha256 {
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
SHA256::~SHA256() {
if (this->ctx_) {
mbedtls_sha256_free(&this->ctx_->ctx);
}
}
void SHA256::init() {
if (!this->ctx_) {
this->ctx_ = std::make_unique<SHA256Context>();
}
mbedtls_sha256_init(&this->ctx_->ctx);
mbedtls_sha256_starts(&this->ctx_->ctx, 0); // 0 = SHA256, not SHA224
}
void SHA256::add(const uint8_t *data, size_t len) {
if (!this->ctx_) {
this->init();
}
mbedtls_sha256_update(&this->ctx_->ctx, data, len);
}
void SHA256::calculate() {
if (!this->ctx_) {
this->init();
}
mbedtls_sha256_finish(&this->ctx_->ctx, this->ctx_->hash);
}
#elif defined(USE_ESP8266) || defined(USE_RP2040)
SHA256::~SHA256() = default;
void SHA256::init() {
if (!this->ctx_) {
this->ctx_ = std::make_unique<SHA256Context>();
}
br_sha256_init(&this->ctx_->ctx);
this->ctx_->calculated = false;
}
void SHA256::add(const uint8_t *data, size_t len) {
if (!this->ctx_) {
this->init();
}
br_sha256_update(&this->ctx_->ctx, data, len);
}
void SHA256::calculate() {
if (!this->ctx_) {
this->init();
}
if (!this->ctx_->calculated) {
br_sha256_out(&this->ctx_->ctx, this->ctx_->hash);
this->ctx_->calculated = true;
}
}
#elif defined(USE_HOST)
SHA256::~SHA256() {
if (this->ctx_ && this->ctx_->ctx) {
EVP_MD_CTX_free(this->ctx_->ctx);
this->ctx_->ctx = nullptr;
}
}
void SHA256::init() {
if (!this->ctx_) {
this->ctx_ = std::make_unique<SHA256Context>();
}
if (this->ctx_->ctx) {
EVP_MD_CTX_free(this->ctx_->ctx);
}
this->ctx_->ctx = EVP_MD_CTX_new();
EVP_DigestInit_ex(this->ctx_->ctx, EVP_sha256(), nullptr);
this->ctx_->calculated = false;
}
void SHA256::add(const uint8_t *data, size_t len) {
if (!this->ctx_) {
this->init();
}
EVP_DigestUpdate(this->ctx_->ctx, data, len);
}
void SHA256::calculate() {
if (!this->ctx_) {
this->init();
}
if (!this->ctx_->calculated) {
unsigned int len = 32;
EVP_DigestFinal_ex(this->ctx_->ctx, this->ctx_->hash, &len);
this->ctx_->calculated = true;
}
}
#else
#error "SHA256 not supported on this platform"
#endif
void SHA256::get_bytes(uint8_t *output) {
if (!this->ctx_) {
memset(output, 0, 32);
return;
}
memcpy(output, this->ctx_->hash, 32);
}
void SHA256::get_hex(char *output) {
if (!this->ctx_) {
memset(output, '0', 64);
output[64] = '\0';
return;
}
for (size_t i = 0; i < 32; i++) {
uint8_t byte = this->ctx_->hash[i];
output[i * 2] = format_hex_char(byte >> 4);
output[i * 2 + 1] = format_hex_char(byte & 0x0F);
}
}
std::string SHA256::get_hex_string() {
char buf[65];
this->get_hex(buf);
return std::string(buf);
}
bool SHA256::equals_bytes(const uint8_t *expected) {
if (!this->ctx_) {
return false;
}
return memcmp(this->ctx_->hash, expected, 32) == 0;
}
bool SHA256::equals_hex(const char *expected) {
if (!this->ctx_) {
return false;
}
uint8_t parsed[32];
if (!parse_hex(expected, parsed, 32)) {
return false;
}
return this->equals_bytes(parsed);
}
} // namespace esphome::sha256
#endif // Platform check

View File

@@ -0,0 +1,69 @@
#pragma once
#include "esphome/core/defines.h"
// Only define SHA256 on platforms that support it
#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST)
#include <cstdint>
#include <string>
#include <memory>
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
#include "mbedtls/sha256.h"
#elif defined(USE_ESP8266) || defined(USE_RP2040)
#include <bearssl/bearssl_hash.h>
#elif defined(USE_HOST)
#include <openssl/evp.h>
#else
#error "SHA256 not supported on this platform"
#endif
namespace esphome::sha256 {
class SHA256 {
public:
SHA256() = default;
~SHA256();
void init();
void add(const uint8_t *data, size_t len);
void add(const char *data, size_t len) { this->add((const uint8_t *) data, len); }
void add(const std::string &data) { this->add(data.c_str(), data.length()); }
void calculate();
void get_bytes(uint8_t *output);
void get_hex(char *output);
std::string get_hex_string();
bool equals_bytes(const uint8_t *expected);
bool equals_hex(const char *expected);
protected:
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
struct SHA256Context {
mbedtls_sha256_context ctx;
uint8_t hash[32];
};
#elif defined(USE_ESP8266) || defined(USE_RP2040)
struct SHA256Context {
br_sha256_context ctx;
uint8_t hash[32];
bool calculated{false};
};
#elif defined(USE_HOST)
struct SHA256Context {
EVP_MD_CTX *ctx{nullptr};
uint8_t hash[32];
bool calculated{false};
};
#else
#error "SHA256 not supported on this platform"
#endif
std::unique_ptr<SHA256Context> ctx_;
};
} // namespace esphome::sha256
#endif // Platform check

View File

@@ -116,6 +116,7 @@
#define USE_API_PLAINTEXT #define USE_API_PLAINTEXT
#define USE_API_SERVICES #define USE_API_SERVICES
#define USE_MD5 #define USE_MD5
#define USE_SHA256
#define USE_MQTT #define USE_MQTT
#define USE_NETWORK #define USE_NETWORK
#define USE_ONLINE_IMAGE_BMP_SUPPORT #define USE_ONLINE_IMAGE_BMP_SUPPORT

View File

@@ -82,6 +82,16 @@ template<typename T> constexpr T byteswap(T n) {
return m; return m;
} }
template<> constexpr uint8_t byteswap(uint8_t n) { return n; } template<> constexpr uint8_t byteswap(uint8_t n) { return n; }
#ifdef USE_LIBRETINY
// LibreTiny's Beken framework redefines __builtin_bswap functions as non-constexpr
template<> inline uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
template<> inline uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
template<> inline uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
template<> inline int8_t byteswap(int8_t n) { return n; }
template<> inline int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
template<> inline int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
template<> inline int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
#else
template<> constexpr uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); } template<> constexpr uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
template<> constexpr uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); } template<> constexpr uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
template<> constexpr uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); } template<> constexpr uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
@@ -89,6 +99,7 @@ template<> constexpr int8_t byteswap(int8_t n) { return n; }
template<> constexpr int16_t byteswap(int16_t n) { return __builtin_bswap16(n); } template<> constexpr int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
template<> constexpr int32_t byteswap(int32_t n) { return __builtin_bswap32(n); } template<> constexpr int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
template<> constexpr int64_t byteswap(int64_t n) { return __builtin_bswap64(n); } template<> constexpr int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
#endif
///@} ///@}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import gzip import gzip
import hashlib import hashlib
import io import io
@@ -9,12 +10,14 @@ import random
import socket import socket
import sys import sys
import time import time
from typing import Any
from esphome.core import EsphomeError from esphome.core import EsphomeError
from esphome.helpers import resolve_ip_address from esphome.helpers import resolve_ip_address
RESPONSE_OK = 0x00 RESPONSE_OK = 0x00
RESPONSE_REQUEST_AUTH = 0x01 RESPONSE_REQUEST_AUTH = 0x01
RESPONSE_REQUEST_SHA256_AUTH = 0x02
RESPONSE_HEADER_OK = 0x40 RESPONSE_HEADER_OK = 0x40
RESPONSE_AUTH_OK = 0x41 RESPONSE_AUTH_OK = 0x41
@@ -45,6 +48,7 @@ OTA_VERSION_2_0 = 2
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45] MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
FEATURE_SUPPORTS_COMPRESSION = 0x01 FEATURE_SUPPORTS_COMPRESSION = 0x01
FEATURE_SUPPORTS_SHA256_AUTH = 0x02
UPLOAD_BLOCK_SIZE = 8192 UPLOAD_BLOCK_SIZE = 8192
@@ -52,6 +56,12 @@ UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Authentication method lookup table: response -> (hash_func, nonce_size, name)
_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = {
RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"),
RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"),
}
class ProgressBar: class ProgressBar:
def __init__(self): def __init__(self):
@@ -81,18 +91,43 @@ class OTAError(EsphomeError):
pass pass
def recv_decode(sock, amount, decode=True): def recv_decode(
sock: socket.socket, amount: int, decode: bool = True
) -> bytes | list[int]:
"""Receive data from socket and optionally decode to list of integers.
:param sock: Socket to receive data from.
:param amount: Number of bytes to receive.
:param decode: If True, convert bytes to list of integers, otherwise return raw bytes.
:return: List of integers if decode=True, otherwise raw bytes.
"""
data = sock.recv(amount) data = sock.recv(amount)
if not decode: if not decode:
return data return data
return list(data) return list(data)
def receive_exactly(sock, amount, msg, expect, decode=True): def receive_exactly(
data = [] if decode else b"" sock: socket.socket,
amount: int,
msg: str,
expect: int | list[int] | None,
decode: bool = True,
) -> list[int] | bytes:
"""Receive exactly the specified amount of data from socket with error checking.
:param sock: Socket to receive data from.
:param amount: Exact number of bytes to receive.
:param msg: Description of what is being received for error messages.
:param expect: Expected response code(s) for validation, None to skip validation.
:param decode: If True, return list of integers, otherwise return raw bytes.
:return: List of integers if decode=True, otherwise raw bytes.
:raises OTAError: If receiving fails or response doesn't match expected.
"""
data: list[int] | bytes = [] if decode else b""
try: try:
data += recv_decode(sock, 1, decode=decode) data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
except OSError as err: except OSError as err:
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
@@ -104,13 +139,19 @@ def receive_exactly(sock, amount, msg, expect, decode=True):
while len(data) < amount: while len(data) < amount:
try: try:
data += recv_decode(sock, amount - len(data), decode=decode) data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
except OSError as err: except OSError as err:
raise OTAError(f"Error receiving {msg}: {err}") from err raise OTAError(f"Error receiving {msg}: {err}") from err
return data return data
def check_error(data, expect): def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
"""Check response data for error codes and validate against expected response.
:param data: Response data from device (first byte is the response code).
:param expect: Expected response code(s), None to skip validation.
:raises OTAError: If an error code is detected or response doesn't match expected.
"""
if not expect: if not expect:
return return
dat = data[0] dat = data[0]
@@ -177,7 +218,16 @@ def check_error(data, expect):
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}") raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
def send_check(sock, data, msg): def send_check(
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
) -> None:
"""Send data to socket with error handling.
:param sock: Socket to send data to.
:param data: Data to send (can be list/tuple of ints, single int, string, or bytes).
:param msg: Description of what is being sent for error messages.
:raises OTAError: If sending fails.
"""
try: try:
if isinstance(data, (list, tuple)): if isinstance(data, (list, tuple)):
data = bytes(data) data = bytes(data)
@@ -210,10 +260,14 @@ def perform_ota(
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}" f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
) )
# Features # Features - send both compression and SHA256 auth support
send_check(sock, FEATURE_SUPPORTS_COMPRESSION, "features") features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH
send_check(sock, features_to_send, "features")
features = receive_exactly( features = receive_exactly(
sock, 1, "features", [RESPONSE_HEADER_OK, RESPONSE_SUPPORTS_COMPRESSION] sock,
1,
"features",
None, # Accept any response
)[0] )[0]
if features == RESPONSE_SUPPORTS_COMPRESSION: if features == RESPONSE_SUPPORTS_COMPRESSION:
@@ -222,31 +276,52 @@ def perform_ota(
else: else:
upload_contents = file_contents upload_contents = file_contents
(auth,) = receive_exactly( def perform_auth(
sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK] sock: socket.socket,
) password: str,
if auth == RESPONSE_REQUEST_AUTH: hash_func: Callable[..., Any],
nonce_size: int,
hash_name: str,
) -> None:
"""Perform challenge-response authentication using specified hash algorithm."""
if not password: if not password:
raise OTAError("ESP requests password, but no password given!") raise OTAError("ESP requests password, but no password given!")
nonce = receive_exactly(
sock, 32, "authentication nonce", [], decode=False nonce_bytes = receive_exactly(
).decode() sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False
_LOGGER.debug("Auth: Nonce is %s", nonce) )
cnonce = hashlib.md5(str(random.random()).encode()).hexdigest() assert isinstance(nonce_bytes, bytes)
_LOGGER.debug("Auth: CNonce is %s", cnonce) nonce = nonce_bytes.decode()
_LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce)
# Generate cnonce
cnonce = hash_func(str(random.random()).encode()).hexdigest()
_LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce)
send_check(sock, cnonce, "auth cnonce") send_check(sock, cnonce, "auth cnonce")
result_md5 = hashlib.md5() # Calculate challenge response
result_md5.update(password.encode("utf-8")) hasher = hash_func()
result_md5.update(nonce.encode()) hasher.update(password.encode("utf-8"))
result_md5.update(cnonce.encode()) hasher.update(nonce.encode())
result = result_md5.hexdigest() hasher.update(cnonce.encode())
_LOGGER.debug("Auth: Result is %s", result) result = hasher.hexdigest()
_LOGGER.debug("Auth: %s Result is %s", hash_name, result)
send_check(sock, result, "auth result") send_check(sock, result, "auth result")
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK) receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
(auth,) = receive_exactly(
sock,
1,
"auth",
[RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK],
)
if auth != RESPONSE_AUTH_OK:
hash_func, nonce_size, hash_name = _AUTH_METHODS[auth]
perform_auth(sock, password, hash_func, nonce_size, hash_name)
# Set higher timeout during upload # Set higher timeout during upload
sock.settimeout(30.0) sock.settimeout(30.0)

View File

@@ -0,0 +1,32 @@
esphome:
on_boot:
- lambda: |-
// Test SHA256 functionality
#ifdef USE_SHA256
using esphome::sha256::SHA256;
SHA256 hasher;
hasher.init();
// Test with "Hello World" - known SHA256
const char* test_string = "Hello World";
hasher.add(test_string, strlen(test_string));
hasher.calculate();
char hex_output[65];
hasher.get_hex(hex_output);
hex_output[64] = '\0';
ESP_LOGD("SHA256", "SHA256('Hello World') = %s", hex_output);
// Expected: a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e
const char* expected = "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e";
if (strcmp(hex_output, expected) == 0) {
ESP_LOGI("SHA256", "Test PASSED");
} else {
ESP_LOGE("SHA256", "Test FAILED. Expected %s", expected);
}
#else
ESP_LOGW("SHA256", "SHA256 not available on this platform");
#endif
sha256:

View File

@@ -0,0 +1 @@
<<: !include common.yaml

View File

@@ -0,0 +1 @@
<<: !include common.yaml

View File

@@ -0,0 +1 @@
<<: !include common.yaml

View File

@@ -0,0 +1 @@
<<: !include common.yaml

View File

@@ -0,0 +1 @@
<<: !include common.yaml

View File

@@ -0,0 +1,738 @@
"""Unit tests for esphome.espota2 module."""
from __future__ import annotations
from collections.abc import Generator
import gzip
import hashlib
import io
from pathlib import Path
import socket
import struct
from unittest.mock import Mock, call, patch
import pytest
from pytest import CaptureFixture
from esphome import espota2
from esphome.core import EsphomeError
# Test constants
MOCK_RANDOM_VALUE = 0.123456
MOCK_RANDOM_BYTES = b"0.123456"
MOCK_MD5_NONCE = b"12345678901234567890123456789012" # 32 char nonce for MD5
MOCK_SHA256_NONCE = b"1234567890123456789012345678901234567890123456789012345678901234" # 64 char nonce for SHA256
@pytest.fixture
def mock_socket() -> Mock:
"""Create a mock socket for testing."""
socket_mock = Mock()
socket_mock.close = Mock()
socket_mock.recv = Mock()
socket_mock.sendall = Mock()
socket_mock.settimeout = Mock()
socket_mock.connect = Mock()
socket_mock.setsockopt = Mock()
return socket_mock
@pytest.fixture
def mock_file() -> io.BytesIO:
"""Create a mock firmware file for testing."""
return io.BytesIO(b"firmware content here")
@pytest.fixture
def mock_time() -> Generator[None]:
"""Mock time-related functions for consistent testing."""
# Provide enough values for multiple calls (tests may call perform_ota multiple times)
with (
patch("time.sleep"),
patch("time.perf_counter", side_effect=[0, 1, 0, 1, 0, 1]),
):
yield
@pytest.fixture
def mock_random() -> Generator[Mock]:
"""Mock random for predictable test values."""
with patch("random.random", return_value=MOCK_RANDOM_VALUE) as mock_rand:
yield mock_rand
@pytest.fixture
def mock_resolve_ip() -> Generator[Mock]:
"""Mock resolve_ip_address for testing."""
with patch("esphome.espota2.resolve_ip_address") as mock:
mock.return_value = [
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.100", 3232))
]
yield mock
@pytest.fixture
def mock_perform_ota() -> Generator[Mock]:
"""Mock perform_ota function for testing."""
with patch("esphome.espota2.perform_ota") as mock:
yield mock
@pytest.fixture
def mock_run_ota_impl() -> Generator[Mock]:
"""Mock run_ota_impl_ function for testing."""
with patch("esphome.espota2.run_ota_impl_") as mock:
mock.return_value = (0, "192.168.1.100")
yield mock
@pytest.fixture
def mock_socket_constructor(mock_socket: Mock) -> Generator[Mock]:
"""Mock socket.socket constructor to return our mock socket."""
with patch("socket.socket", return_value=mock_socket) as mock_constructor:
yield mock_constructor
def test_recv_decode_with_decode(mock_socket: Mock) -> None:
"""Test recv_decode with decode=True returns list."""
mock_socket.recv.return_value = b"\x01\x02\x03"
result = espota2.recv_decode(mock_socket, 3, decode=True)
assert result == [1, 2, 3]
mock_socket.recv.assert_called_once_with(3)
def test_recv_decode_without_decode(mock_socket: Mock) -> None:
"""Test recv_decode with decode=False returns bytes."""
mock_socket.recv.return_value = b"\x01\x02\x03"
result = espota2.recv_decode(mock_socket, 3, decode=False)
assert result == b"\x01\x02\x03"
mock_socket.recv.assert_called_once_with(3)
def test_receive_exactly_success(mock_socket: Mock) -> None:
"""Test receive_exactly successfully receives expected data."""
mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"]
result = espota2.receive_exactly(mock_socket, 3, "test", espota2.RESPONSE_OK)
assert result == [0, 1, 2]
assert mock_socket.recv.call_count == 2
def test_receive_exactly_with_error_response(mock_socket: Mock) -> None:
"""Test receive_exactly raises OTAError on error response."""
mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID])
with pytest.raises(espota2.OTAError, match="Error auth:.*Authentication invalid"):
espota2.receive_exactly(mock_socket, 1, "auth", [espota2.RESPONSE_OK])
mock_socket.close.assert_called_once()
def test_receive_exactly_socket_error(mock_socket: Mock) -> None:
"""Test receive_exactly handles socket errors."""
mock_socket.recv.side_effect = OSError("Connection reset")
with pytest.raises(espota2.OTAError, match="Error receiving acknowledge test"):
espota2.receive_exactly(mock_socket, 1, "test", espota2.RESPONSE_OK)
@pytest.mark.parametrize(
("error_code", "expected_msg"),
[
(espota2.RESPONSE_ERROR_MAGIC, "Error: Invalid magic byte"),
(espota2.RESPONSE_ERROR_UPDATE_PREPARE, "Error: Couldn't prepare flash memory"),
(espota2.RESPONSE_ERROR_AUTH_INVALID, "Error: Authentication invalid"),
(
espota2.RESPONSE_ERROR_WRITING_FLASH,
"Error: Wring OTA data to flash memory failed",
),
(espota2.RESPONSE_ERROR_UPDATE_END, "Error: Finishing update failed"),
(
espota2.RESPONSE_ERROR_INVALID_BOOTSTRAPPING,
"Error: Please press the reset button",
),
(
espota2.RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG,
"Error: ESP has been flashed with wrong flash size",
),
(
espota2.RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG,
"Error: ESP does not have the requested flash size",
),
(
espota2.RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE,
"Error: ESP does not have enough space",
),
(
espota2.RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE,
"Error: The OTA partition on the ESP is too small",
),
(
espota2.RESPONSE_ERROR_NO_UPDATE_PARTITION,
"Error: The OTA partition on the ESP couldn't be found",
),
(espota2.RESPONSE_ERROR_MD5_MISMATCH, "Error: Application MD5 code mismatch"),
(espota2.RESPONSE_ERROR_UNKNOWN, "Unknown error from ESP"),
],
)
def test_check_error_with_various_errors(error_code: int, expected_msg: str) -> None:
"""Test check_error raises appropriate errors for different error codes."""
with pytest.raises(espota2.OTAError, match=expected_msg):
espota2.check_error([error_code], [espota2.RESPONSE_OK])
def test_check_error_unexpected_response() -> None:
"""Test check_error raises error for unexpected response."""
with pytest.raises(espota2.OTAError, match="Unexpected response from ESP: 0x7F"):
espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK])
def test_send_check_with_various_data_types(mock_socket: Mock) -> None:
"""Test send_check handles different data types."""
# Test with list/tuple
espota2.send_check(mock_socket, [0x01, 0x02], "list")
mock_socket.sendall.assert_called_with(b"\x01\x02")
# Test with int
espota2.send_check(mock_socket, 0x42, "int")
mock_socket.sendall.assert_called_with(b"\x42")
# Test with string
espota2.send_check(mock_socket, "hello", "string")
mock_socket.sendall.assert_called_with(b"hello")
# Test with bytes (should pass through)
espota2.send_check(mock_socket, b"\xaa\xbb", "bytes")
mock_socket.sendall.assert_called_with(b"\xaa\xbb")
def test_send_check_socket_error(mock_socket: Mock) -> None:
"""Test send_check handles socket errors."""
mock_socket.sendall.side_effect = OSError("Broken pipe")
with pytest.raises(espota2.OTAError, match="Error sending test"):
espota2.send_check(mock_socket, b"data", "test")
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_successful_md5_auth(
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test successful OTA with MD5 authentication."""
# Setup socket responses for recv calls
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_REQUEST_AUTH]), # Auth request
MOCK_MD5_NONCE, # 32 char hex nonce
bytes([espota2.RESPONSE_AUTH_OK]), # Auth result
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses
# Run OTA
espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin")
# Verify magic bytes were sent
assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES))
# Verify features were sent (compression + SHA256 support)
assert mock_socket.sendall.call_args_list[1] == call(
bytes(
[
espota2.FEATURE_SUPPORTS_COMPRESSION
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
]
)
)
# Verify cnonce was sent (MD5 of random.random())
cnonce = hashlib.md5(MOCK_RANDOM_BYTES).hexdigest()
assert mock_socket.sendall.call_args_list[2] == call(cnonce.encode())
# Verify auth result was computed correctly
expected_hash = hashlib.md5()
expected_hash.update(b"testpass")
expected_hash.update(MOCK_MD5_NONCE)
expected_hash.update(cnonce.encode())
expected_result = expected_hash.hexdigest()
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_no_auth(mock_socket: Mock, mock_file: io.BytesIO) -> None:
"""Test OTA without authentication."""
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_1_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
# Should not send any auth-related data
auth_calls = [
call
for call in mock_socket.sendall.call_args_list
if "cnonce" in str(call) or "result" in str(call)
]
assert len(auth_calls) == 0
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_with_compression(mock_socket: Mock) -> None:
"""Test OTA with compression support."""
original_content = b"firmware" * 100 # Repeating content for compression
mock_file = io.BytesIO(original_content)
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_SUPPORTS_COMPRESSION]), # Device supports compression
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
# Verify compressed content was sent
# Get the binary size that was sent (4 bytes after features)
size_bytes = mock_socket.sendall.call_args_list[2][0][0]
sent_size = struct.unpack(">I", size_bytes)[0]
# Size should be less than original due to compression
assert sent_size < len(original_content)
# Verify the content sent was gzipped
compressed = gzip.compress(original_content, compresslevel=9)
assert sent_size == len(compressed)
def test_perform_ota_auth_without_password(mock_socket: Mock) -> None:
"""Test OTA fails when auth is required but no password provided."""
mock_file = io.BytesIO(b"firmware")
responses = [
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
bytes([espota2.RESPONSE_HEADER_OK]),
bytes([espota2.RESPONSE_REQUEST_AUTH]),
]
mock_socket.recv.side_effect = responses
with pytest.raises(
espota2.OTAError, match="ESP requests password, but no password given"
):
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_md5_auth_wrong_password(
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test OTA fails when MD5 authentication is rejected due to wrong password."""
# Setup socket responses for recv calls
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_REQUEST_AUTH]), # Auth request
MOCK_MD5_NONCE, # 32 char hex nonce
bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]), # Auth rejected!
]
mock_socket.recv.side_effect = recv_responses
with pytest.raises(espota2.OTAError, match="Error auth.*Authentication invalid"):
espota2.perform_ota(mock_socket, "wrongpassword", mock_file, "test.bin")
# Verify the socket was closed after auth failure
mock_socket.close.assert_called()
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_sha256_auth_wrong_password(
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test OTA fails when SHA256 authentication is rejected due to wrong password."""
# Setup socket responses for recv calls
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), # SHA256 Auth request
MOCK_SHA256_NONCE, # 64 char hex nonce
bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]), # Auth rejected!
]
mock_socket.recv.side_effect = recv_responses
with pytest.raises(espota2.OTAError, match="Error auth.*Authentication invalid"):
espota2.perform_ota(mock_socket, "wrongpassword", mock_file, "test.bin")
# Verify the socket was closed after auth failure
mock_socket.close.assert_called()
def test_perform_ota_sha256_auth_without_password(mock_socket: Mock) -> None:
"""Test OTA fails when SHA256 auth is required but no password provided."""
mock_file = io.BytesIO(b"firmware")
responses = [
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
bytes([espota2.RESPONSE_HEADER_OK]),
bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]),
]
mock_socket.recv.side_effect = responses
with pytest.raises(
espota2.OTAError, match="ESP requests password, but no password given"
):
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
def test_perform_ota_unexpected_auth_response(mock_socket: Mock) -> None:
"""Test OTA fails when device sends an unexpected auth response."""
mock_file = io.BytesIO(b"firmware")
# Use 0x03 which is not in the expected auth responses
# This will be caught by check_error and raise "Unexpected response from ESP"
UNKNOWN_AUTH_METHOD = 0x03
responses = [
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
bytes([espota2.RESPONSE_HEADER_OK]),
bytes([UNKNOWN_AUTH_METHOD]), # Unknown auth method
]
mock_socket.recv.side_effect = responses
# This will actually raise "Unexpected response from ESP" from check_error
with pytest.raises(
espota2.OTAError, match=r"Error auth: Unexpected response from ESP: 0x03"
):
espota2.perform_ota(mock_socket, "password", mock_file, "test.bin")
def test_perform_ota_unsupported_version(mock_socket: Mock) -> None:
"""Test OTA fails with unsupported version."""
mock_file = io.BytesIO(b"firmware")
responses = [
bytes([espota2.RESPONSE_OK, 99]), # Unsupported version
]
mock_socket.recv.side_effect = responses
with pytest.raises(espota2.OTAError, match="Device uses unsupported OTA version"):
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> None:
"""Test OTA handles upload errors."""
# Setup responses - provide enough for the recv calls
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
]
# Add OSError to recv to simulate connection loss during chunk read
recv_responses.append(OSError("Connection lost"))
mock_socket.recv.side_effect = recv_responses
with pytest.raises(espota2.OTAError, match="Error receiving acknowledge chunk OK"):
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip")
def test_run_ota_impl_successful(
mock_socket: Mock, tmp_path: Path, mock_perform_ota: Mock
) -> None:
"""Test run_ota_impl_ with successful upload."""
# Create a real firmware file
firmware_file = tmp_path / "firmware.bin"
firmware_file.write_bytes(b"firmware content")
# Run OTA with real file path
result_code, result_host = espota2.run_ota_impl_(
"test.local", 3232, "password", str(firmware_file)
)
# Verify success
assert result_code == 0
assert result_host == "192.168.1.100"
# Verify socket was configured correctly
mock_socket.settimeout.assert_called_with(10.0)
mock_socket.connect.assert_called_once_with(("192.168.1.100", 3232))
mock_socket.close.assert_called_once()
# Verify perform_ota was called with real file
mock_perform_ota.assert_called_once()
call_args = mock_perform_ota.call_args[0]
assert call_args[0] == mock_socket
assert call_args[1] == "password"
# Verify the file object is a proper file handle
assert isinstance(call_args[2], io.IOBase)
assert call_args[3] == str(firmware_file)
@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip")
def test_run_ota_impl_connection_failed(mock_socket: Mock, tmp_path: Path) -> None:
"""Test run_ota_impl_ when connection fails."""
mock_socket.connect.side_effect = OSError("Connection refused")
# Create a real firmware file
firmware_file = tmp_path / "firmware.bin"
firmware_file.write_bytes(b"firmware content")
result_code, result_host = espota2.run_ota_impl_(
"test.local", 3232, "password", str(firmware_file)
)
assert result_code == 1
assert result_host is None
mock_socket.close.assert_called_once()
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None:
"""Test run_ota_impl_ when DNS resolution fails."""
# Create a real firmware file
firmware_file = tmp_path / "firmware.bin"
firmware_file.write_bytes(b"firmware content")
mock_resolve_ip.side_effect = EsphomeError("DNS resolution failed")
with pytest.raises(espota2.OTAError, match="DNS resolution failed"):
result_code, result_host = espota2.run_ota_impl_(
"unknown.host", 3232, "password", str(firmware_file)
)
def test_run_ota_wrapper(mock_run_ota_impl: Mock) -> None:
"""Test run_ota wrapper function."""
# Test successful case
mock_run_ota_impl.return_value = (0, "192.168.1.100")
result = espota2.run_ota("test.local", 3232, "pass", "fw.bin")
assert result == (0, "192.168.1.100")
# Test error case
mock_run_ota_impl.side_effect = espota2.OTAError("Test error")
result = espota2.run_ota("test.local", 3232, "pass", "fw.bin")
assert result == (1, None)
def test_progress_bar(capsys: CaptureFixture[str]) -> None:
"""Test ProgressBar functionality."""
progress = espota2.ProgressBar()
# Test initial update
progress.update(0.0)
captured = capsys.readouterr()
assert "0%" in captured.err
assert "[" in captured.err
# Test progress update
progress.update(0.5)
captured = capsys.readouterr()
assert "50%" in captured.err
# Test completion
progress.update(1.0)
captured = capsys.readouterr()
assert "100%" in captured.err
assert "Done" in captured.err
# Test done method
progress.done()
captured = capsys.readouterr()
assert captured.err == "\n"
# Test same progress doesn't update
progress.update(0.5)
progress.update(0.5)
captured = capsys.readouterr()
# Should only see one update (second call shouldn't write)
assert captured.err.count("50%") == 1
# Tests for SHA256 authentication
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_successful_sha256_auth(
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test successful OTA with SHA256 authentication."""
# Setup socket responses for recv calls
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), # SHA256 Auth request
MOCK_SHA256_NONCE, # 64 char hex nonce
bytes([espota2.RESPONSE_AUTH_OK]), # Auth result
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses
# Run OTA
espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin")
# Verify magic bytes were sent
assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES))
# Verify features were sent (compression + SHA256 support)
assert mock_socket.sendall.call_args_list[1] == call(
bytes(
[
espota2.FEATURE_SUPPORTS_COMPRESSION
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
]
)
)
# Verify cnonce was sent (SHA256 of random.random())
cnonce = hashlib.sha256(MOCK_RANDOM_BYTES).hexdigest()
assert mock_socket.sendall.call_args_list[2] == call(cnonce.encode())
# Verify auth result was computed correctly with SHA256
expected_hash = hashlib.sha256()
expected_hash.update(b"testpass")
expected_hash.update(MOCK_SHA256_NONCE)
expected_hash.update(cnonce.encode())
expected_result = expected_hash.hexdigest()
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_sha256_fallback_to_md5(
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
) -> None:
"""Test SHA256-capable client falls back to MD5 for compatibility."""
# This test verifies the temporary backward compatibility
# where a SHA256-capable client can still authenticate with MD5
# This compatibility will be removed in 2026.1.0
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes(
[espota2.RESPONSE_REQUEST_AUTH]
), # MD5 Auth request (device doesn't support SHA256)
MOCK_MD5_NONCE, # 32 char hex nonce for MD5
bytes([espota2.RESPONSE_AUTH_OK]), # Auth result
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses
# Run OTA - should work even though device requested MD5
espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin")
# Verify client still advertised SHA256 support
assert mock_socket.sendall.call_args_list[1] == call(
bytes(
[
espota2.FEATURE_SUPPORTS_COMPRESSION
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
]
)
)
# But authentication was done with MD5
cnonce = hashlib.md5(MOCK_RANDOM_BYTES).hexdigest()
expected_hash = hashlib.md5()
expected_hash.update(b"testpass")
expected_hash.update(MOCK_MD5_NONCE)
expected_hash.update(cnonce.encode())
expected_result = expected_hash.hexdigest()
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
@pytest.mark.usefixtures("mock_time")
def test_perform_ota_version_differences(
mock_socket: Mock, mock_file: io.BytesIO
) -> None:
"""Test OTA behavior differences between version 1.0 and 2.0."""
# Test version 1.0 - no chunk acknowledgments
recv_responses = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_1_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
# No RESPONSE_CHUNK_OK for v1
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
# For v1.0, verify that we only get the expected number of recv calls
# v1.0 doesn't have chunk acknowledgments, so fewer recv calls
assert mock_socket.recv.call_count == 8 # v1.0 has 8 recv calls
# Reset mock for v2.0 test
mock_socket.reset_mock()
# Reset file position for second test
mock_file.seek(0)
# Test version 2.0 - with chunk acknowledgments
recv_responses_v2 = [
bytes([espota2.RESPONSE_OK]), # First byte of version response
bytes([espota2.OTA_VERSION_2_0]), # Version number
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
bytes([espota2.RESPONSE_CHUNK_OK]), # v2.0 has chunk acknowledgment
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
]
mock_socket.recv.side_effect = recv_responses_v2
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
# For v2.0, verify more recv calls due to chunk acknowledgments
assert mock_socket.recv.call_count == 9 # v2.0 has 9 recv calls (includes chunk OK)