mirror of
https://github.com/esphome/esphome.git
synced 2025-09-22 05:02:23 +01:00
Merge branch 'integration' into memory_api
This commit is contained in:
@@ -407,6 +407,7 @@ esphome/components/sensor/* @esphome/core
|
|||||||
esphome/components/sfa30/* @ghsensdev
|
esphome/components/sfa30/* @ghsensdev
|
||||||
esphome/components/sgp40/* @SenexCrenshaw
|
esphome/components/sgp40/* @SenexCrenshaw
|
||||||
esphome/components/sgp4x/* @martgras @SenexCrenshaw
|
esphome/components/sgp4x/* @martgras @SenexCrenshaw
|
||||||
|
esphome/components/sha256/* @esphome/core
|
||||||
esphome/components/shelly_dimmer/* @edge90 @rnauber
|
esphome/components/shelly_dimmer/* @edge90 @rnauber
|
||||||
esphome/components/sht3xd/* @mrtoy-me
|
esphome/components/sht3xd/* @mrtoy-me
|
||||||
esphome/components/sht4x/* @sjtrny
|
esphome/components/sht4x/* @sjtrny
|
||||||
|
@@ -16,7 +16,7 @@ from esphome.const import (
|
|||||||
CONF_SAFE_MODE,
|
CONF_SAFE_MODE,
|
||||||
CONF_VERSION,
|
CONF_VERSION,
|
||||||
)
|
)
|
||||||
from esphome.core import coroutine_with_priority
|
from esphome.core import CORE, coroutine_with_priority
|
||||||
from esphome.coroutine import CoroPriority
|
from esphome.coroutine import CoroPriority
|
||||||
import esphome.final_validate as fv
|
import esphome.final_validate as fv
|
||||||
|
|
||||||
@@ -24,9 +24,22 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
CODEOWNERS = ["@esphome/core"]
|
CODEOWNERS = ["@esphome/core"]
|
||||||
AUTO_LOAD = ["md5", "socket"]
|
|
||||||
DEPENDENCIES = ["network"]
|
DEPENDENCIES = ["network"]
|
||||||
|
|
||||||
|
|
||||||
|
def supports_sha256() -> bool:
|
||||||
|
"""Check if the current platform supports SHA256 for OTA authentication."""
|
||||||
|
return bool(CORE.is_esp32 or CORE.is_esp8266 or CORE.is_rp2040 or CORE.is_libretiny)
|
||||||
|
|
||||||
|
|
||||||
|
def AUTO_LOAD() -> list[str]:
|
||||||
|
"""Conditionally auto-load sha256 only on platforms that support it."""
|
||||||
|
base_components = ["md5", "socket"]
|
||||||
|
if supports_sha256():
|
||||||
|
return base_components + ["sha256"]
|
||||||
|
return base_components
|
||||||
|
|
||||||
|
|
||||||
esphome = cg.esphome_ns.namespace("esphome")
|
esphome = cg.esphome_ns.namespace("esphome")
|
||||||
ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent)
|
ESPHomeOTAComponent = esphome.class_("ESPHomeOTAComponent", OTAComponent)
|
||||||
|
|
||||||
@@ -126,6 +139,11 @@ FINAL_VALIDATE_SCHEMA = ota_esphome_final_validate
|
|||||||
async def to_code(config):
|
async def to_code(config):
|
||||||
var = cg.new_Pvariable(config[CONF_ID])
|
var = cg.new_Pvariable(config[CONF_ID])
|
||||||
cg.add(var.set_port(config[CONF_PORT]))
|
cg.add(var.set_port(config[CONF_PORT]))
|
||||||
|
|
||||||
|
# Only include SHA256 support on platforms that have it
|
||||||
|
if supports_sha256():
|
||||||
|
cg.add_define("USE_OTA_SHA256")
|
||||||
|
|
||||||
if CONF_PASSWORD in config:
|
if CONF_PASSWORD in config:
|
||||||
cg.add(var.set_auth_password(config[CONF_PASSWORD]))
|
cg.add(var.set_auth_password(config[CONF_PASSWORD]))
|
||||||
cg.add_define("USE_OTA_PASSWORD")
|
cg.add_define("USE_OTA_PASSWORD")
|
||||||
|
@@ -1,6 +1,9 @@
|
|||||||
#include "ota_esphome.h"
|
#include "ota_esphome.h"
|
||||||
#ifdef USE_OTA
|
#ifdef USE_OTA
|
||||||
#include "esphome/components/md5/md5.h"
|
#include "esphome/components/md5/md5.h"
|
||||||
|
#ifdef USE_OTA_SHA256
|
||||||
|
#include "esphome/components/sha256/sha256.h"
|
||||||
|
#endif
|
||||||
#include "esphome/components/network/util.h"
|
#include "esphome/components/network/util.h"
|
||||||
#include "esphome/components/ota/ota_backend.h"
|
#include "esphome/components/ota/ota_backend.h"
|
||||||
#include "esphome/components/ota/ota_backend_arduino_esp32.h"
|
#include "esphome/components/ota/ota_backend_arduino_esp32.h"
|
||||||
@@ -95,6 +98,33 @@ void ESPHomeOTAComponent::loop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01;
|
static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01;
|
||||||
|
#ifdef USE_OTA_SHA256
|
||||||
|
static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Temporary flag to allow MD5 downgrade for ~3 versions (until 2026.1.0)
|
||||||
|
// This allows users to downgrade via OTA if they encounter issues after updating.
|
||||||
|
// Without this, users would need to do a serial flash to downgrade.
|
||||||
|
// TODO: Remove this flag and all associated code in 2026.1.0
|
||||||
|
#define ALLOW_OTA_DOWNGRADE_MD5
|
||||||
|
|
||||||
|
template<typename HashClass> struct HashTraits;
|
||||||
|
|
||||||
|
template<> struct HashTraits<md5::MD5Digest> {
|
||||||
|
static constexpr int NONCE_SIZE = 8;
|
||||||
|
static constexpr int HEX_SIZE = 32;
|
||||||
|
static constexpr const char *NAME = "MD5";
|
||||||
|
static constexpr ota::OTAResponseTypes AUTH_REQUEST = ota::OTA_RESPONSE_REQUEST_AUTH;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef USE_OTA_SHA256
|
||||||
|
template<> struct HashTraits<sha256::SHA256> {
|
||||||
|
static constexpr int NONCE_SIZE = 16;
|
||||||
|
static constexpr int HEX_SIZE = 64;
|
||||||
|
static constexpr const char *NAME = "SHA256";
|
||||||
|
static constexpr ota::OTAResponseTypes AUTH_REQUEST = ota::OTA_RESPONSE_REQUEST_SHA256_AUTH;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
void ESPHomeOTAComponent::handle_handshake_() {
|
void ESPHomeOTAComponent::handle_handshake_() {
|
||||||
/// Handle the initial OTA handshake.
|
/// Handle the initial OTA handshake.
|
||||||
@@ -225,57 +255,55 @@ void ESPHomeOTAComponent::handle_data_() {
|
|||||||
|
|
||||||
#ifdef USE_OTA_PASSWORD
|
#ifdef USE_OTA_PASSWORD
|
||||||
if (!this->password_.empty()) {
|
if (!this->password_.empty()) {
|
||||||
buf[0] = ota::OTA_RESPONSE_REQUEST_AUTH;
|
bool auth_success = false;
|
||||||
this->writeall_(buf, 1);
|
|
||||||
md5::MD5Digest md5{};
|
|
||||||
md5.init();
|
|
||||||
sprintf(sbuf, "%08" PRIx32, random_uint32());
|
|
||||||
md5.add(sbuf, 8);
|
|
||||||
md5.calculate();
|
|
||||||
md5.get_hex(sbuf);
|
|
||||||
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
|
|
||||||
|
|
||||||
// Send nonce, 32 bytes hex MD5
|
#ifdef USE_OTA_SHA256
|
||||||
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
|
// SECURITY HARDENING: Prefer SHA256 authentication on platforms that support it.
|
||||||
ESP_LOGW(TAG, "Auth: Writing nonce failed");
|
//
|
||||||
|
// This is a hardening measure to prevent future downgrade attacks where an attacker
|
||||||
|
// could force the use of MD5 authentication by manipulating the feature flags.
|
||||||
|
//
|
||||||
|
// While MD5 is currently still acceptable for our OTA authentication use case
|
||||||
|
// (where the password is a shared secret and we're only authenticating, not
|
||||||
|
// encrypting), at some point in the future MD5 will likely become so weak that
|
||||||
|
// it could be practically attacked.
|
||||||
|
//
|
||||||
|
// We enforce SHA256 now on capable platforms because:
|
||||||
|
// 1. We can't retroactively update device firmware in the field
|
||||||
|
// 2. Clients (like esphome CLI) can always be updated to support SHA256
|
||||||
|
// 3. This prevents any possibility of downgrade attacks in the future
|
||||||
|
//
|
||||||
|
// Devices that don't support SHA256 (due to platform limitations) will
|
||||||
|
// continue to use MD5 as their only option (see #else branch below).
|
||||||
|
|
||||||
|
bool client_supports_sha256 = (ota_features & FEATURE_SUPPORTS_SHA256_AUTH) != 0;
|
||||||
|
|
||||||
|
#ifdef ALLOW_OTA_DOWNGRADE_MD5
|
||||||
|
// Temporary compatibility mode: Allow MD5 for ~3 versions to enable OTA downgrades
|
||||||
|
// This prevents users from being locked out if they need to downgrade after updating
|
||||||
|
// TODO: Remove this entire ifdef block in 2026.1.0
|
||||||
|
if (client_supports_sha256) {
|
||||||
|
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
|
||||||
|
} else {
|
||||||
|
ESP_LOGW(TAG, "Using MD5 auth for compatibility (deprecated)");
|
||||||
|
auth_success = this->perform_hash_auth_<md5::MD5Digest>(this->password_);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// Strict mode: SHA256 required on capable platforms (future default)
|
||||||
|
if (!client_supports_sha256) {
|
||||||
|
ESP_LOGW(TAG, "Client requires SHA256");
|
||||||
|
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
|
||||||
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
||||||
}
|
}
|
||||||
|
auth_success = this->perform_hash_auth_<sha256::SHA256>(this->password_);
|
||||||
|
#endif // ALLOW_OTA_DOWNGRADE_MD5
|
||||||
|
#else
|
||||||
|
// Platform only supports MD5 - use it as the only available option
|
||||||
|
// This is not a security downgrade as the platform cannot support SHA256
|
||||||
|
auth_success = this->perform_hash_auth_<md5::MD5Digest>(this->password_);
|
||||||
|
#endif // USE_OTA_SHA256
|
||||||
|
|
||||||
// prepare challenge
|
if (!auth_success) {
|
||||||
md5.init();
|
|
||||||
md5.add(this->password_.c_str(), this->password_.length());
|
|
||||||
// add nonce
|
|
||||||
md5.add(sbuf, 32);
|
|
||||||
|
|
||||||
// Receive cnonce, 32 bytes hex MD5
|
|
||||||
if (!this->readall_(buf, 32)) {
|
|
||||||
ESP_LOGW(TAG, "Auth: Reading cnonce failed");
|
|
||||||
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
|
||||||
}
|
|
||||||
sbuf[32] = '\0';
|
|
||||||
ESP_LOGV(TAG, "Auth: CNonce is %s", sbuf);
|
|
||||||
// add cnonce
|
|
||||||
md5.add(sbuf, 32);
|
|
||||||
|
|
||||||
// calculate result
|
|
||||||
md5.calculate();
|
|
||||||
md5.get_hex(sbuf);
|
|
||||||
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
|
|
||||||
|
|
||||||
// Receive result, 32 bytes hex MD5
|
|
||||||
if (!this->readall_(buf + 64, 32)) {
|
|
||||||
ESP_LOGW(TAG, "Auth: Reading response failed");
|
|
||||||
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
|
||||||
}
|
|
||||||
sbuf[64 + 32] = '\0';
|
|
||||||
ESP_LOGV(TAG, "Auth: Response is %s", sbuf + 64);
|
|
||||||
|
|
||||||
bool matches = true;
|
|
||||||
for (uint8_t i = 0; i < 32; i++)
|
|
||||||
matches = matches && buf[i] == buf[64 + i];
|
|
||||||
|
|
||||||
if (!matches) {
|
|
||||||
ESP_LOGW(TAG, "Auth failed! Passwords do not match");
|
|
||||||
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
|
error_code = ota::OTA_RESPONSE_ERROR_AUTH_INVALID;
|
||||||
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
goto error; // NOLINT(cppcoreguidelines-avoid-goto)
|
||||||
}
|
}
|
||||||
@@ -499,5 +527,110 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() {
|
|||||||
delay(1);
|
delay(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Template function definition - placed at end to ensure all types are complete
|
||||||
|
template<typename HashClass> bool ESPHomeOTAComponent::perform_hash_auth_(const std::string &password) {
|
||||||
|
using Traits = HashTraits<HashClass>;
|
||||||
|
|
||||||
|
// Minimize stack usage by reusing buffers
|
||||||
|
// We only need 2 buffers at most at the same time
|
||||||
|
constexpr size_t hex_buffer_size = Traits::HEX_SIZE + 1;
|
||||||
|
|
||||||
|
// These two buffers are reused throughout the function
|
||||||
|
char hex_buffer1[hex_buffer_size]; // Used for: nonce -> expected result
|
||||||
|
char hex_buffer2[hex_buffer_size]; // Used for: cnonce -> response
|
||||||
|
|
||||||
|
// Small stack buffer for auth request and nonce seed bytes
|
||||||
|
uint8_t buf[1];
|
||||||
|
uint8_t nonce_bytes[8]; // Max 8 bytes (2 x uint32_t for SHA256)
|
||||||
|
|
||||||
|
// Send auth request type
|
||||||
|
buf[0] = Traits::AUTH_REQUEST;
|
||||||
|
this->writeall_(buf, 1);
|
||||||
|
|
||||||
|
HashClass hasher;
|
||||||
|
hasher.init();
|
||||||
|
|
||||||
|
// Generate nonce seed bytes
|
||||||
|
uint32_t r1 = random_uint32();
|
||||||
|
// Convert first uint32 to bytes (always needed for MD5)
|
||||||
|
nonce_bytes[0] = (r1 >> 24) & 0xFF;
|
||||||
|
nonce_bytes[1] = (r1 >> 16) & 0xFF;
|
||||||
|
nonce_bytes[2] = (r1 >> 8) & 0xFF;
|
||||||
|
nonce_bytes[3] = r1 & 0xFF;
|
||||||
|
|
||||||
|
if (Traits::NONCE_SIZE == 8) {
|
||||||
|
// MD5: 8 chars = "%08x" format = 4 bytes from one random uint32
|
||||||
|
hasher.add(nonce_bytes, 4);
|
||||||
|
}
|
||||||
|
#ifdef USE_OTA_SHA256
|
||||||
|
else {
|
||||||
|
// SHA256: 16 chars = "%08x%08x" format = 8 bytes from two random uint32s
|
||||||
|
uint32_t r2 = random_uint32();
|
||||||
|
nonce_bytes[4] = (r2 >> 24) & 0xFF;
|
||||||
|
nonce_bytes[5] = (r2 >> 16) & 0xFF;
|
||||||
|
nonce_bytes[6] = (r2 >> 8) & 0xFF;
|
||||||
|
nonce_bytes[7] = r2 & 0xFF;
|
||||||
|
hasher.add(nonce_bytes, 8);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
hasher.calculate();
|
||||||
|
|
||||||
|
// Use hex_buffer1 for nonce
|
||||||
|
hasher.get_hex(hex_buffer1);
|
||||||
|
hex_buffer1[Traits::HEX_SIZE] = '\0';
|
||||||
|
ESP_LOGV(TAG, "Auth: %s Nonce is %s", Traits::NAME, hex_buffer1);
|
||||||
|
|
||||||
|
// Send nonce
|
||||||
|
if (!this->writeall_(reinterpret_cast<uint8_t *>(hex_buffer1), Traits::HEX_SIZE)) {
|
||||||
|
ESP_LOGW(TAG, "Auth: Writing %s nonce failed", Traits::NAME);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare challenge
|
||||||
|
hasher.init();
|
||||||
|
hasher.add(password.c_str(), password.length());
|
||||||
|
hasher.add(hex_buffer1, Traits::HEX_SIZE); // Add nonce
|
||||||
|
|
||||||
|
// Receive cnonce into hex_buffer2
|
||||||
|
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), Traits::HEX_SIZE)) {
|
||||||
|
ESP_LOGW(TAG, "Auth: Reading %s cnonce failed", Traits::NAME);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
hex_buffer2[Traits::HEX_SIZE] = '\0';
|
||||||
|
ESP_LOGV(TAG, "Auth: %s CNonce is %s", Traits::NAME, hex_buffer2);
|
||||||
|
|
||||||
|
// Add cnonce to hash
|
||||||
|
hasher.add(hex_buffer2, Traits::HEX_SIZE);
|
||||||
|
|
||||||
|
// Calculate result - reuse hex_buffer1 for expected
|
||||||
|
hasher.calculate();
|
||||||
|
hasher.get_hex(hex_buffer1);
|
||||||
|
hex_buffer1[Traits::HEX_SIZE] = '\0';
|
||||||
|
ESP_LOGV(TAG, "Auth: %s Result is %s", Traits::NAME, hex_buffer1);
|
||||||
|
|
||||||
|
// Receive response - reuse hex_buffer2
|
||||||
|
if (!this->readall_(reinterpret_cast<uint8_t *>(hex_buffer2), Traits::HEX_SIZE)) {
|
||||||
|
ESP_LOGW(TAG, "Auth: Reading %s response failed", Traits::NAME);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
hex_buffer2[Traits::HEX_SIZE] = '\0';
|
||||||
|
ESP_LOGV(TAG, "Auth: %s Response is %s", Traits::NAME, hex_buffer2);
|
||||||
|
|
||||||
|
// Compare
|
||||||
|
bool matches = memcmp(hex_buffer1, hex_buffer2, Traits::HEX_SIZE) == 0;
|
||||||
|
|
||||||
|
if (!matches) {
|
||||||
|
ESP_LOGW(TAG, "Auth failed! %s passwords do not match", Traits::NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit template instantiations
|
||||||
|
template bool ESPHomeOTAComponent::perform_hash_auth_<md5::MD5Digest>(const std::string &);
|
||||||
|
#ifdef USE_OTA_SHA256
|
||||||
|
template bool ESPHomeOTAComponent::perform_hash_auth_<sha256::SHA256>(const std::string &);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace esphome
|
} // namespace esphome
|
||||||
#endif
|
#endif
|
||||||
|
@@ -30,6 +30,7 @@ class ESPHomeOTAComponent : public ota::OTAComponent {
|
|||||||
protected:
|
protected:
|
||||||
void handle_handshake_();
|
void handle_handshake_();
|
||||||
void handle_data_();
|
void handle_data_();
|
||||||
|
template<typename HashClass> bool perform_hash_auth_(const std::string &password);
|
||||||
bool readall_(uint8_t *buf, size_t len);
|
bool readall_(uint8_t *buf, size_t len);
|
||||||
bool writeall_(const uint8_t *buf, size_t len);
|
bool writeall_(const uint8_t *buf, size_t len);
|
||||||
void log_socket_error_(const LogString *msg);
|
void log_socket_error_(const LogString *msg);
|
||||||
|
@@ -14,6 +14,7 @@ namespace ota {
|
|||||||
enum OTAResponseTypes {
|
enum OTAResponseTypes {
|
||||||
OTA_RESPONSE_OK = 0x00,
|
OTA_RESPONSE_OK = 0x00,
|
||||||
OTA_RESPONSE_REQUEST_AUTH = 0x01,
|
OTA_RESPONSE_REQUEST_AUTH = 0x01,
|
||||||
|
OTA_RESPONSE_REQUEST_SHA256_AUTH = 0x02,
|
||||||
|
|
||||||
OTA_RESPONSE_HEADER_OK = 0x40,
|
OTA_RESPONSE_HEADER_OK = 0x40,
|
||||||
OTA_RESPONSE_AUTH_OK = 0x41,
|
OTA_RESPONSE_AUTH_OK = 0x41,
|
||||||
|
24
esphome/components/sha256/__init__.py
Normal file
24
esphome/components/sha256/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import esphome.codegen as cg
|
||||||
|
import esphome.config_validation as cv
|
||||||
|
from esphome.core import CORE
|
||||||
|
from esphome.helpers import IS_MACOS
|
||||||
|
from esphome.types import ConfigType
|
||||||
|
|
||||||
|
CODEOWNERS = ["@esphome/core"]
|
||||||
|
|
||||||
|
sha256_ns = cg.esphome_ns.namespace("sha256")
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = cv.Schema({})
|
||||||
|
|
||||||
|
|
||||||
|
async def to_code(config: ConfigType) -> None:
|
||||||
|
# Add OpenSSL library for host platform
|
||||||
|
if CORE.is_host:
|
||||||
|
if IS_MACOS:
|
||||||
|
# macOS needs special handling for Homebrew OpenSSL
|
||||||
|
cg.add_build_flag("-I/opt/homebrew/opt/openssl/include")
|
||||||
|
cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib")
|
||||||
|
cg.add_build_flag("-lcrypto")
|
||||||
|
else:
|
||||||
|
# Linux and other Unix systems usually have OpenSSL in standard paths
|
||||||
|
cg.add_build_flag("-lcrypto")
|
160
esphome/components/sha256/sha256.cpp
Normal file
160
esphome/components/sha256/sha256.cpp
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
#include "sha256.h"
|
||||||
|
|
||||||
|
// Only compile SHA256 implementation on platforms that support it
|
||||||
|
#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST)
|
||||||
|
|
||||||
|
#include "esphome/core/helpers.h"
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
namespace esphome::sha256 {
|
||||||
|
|
||||||
|
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
|
||||||
|
|
||||||
|
SHA256::~SHA256() {
|
||||||
|
if (this->ctx_) {
|
||||||
|
mbedtls_sha256_free(&this->ctx_->ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::init() {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->ctx_ = std::make_unique<SHA256Context>();
|
||||||
|
}
|
||||||
|
mbedtls_sha256_init(&this->ctx_->ctx);
|
||||||
|
mbedtls_sha256_starts(&this->ctx_->ctx, 0); // 0 = SHA256, not SHA224
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::add(const uint8_t *data, size_t len) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->init();
|
||||||
|
}
|
||||||
|
mbedtls_sha256_update(&this->ctx_->ctx, data, len);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::calculate() {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->init();
|
||||||
|
}
|
||||||
|
mbedtls_sha256_finish(&this->ctx_->ctx, this->ctx_->hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif defined(USE_ESP8266) || defined(USE_RP2040)
|
||||||
|
|
||||||
|
SHA256::~SHA256() = default;
|
||||||
|
|
||||||
|
void SHA256::init() {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->ctx_ = std::make_unique<SHA256Context>();
|
||||||
|
}
|
||||||
|
br_sha256_init(&this->ctx_->ctx);
|
||||||
|
this->ctx_->calculated = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::add(const uint8_t *data, size_t len) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->init();
|
||||||
|
}
|
||||||
|
br_sha256_update(&this->ctx_->ctx, data, len);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::calculate() {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->init();
|
||||||
|
}
|
||||||
|
if (!this->ctx_->calculated) {
|
||||||
|
br_sha256_out(&this->ctx_->ctx, this->ctx_->hash);
|
||||||
|
this->ctx_->calculated = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif defined(USE_HOST)
|
||||||
|
|
||||||
|
SHA256::~SHA256() {
|
||||||
|
if (this->ctx_ && this->ctx_->ctx) {
|
||||||
|
EVP_MD_CTX_free(this->ctx_->ctx);
|
||||||
|
this->ctx_->ctx = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::init() {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->ctx_ = std::make_unique<SHA256Context>();
|
||||||
|
}
|
||||||
|
if (this->ctx_->ctx) {
|
||||||
|
EVP_MD_CTX_free(this->ctx_->ctx);
|
||||||
|
}
|
||||||
|
this->ctx_->ctx = EVP_MD_CTX_new();
|
||||||
|
EVP_DigestInit_ex(this->ctx_->ctx, EVP_sha256(), nullptr);
|
||||||
|
this->ctx_->calculated = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::add(const uint8_t *data, size_t len) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->init();
|
||||||
|
}
|
||||||
|
EVP_DigestUpdate(this->ctx_->ctx, data, len);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::calculate() {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
this->init();
|
||||||
|
}
|
||||||
|
if (!this->ctx_->calculated) {
|
||||||
|
unsigned int len = 32;
|
||||||
|
EVP_DigestFinal_ex(this->ctx_->ctx, this->ctx_->hash, &len);
|
||||||
|
this->ctx_->calculated = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
#error "SHA256 not supported on this platform"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void SHA256::get_bytes(uint8_t *output) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
memset(output, 0, 32);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
memcpy(output, this->ctx_->hash, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SHA256::get_hex(char *output) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
memset(output, '0', 64);
|
||||||
|
output[64] = '\0';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < 32; i++) {
|
||||||
|
uint8_t byte = this->ctx_->hash[i];
|
||||||
|
output[i * 2] = format_hex_char(byte >> 4);
|
||||||
|
output[i * 2 + 1] = format_hex_char(byte & 0x0F);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SHA256::get_hex_string() {
|
||||||
|
char buf[65];
|
||||||
|
this->get_hex(buf);
|
||||||
|
return std::string(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SHA256::equals_bytes(const uint8_t *expected) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return memcmp(this->ctx_->hash, expected, 32) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SHA256::equals_hex(const char *expected) {
|
||||||
|
if (!this->ctx_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
uint8_t parsed[32];
|
||||||
|
if (!parse_hex(expected, parsed, 32)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return this->equals_bytes(parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace esphome::sha256
|
||||||
|
|
||||||
|
#endif // Platform check
|
69
esphome/components/sha256/sha256.h
Normal file
69
esphome/components/sha256/sha256.h
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "esphome/core/defines.h"
|
||||||
|
|
||||||
|
// Only define SHA256 on platforms that support it
|
||||||
|
#if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) || defined(USE_HOST)
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
|
||||||
|
#include "mbedtls/sha256.h"
|
||||||
|
#elif defined(USE_ESP8266) || defined(USE_RP2040)
|
||||||
|
#include <bearssl/bearssl_hash.h>
|
||||||
|
#elif defined(USE_HOST)
|
||||||
|
#include <openssl/evp.h>
|
||||||
|
#else
|
||||||
|
#error "SHA256 not supported on this platform"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace esphome::sha256 {
|
||||||
|
|
||||||
|
class SHA256 {
|
||||||
|
public:
|
||||||
|
SHA256() = default;
|
||||||
|
~SHA256();
|
||||||
|
|
||||||
|
void init();
|
||||||
|
void add(const uint8_t *data, size_t len);
|
||||||
|
void add(const char *data, size_t len) { this->add((const uint8_t *) data, len); }
|
||||||
|
void add(const std::string &data) { this->add(data.c_str(), data.length()); }
|
||||||
|
|
||||||
|
void calculate();
|
||||||
|
|
||||||
|
void get_bytes(uint8_t *output);
|
||||||
|
void get_hex(char *output);
|
||||||
|
std::string get_hex_string();
|
||||||
|
|
||||||
|
bool equals_bytes(const uint8_t *expected);
|
||||||
|
bool equals_hex(const char *expected);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
|
||||||
|
struct SHA256Context {
|
||||||
|
mbedtls_sha256_context ctx;
|
||||||
|
uint8_t hash[32];
|
||||||
|
};
|
||||||
|
#elif defined(USE_ESP8266) || defined(USE_RP2040)
|
||||||
|
struct SHA256Context {
|
||||||
|
br_sha256_context ctx;
|
||||||
|
uint8_t hash[32];
|
||||||
|
bool calculated{false};
|
||||||
|
};
|
||||||
|
#elif defined(USE_HOST)
|
||||||
|
struct SHA256Context {
|
||||||
|
EVP_MD_CTX *ctx{nullptr};
|
||||||
|
uint8_t hash[32];
|
||||||
|
bool calculated{false};
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
#error "SHA256 not supported on this platform"
|
||||||
|
#endif
|
||||||
|
std::unique_ptr<SHA256Context> ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace esphome::sha256
|
||||||
|
|
||||||
|
#endif // Platform check
|
@@ -116,6 +116,7 @@
|
|||||||
#define USE_API_PLAINTEXT
|
#define USE_API_PLAINTEXT
|
||||||
#define USE_API_SERVICES
|
#define USE_API_SERVICES
|
||||||
#define USE_MD5
|
#define USE_MD5
|
||||||
|
#define USE_SHA256
|
||||||
#define USE_MQTT
|
#define USE_MQTT
|
||||||
#define USE_NETWORK
|
#define USE_NETWORK
|
||||||
#define USE_ONLINE_IMAGE_BMP_SUPPORT
|
#define USE_ONLINE_IMAGE_BMP_SUPPORT
|
||||||
|
@@ -82,6 +82,16 @@ template<typename T> constexpr T byteswap(T n) {
|
|||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
template<> constexpr uint8_t byteswap(uint8_t n) { return n; }
|
template<> constexpr uint8_t byteswap(uint8_t n) { return n; }
|
||||||
|
#ifdef USE_LIBRETINY
|
||||||
|
// LibreTiny's Beken framework redefines __builtin_bswap functions as non-constexpr
|
||||||
|
template<> inline uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
|
||||||
|
template<> inline uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
|
||||||
|
template<> inline uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
|
||||||
|
template<> inline int8_t byteswap(int8_t n) { return n; }
|
||||||
|
template<> inline int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
|
||||||
|
template<> inline int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
|
||||||
|
template<> inline int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
|
||||||
|
#else
|
||||||
template<> constexpr uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
|
template<> constexpr uint16_t byteswap(uint16_t n) { return __builtin_bswap16(n); }
|
||||||
template<> constexpr uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
|
template<> constexpr uint32_t byteswap(uint32_t n) { return __builtin_bswap32(n); }
|
||||||
template<> constexpr uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
|
template<> constexpr uint64_t byteswap(uint64_t n) { return __builtin_bswap64(n); }
|
||||||
@@ -89,6 +99,7 @@ template<> constexpr int8_t byteswap(int8_t n) { return n; }
|
|||||||
template<> constexpr int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
|
template<> constexpr int16_t byteswap(int16_t n) { return __builtin_bswap16(n); }
|
||||||
template<> constexpr int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
|
template<> constexpr int32_t byteswap(int32_t n) { return __builtin_bswap32(n); }
|
||||||
template<> constexpr int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
|
template<> constexpr int64_t byteswap(int64_t n) { return __builtin_bswap64(n); }
|
||||||
|
#endif
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
@@ -9,12 +10,14 @@ import random
|
|||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from esphome.core import EsphomeError
|
from esphome.core import EsphomeError
|
||||||
from esphome.helpers import resolve_ip_address
|
from esphome.helpers import resolve_ip_address
|
||||||
|
|
||||||
RESPONSE_OK = 0x00
|
RESPONSE_OK = 0x00
|
||||||
RESPONSE_REQUEST_AUTH = 0x01
|
RESPONSE_REQUEST_AUTH = 0x01
|
||||||
|
RESPONSE_REQUEST_SHA256_AUTH = 0x02
|
||||||
|
|
||||||
RESPONSE_HEADER_OK = 0x40
|
RESPONSE_HEADER_OK = 0x40
|
||||||
RESPONSE_AUTH_OK = 0x41
|
RESPONSE_AUTH_OK = 0x41
|
||||||
@@ -45,6 +48,7 @@ OTA_VERSION_2_0 = 2
|
|||||||
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
|
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
|
||||||
|
|
||||||
FEATURE_SUPPORTS_COMPRESSION = 0x01
|
FEATURE_SUPPORTS_COMPRESSION = 0x01
|
||||||
|
FEATURE_SUPPORTS_SHA256_AUTH = 0x02
|
||||||
|
|
||||||
|
|
||||||
UPLOAD_BLOCK_SIZE = 8192
|
UPLOAD_BLOCK_SIZE = 8192
|
||||||
@@ -52,6 +56,12 @@ UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Authentication method lookup table: response -> (hash_func, nonce_size, name)
|
||||||
|
_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = {
|
||||||
|
RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"),
|
||||||
|
RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -81,18 +91,43 @@ class OTAError(EsphomeError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def recv_decode(sock, amount, decode=True):
|
def recv_decode(
|
||||||
|
sock: socket.socket, amount: int, decode: bool = True
|
||||||
|
) -> bytes | list[int]:
|
||||||
|
"""Receive data from socket and optionally decode to list of integers.
|
||||||
|
|
||||||
|
:param sock: Socket to receive data from.
|
||||||
|
:param amount: Number of bytes to receive.
|
||||||
|
:param decode: If True, convert bytes to list of integers, otherwise return raw bytes.
|
||||||
|
:return: List of integers if decode=True, otherwise raw bytes.
|
||||||
|
"""
|
||||||
data = sock.recv(amount)
|
data = sock.recv(amount)
|
||||||
if not decode:
|
if not decode:
|
||||||
return data
|
return data
|
||||||
return list(data)
|
return list(data)
|
||||||
|
|
||||||
|
|
||||||
def receive_exactly(sock, amount, msg, expect, decode=True):
|
def receive_exactly(
|
||||||
data = [] if decode else b""
|
sock: socket.socket,
|
||||||
|
amount: int,
|
||||||
|
msg: str,
|
||||||
|
expect: int | list[int] | None,
|
||||||
|
decode: bool = True,
|
||||||
|
) -> list[int] | bytes:
|
||||||
|
"""Receive exactly the specified amount of data from socket with error checking.
|
||||||
|
|
||||||
|
:param sock: Socket to receive data from.
|
||||||
|
:param amount: Exact number of bytes to receive.
|
||||||
|
:param msg: Description of what is being received for error messages.
|
||||||
|
:param expect: Expected response code(s) for validation, None to skip validation.
|
||||||
|
:param decode: If True, return list of integers, otherwise return raw bytes.
|
||||||
|
:return: List of integers if decode=True, otherwise raw bytes.
|
||||||
|
:raises OTAError: If receiving fails or response doesn't match expected.
|
||||||
|
"""
|
||||||
|
data: list[int] | bytes = [] if decode else b""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data += recv_decode(sock, 1, decode=decode)
|
data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
|
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
|
||||||
|
|
||||||
@@ -104,13 +139,19 @@ def receive_exactly(sock, amount, msg, expect, decode=True):
|
|||||||
|
|
||||||
while len(data) < amount:
|
while len(data) < amount:
|
||||||
try:
|
try:
|
||||||
data += recv_decode(sock, amount - len(data), decode=decode)
|
data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise OTAError(f"Error receiving {msg}: {err}") from err
|
raise OTAError(f"Error receiving {msg}: {err}") from err
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def check_error(data, expect):
|
def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
|
||||||
|
"""Check response data for error codes and validate against expected response.
|
||||||
|
|
||||||
|
:param data: Response data from device (first byte is the response code).
|
||||||
|
:param expect: Expected response code(s), None to skip validation.
|
||||||
|
:raises OTAError: If an error code is detected or response doesn't match expected.
|
||||||
|
"""
|
||||||
if not expect:
|
if not expect:
|
||||||
return
|
return
|
||||||
dat = data[0]
|
dat = data[0]
|
||||||
@@ -177,7 +218,16 @@ def check_error(data, expect):
|
|||||||
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
|
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
|
||||||
|
|
||||||
|
|
||||||
def send_check(sock, data, msg):
|
def send_check(
|
||||||
|
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
|
||||||
|
) -> None:
|
||||||
|
"""Send data to socket with error handling.
|
||||||
|
|
||||||
|
:param sock: Socket to send data to.
|
||||||
|
:param data: Data to send (can be list/tuple of ints, single int, string, or bytes).
|
||||||
|
:param msg: Description of what is being sent for error messages.
|
||||||
|
:raises OTAError: If sending fails.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if isinstance(data, (list, tuple)):
|
if isinstance(data, (list, tuple)):
|
||||||
data = bytes(data)
|
data = bytes(data)
|
||||||
@@ -210,10 +260,14 @@ def perform_ota(
|
|||||||
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
|
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Features
|
# Features - send both compression and SHA256 auth support
|
||||||
send_check(sock, FEATURE_SUPPORTS_COMPRESSION, "features")
|
features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH
|
||||||
|
send_check(sock, features_to_send, "features")
|
||||||
features = receive_exactly(
|
features = receive_exactly(
|
||||||
sock, 1, "features", [RESPONSE_HEADER_OK, RESPONSE_SUPPORTS_COMPRESSION]
|
sock,
|
||||||
|
1,
|
||||||
|
"features",
|
||||||
|
None, # Accept any response
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if features == RESPONSE_SUPPORTS_COMPRESSION:
|
if features == RESPONSE_SUPPORTS_COMPRESSION:
|
||||||
@@ -222,31 +276,52 @@ def perform_ota(
|
|||||||
else:
|
else:
|
||||||
upload_contents = file_contents
|
upload_contents = file_contents
|
||||||
|
|
||||||
(auth,) = receive_exactly(
|
def perform_auth(
|
||||||
sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK]
|
sock: socket.socket,
|
||||||
)
|
password: str,
|
||||||
if auth == RESPONSE_REQUEST_AUTH:
|
hash_func: Callable[..., Any],
|
||||||
|
nonce_size: int,
|
||||||
|
hash_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Perform challenge-response authentication using specified hash algorithm."""
|
||||||
if not password:
|
if not password:
|
||||||
raise OTAError("ESP requests password, but no password given!")
|
raise OTAError("ESP requests password, but no password given!")
|
||||||
nonce = receive_exactly(
|
|
||||||
sock, 32, "authentication nonce", [], decode=False
|
nonce_bytes = receive_exactly(
|
||||||
).decode()
|
sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False
|
||||||
_LOGGER.debug("Auth: Nonce is %s", nonce)
|
)
|
||||||
cnonce = hashlib.md5(str(random.random()).encode()).hexdigest()
|
assert isinstance(nonce_bytes, bytes)
|
||||||
_LOGGER.debug("Auth: CNonce is %s", cnonce)
|
nonce = nonce_bytes.decode()
|
||||||
|
_LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce)
|
||||||
|
|
||||||
|
# Generate cnonce
|
||||||
|
cnonce = hash_func(str(random.random()).encode()).hexdigest()
|
||||||
|
_LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce)
|
||||||
|
|
||||||
send_check(sock, cnonce, "auth cnonce")
|
send_check(sock, cnonce, "auth cnonce")
|
||||||
|
|
||||||
result_md5 = hashlib.md5()
|
# Calculate challenge response
|
||||||
result_md5.update(password.encode("utf-8"))
|
hasher = hash_func()
|
||||||
result_md5.update(nonce.encode())
|
hasher.update(password.encode("utf-8"))
|
||||||
result_md5.update(cnonce.encode())
|
hasher.update(nonce.encode())
|
||||||
result = result_md5.hexdigest()
|
hasher.update(cnonce.encode())
|
||||||
_LOGGER.debug("Auth: Result is %s", result)
|
result = hasher.hexdigest()
|
||||||
|
_LOGGER.debug("Auth: %s Result is %s", hash_name, result)
|
||||||
|
|
||||||
send_check(sock, result, "auth result")
|
send_check(sock, result, "auth result")
|
||||||
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
|
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
|
||||||
|
|
||||||
|
(auth,) = receive_exactly(
|
||||||
|
sock,
|
||||||
|
1,
|
||||||
|
"auth",
|
||||||
|
[RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK],
|
||||||
|
)
|
||||||
|
|
||||||
|
if auth != RESPONSE_AUTH_OK:
|
||||||
|
hash_func, nonce_size, hash_name = _AUTH_METHODS[auth]
|
||||||
|
perform_auth(sock, password, hash_func, nonce_size, hash_name)
|
||||||
|
|
||||||
# Set higher timeout during upload
|
# Set higher timeout during upload
|
||||||
sock.settimeout(30.0)
|
sock.settimeout(30.0)
|
||||||
|
|
||||||
|
32
tests/components/sha256/common.yaml
Normal file
32
tests/components/sha256/common.yaml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
esphome:
|
||||||
|
on_boot:
|
||||||
|
- lambda: |-
|
||||||
|
// Test SHA256 functionality
|
||||||
|
#ifdef USE_SHA256
|
||||||
|
using esphome::sha256::SHA256;
|
||||||
|
SHA256 hasher;
|
||||||
|
hasher.init();
|
||||||
|
|
||||||
|
// Test with "Hello World" - known SHA256
|
||||||
|
const char* test_string = "Hello World";
|
||||||
|
hasher.add(test_string, strlen(test_string));
|
||||||
|
hasher.calculate();
|
||||||
|
|
||||||
|
char hex_output[65];
|
||||||
|
hasher.get_hex(hex_output);
|
||||||
|
hex_output[64] = '\0';
|
||||||
|
|
||||||
|
ESP_LOGD("SHA256", "SHA256('Hello World') = %s", hex_output);
|
||||||
|
|
||||||
|
// Expected: a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e
|
||||||
|
const char* expected = "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e";
|
||||||
|
if (strcmp(hex_output, expected) == 0) {
|
||||||
|
ESP_LOGI("SHA256", "Test PASSED");
|
||||||
|
} else {
|
||||||
|
ESP_LOGE("SHA256", "Test FAILED. Expected %s", expected);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
ESP_LOGW("SHA256", "SHA256 not available on this platform");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
sha256:
|
1
tests/components/sha256/test.bk72xx-ard.yaml
Normal file
1
tests/components/sha256/test.bk72xx-ard.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<<: !include common.yaml
|
1
tests/components/sha256/test.esp32-idf.yaml
Normal file
1
tests/components/sha256/test.esp32-idf.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<<: !include common.yaml
|
1
tests/components/sha256/test.esp8266-ard.yaml
Normal file
1
tests/components/sha256/test.esp8266-ard.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<<: !include common.yaml
|
1
tests/components/sha256/test.host.yaml
Normal file
1
tests/components/sha256/test.host.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<<: !include common.yaml
|
1
tests/components/sha256/test.rp2040-ard.yaml
Normal file
1
tests/components/sha256/test.rp2040-ard.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<<: !include common.yaml
|
738
tests/unit_tests/test_espota2.py
Normal file
738
tests/unit_tests/test_espota2.py
Normal file
@@ -0,0 +1,738 @@
|
|||||||
|
"""Unit tests for esphome.espota2 module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
import gzip
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
from unittest.mock import Mock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
|
from esphome import espota2
|
||||||
|
from esphome.core import EsphomeError
|
||||||
|
|
||||||
|
# Test constants
|
||||||
|
MOCK_RANDOM_VALUE = 0.123456
|
||||||
|
MOCK_RANDOM_BYTES = b"0.123456"
|
||||||
|
MOCK_MD5_NONCE = b"12345678901234567890123456789012" # 32 char nonce for MD5
|
||||||
|
MOCK_SHA256_NONCE = b"1234567890123456789012345678901234567890123456789012345678901234" # 64 char nonce for SHA256
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_socket() -> Mock:
|
||||||
|
"""Create a mock socket for testing."""
|
||||||
|
socket_mock = Mock()
|
||||||
|
socket_mock.close = Mock()
|
||||||
|
socket_mock.recv = Mock()
|
||||||
|
socket_mock.sendall = Mock()
|
||||||
|
socket_mock.settimeout = Mock()
|
||||||
|
socket_mock.connect = Mock()
|
||||||
|
socket_mock.setsockopt = Mock()
|
||||||
|
return socket_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_file() -> io.BytesIO:
|
||||||
|
"""Create a mock firmware file for testing."""
|
||||||
|
return io.BytesIO(b"firmware content here")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_time() -> Generator[None]:
|
||||||
|
"""Mock time-related functions for consistent testing."""
|
||||||
|
# Provide enough values for multiple calls (tests may call perform_ota multiple times)
|
||||||
|
with (
|
||||||
|
patch("time.sleep"),
|
||||||
|
patch("time.perf_counter", side_effect=[0, 1, 0, 1, 0, 1]),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_random() -> Generator[Mock]:
|
||||||
|
"""Mock random for predictable test values."""
|
||||||
|
with patch("random.random", return_value=MOCK_RANDOM_VALUE) as mock_rand:
|
||||||
|
yield mock_rand
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_resolve_ip() -> Generator[Mock]:
|
||||||
|
"""Mock resolve_ip_address for testing."""
|
||||||
|
with patch("esphome.espota2.resolve_ip_address") as mock:
|
||||||
|
mock.return_value = [
|
||||||
|
(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.100", 3232))
|
||||||
|
]
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_perform_ota() -> Generator[Mock]:
|
||||||
|
"""Mock perform_ota function for testing."""
|
||||||
|
with patch("esphome.espota2.perform_ota") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_run_ota_impl() -> Generator[Mock]:
|
||||||
|
"""Mock run_ota_impl_ function for testing."""
|
||||||
|
with patch("esphome.espota2.run_ota_impl_") as mock:
|
||||||
|
mock.return_value = (0, "192.168.1.100")
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_socket_constructor(mock_socket: Mock) -> Generator[Mock]:
|
||||||
|
"""Mock socket.socket constructor to return our mock socket."""
|
||||||
|
with patch("socket.socket", return_value=mock_socket) as mock_constructor:
|
||||||
|
yield mock_constructor
|
||||||
|
|
||||||
|
|
||||||
|
def test_recv_decode_with_decode(mock_socket: Mock) -> None:
|
||||||
|
"""Test recv_decode with decode=True returns list."""
|
||||||
|
mock_socket.recv.return_value = b"\x01\x02\x03"
|
||||||
|
|
||||||
|
result = espota2.recv_decode(mock_socket, 3, decode=True)
|
||||||
|
|
||||||
|
assert result == [1, 2, 3]
|
||||||
|
mock_socket.recv.assert_called_once_with(3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recv_decode_without_decode(mock_socket: Mock) -> None:
|
||||||
|
"""Test recv_decode with decode=False returns bytes."""
|
||||||
|
mock_socket.recv.return_value = b"\x01\x02\x03"
|
||||||
|
|
||||||
|
result = espota2.recv_decode(mock_socket, 3, decode=False)
|
||||||
|
|
||||||
|
assert result == b"\x01\x02\x03"
|
||||||
|
mock_socket.recv.assert_called_once_with(3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_receive_exactly_success(mock_socket: Mock) -> None:
|
||||||
|
"""Test receive_exactly successfully receives expected data."""
|
||||||
|
mock_socket.recv.side_effect = [b"\x00", b"\x01\x02"]
|
||||||
|
|
||||||
|
result = espota2.receive_exactly(mock_socket, 3, "test", espota2.RESPONSE_OK)
|
||||||
|
|
||||||
|
assert result == [0, 1, 2]
|
||||||
|
assert mock_socket.recv.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_receive_exactly_with_error_response(mock_socket: Mock) -> None:
|
||||||
|
"""Test receive_exactly raises OTAError on error response."""
|
||||||
|
mock_socket.recv.return_value = bytes([espota2.RESPONSE_ERROR_AUTH_INVALID])
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Error auth:.*Authentication invalid"):
|
||||||
|
espota2.receive_exactly(mock_socket, 1, "auth", [espota2.RESPONSE_OK])
|
||||||
|
|
||||||
|
mock_socket.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_receive_exactly_socket_error(mock_socket: Mock) -> None:
|
||||||
|
"""Test receive_exactly handles socket errors."""
|
||||||
|
mock_socket.recv.side_effect = OSError("Connection reset")
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Error receiving acknowledge test"):
|
||||||
|
espota2.receive_exactly(mock_socket, 1, "test", espota2.RESPONSE_OK)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("error_code", "expected_msg"),
|
||||||
|
[
|
||||||
|
(espota2.RESPONSE_ERROR_MAGIC, "Error: Invalid magic byte"),
|
||||||
|
(espota2.RESPONSE_ERROR_UPDATE_PREPARE, "Error: Couldn't prepare flash memory"),
|
||||||
|
(espota2.RESPONSE_ERROR_AUTH_INVALID, "Error: Authentication invalid"),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_WRITING_FLASH,
|
||||||
|
"Error: Wring OTA data to flash memory failed",
|
||||||
|
),
|
||||||
|
(espota2.RESPONSE_ERROR_UPDATE_END, "Error: Finishing update failed"),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_INVALID_BOOTSTRAPPING,
|
||||||
|
"Error: Please press the reset button",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG,
|
||||||
|
"Error: ESP has been flashed with wrong flash size",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG,
|
||||||
|
"Error: ESP does not have the requested flash size",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE,
|
||||||
|
"Error: ESP does not have enough space",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE,
|
||||||
|
"Error: The OTA partition on the ESP is too small",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
espota2.RESPONSE_ERROR_NO_UPDATE_PARTITION,
|
||||||
|
"Error: The OTA partition on the ESP couldn't be found",
|
||||||
|
),
|
||||||
|
(espota2.RESPONSE_ERROR_MD5_MISMATCH, "Error: Application MD5 code mismatch"),
|
||||||
|
(espota2.RESPONSE_ERROR_UNKNOWN, "Unknown error from ESP"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_error_with_various_errors(error_code: int, expected_msg: str) -> None:
|
||||||
|
"""Test check_error raises appropriate errors for different error codes."""
|
||||||
|
with pytest.raises(espota2.OTAError, match=expected_msg):
|
||||||
|
espota2.check_error([error_code], [espota2.RESPONSE_OK])
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_error_unexpected_response() -> None:
|
||||||
|
"""Test check_error raises error for unexpected response."""
|
||||||
|
with pytest.raises(espota2.OTAError, match="Unexpected response from ESP: 0x7F"):
|
||||||
|
espota2.check_error([0x7F], [espota2.RESPONSE_OK, espota2.RESPONSE_AUTH_OK])
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_check_with_various_data_types(mock_socket: Mock) -> None:
|
||||||
|
"""Test send_check handles different data types."""
|
||||||
|
|
||||||
|
# Test with list/tuple
|
||||||
|
espota2.send_check(mock_socket, [0x01, 0x02], "list")
|
||||||
|
mock_socket.sendall.assert_called_with(b"\x01\x02")
|
||||||
|
|
||||||
|
# Test with int
|
||||||
|
espota2.send_check(mock_socket, 0x42, "int")
|
||||||
|
mock_socket.sendall.assert_called_with(b"\x42")
|
||||||
|
|
||||||
|
# Test with string
|
||||||
|
espota2.send_check(mock_socket, "hello", "string")
|
||||||
|
mock_socket.sendall.assert_called_with(b"hello")
|
||||||
|
|
||||||
|
# Test with bytes (should pass through)
|
||||||
|
espota2.send_check(mock_socket, b"\xaa\xbb", "bytes")
|
||||||
|
mock_socket.sendall.assert_called_with(b"\xaa\xbb")
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_check_socket_error(mock_socket: Mock) -> None:
|
||||||
|
"""Test send_check handles socket errors."""
|
||||||
|
mock_socket.sendall.side_effect = OSError("Broken pipe")
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Error sending test"):
|
||||||
|
espota2.send_check(mock_socket, b"data", "test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_successful_md5_auth(
|
||||||
|
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test successful OTA with MD5 authentication."""
|
||||||
|
# Setup socket responses for recv calls
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_REQUEST_AUTH]), # Auth request
|
||||||
|
MOCK_MD5_NONCE, # 32 char hex nonce
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # Auth result
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
# Run OTA
|
||||||
|
espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Verify magic bytes were sent
|
||||||
|
assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES))
|
||||||
|
|
||||||
|
# Verify features were sent (compression + SHA256 support)
|
||||||
|
assert mock_socket.sendall.call_args_list[1] == call(
|
||||||
|
bytes(
|
||||||
|
[
|
||||||
|
espota2.FEATURE_SUPPORTS_COMPRESSION
|
||||||
|
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify cnonce was sent (MD5 of random.random())
|
||||||
|
cnonce = hashlib.md5(MOCK_RANDOM_BYTES).hexdigest()
|
||||||
|
assert mock_socket.sendall.call_args_list[2] == call(cnonce.encode())
|
||||||
|
|
||||||
|
# Verify auth result was computed correctly
|
||||||
|
expected_hash = hashlib.md5()
|
||||||
|
expected_hash.update(b"testpass")
|
||||||
|
expected_hash.update(MOCK_MD5_NONCE)
|
||||||
|
expected_hash.update(cnonce.encode())
|
||||||
|
expected_result = expected_hash.hexdigest()
|
||||||
|
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_no_auth(mock_socket: Mock, mock_file: io.BytesIO) -> None:
|
||||||
|
"""Test OTA without authentication."""
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_1_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Should not send any auth-related data
|
||||||
|
auth_calls = [
|
||||||
|
call
|
||||||
|
for call in mock_socket.sendall.call_args_list
|
||||||
|
if "cnonce" in str(call) or "result" in str(call)
|
||||||
|
]
|
||||||
|
assert len(auth_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_with_compression(mock_socket: Mock) -> None:
|
||||||
|
"""Test OTA with compression support."""
|
||||||
|
original_content = b"firmware" * 100 # Repeating content for compression
|
||||||
|
mock_file = io.BytesIO(original_content)
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_SUPPORTS_COMPRESSION]), # Device supports compression
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Verify compressed content was sent
|
||||||
|
# Get the binary size that was sent (4 bytes after features)
|
||||||
|
size_bytes = mock_socket.sendall.call_args_list[2][0][0]
|
||||||
|
sent_size = struct.unpack(">I", size_bytes)[0]
|
||||||
|
|
||||||
|
# Size should be less than original due to compression
|
||||||
|
assert sent_size < len(original_content)
|
||||||
|
|
||||||
|
# Verify the content sent was gzipped
|
||||||
|
compressed = gzip.compress(original_content, compresslevel=9)
|
||||||
|
assert sent_size == len(compressed)
|
||||||
|
|
||||||
|
|
||||||
|
def test_perform_ota_auth_without_password(mock_socket: Mock) -> None:
|
||||||
|
"""Test OTA fails when auth is required but no password provided."""
|
||||||
|
mock_file = io.BytesIO(b"firmware")
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]),
|
||||||
|
bytes([espota2.RESPONSE_REQUEST_AUTH]),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = responses
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
espota2.OTAError, match="ESP requests password, but no password given"
|
||||||
|
):
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_md5_auth_wrong_password(
|
||||||
|
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test OTA fails when MD5 authentication is rejected due to wrong password."""
|
||||||
|
# Setup socket responses for recv calls
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_REQUEST_AUTH]), # Auth request
|
||||||
|
MOCK_MD5_NONCE, # 32 char hex nonce
|
||||||
|
bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]), # Auth rejected!
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Error auth.*Authentication invalid"):
|
||||||
|
espota2.perform_ota(mock_socket, "wrongpassword", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Verify the socket was closed after auth failure
|
||||||
|
mock_socket.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_sha256_auth_wrong_password(
|
||||||
|
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test OTA fails when SHA256 authentication is rejected due to wrong password."""
|
||||||
|
# Setup socket responses for recv calls
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), # SHA256 Auth request
|
||||||
|
MOCK_SHA256_NONCE, # 64 char hex nonce
|
||||||
|
bytes([espota2.RESPONSE_ERROR_AUTH_INVALID]), # Auth rejected!
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Error auth.*Authentication invalid"):
|
||||||
|
espota2.perform_ota(mock_socket, "wrongpassword", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Verify the socket was closed after auth failure
|
||||||
|
mock_socket.close.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_perform_ota_sha256_auth_without_password(mock_socket: Mock) -> None:
|
||||||
|
"""Test OTA fails when SHA256 auth is required but no password provided."""
|
||||||
|
mock_file = io.BytesIO(b"firmware")
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]),
|
||||||
|
bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = responses
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
espota2.OTAError, match="ESP requests password, but no password given"
|
||||||
|
):
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_perform_ota_unexpected_auth_response(mock_socket: Mock) -> None:
|
||||||
|
"""Test OTA fails when device sends an unexpected auth response."""
|
||||||
|
mock_file = io.BytesIO(b"firmware")
|
||||||
|
|
||||||
|
# Use 0x03 which is not in the expected auth responses
|
||||||
|
# This will be caught by check_error and raise "Unexpected response from ESP"
|
||||||
|
UNKNOWN_AUTH_METHOD = 0x03
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK, espota2.OTA_VERSION_2_0]),
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]),
|
||||||
|
bytes([UNKNOWN_AUTH_METHOD]), # Unknown auth method
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = responses
|
||||||
|
|
||||||
|
# This will actually raise "Unexpected response from ESP" from check_error
|
||||||
|
with pytest.raises(
|
||||||
|
espota2.OTAError, match=r"Error auth: Unexpected response from ESP: 0x03"
|
||||||
|
):
|
||||||
|
espota2.perform_ota(mock_socket, "password", mock_file, "test.bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_perform_ota_unsupported_version(mock_socket: Mock) -> None:
|
||||||
|
"""Test OTA fails with unsupported version."""
|
||||||
|
mock_file = io.BytesIO(b"firmware")
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK, 99]), # Unsupported version
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = responses
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Device uses unsupported OTA version"):
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_upload_error(mock_socket: Mock, mock_file: io.BytesIO) -> None:
|
||||||
|
"""Test OTA handles upload errors."""
|
||||||
|
# Setup responses - provide enough for the recv calls
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
]
|
||||||
|
# Add OSError to recv to simulate connection loss during chunk read
|
||||||
|
recv_responses.append(OSError("Connection lost"))
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="Error receiving acknowledge chunk OK"):
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip")
|
||||||
|
def test_run_ota_impl_successful(
|
||||||
|
mock_socket: Mock, tmp_path: Path, mock_perform_ota: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test run_ota_impl_ with successful upload."""
|
||||||
|
# Create a real firmware file
|
||||||
|
firmware_file = tmp_path / "firmware.bin"
|
||||||
|
firmware_file.write_bytes(b"firmware content")
|
||||||
|
|
||||||
|
# Run OTA with real file path
|
||||||
|
result_code, result_host = espota2.run_ota_impl_(
|
||||||
|
"test.local", 3232, "password", str(firmware_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify success
|
||||||
|
assert result_code == 0
|
||||||
|
assert result_host == "192.168.1.100"
|
||||||
|
|
||||||
|
# Verify socket was configured correctly
|
||||||
|
mock_socket.settimeout.assert_called_with(10.0)
|
||||||
|
mock_socket.connect.assert_called_once_with(("192.168.1.100", 3232))
|
||||||
|
mock_socket.close.assert_called_once()
|
||||||
|
|
||||||
|
# Verify perform_ota was called with real file
|
||||||
|
mock_perform_ota.assert_called_once()
|
||||||
|
call_args = mock_perform_ota.call_args[0]
|
||||||
|
assert call_args[0] == mock_socket
|
||||||
|
assert call_args[1] == "password"
|
||||||
|
# Verify the file object is a proper file handle
|
||||||
|
assert isinstance(call_args[2], io.IOBase)
|
||||||
|
assert call_args[3] == str(firmware_file)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_socket_constructor", "mock_resolve_ip")
|
||||||
|
def test_run_ota_impl_connection_failed(mock_socket: Mock, tmp_path: Path) -> None:
|
||||||
|
"""Test run_ota_impl_ when connection fails."""
|
||||||
|
mock_socket.connect.side_effect = OSError("Connection refused")
|
||||||
|
|
||||||
|
# Create a real firmware file
|
||||||
|
firmware_file = tmp_path / "firmware.bin"
|
||||||
|
firmware_file.write_bytes(b"firmware content")
|
||||||
|
|
||||||
|
result_code, result_host = espota2.run_ota_impl_(
|
||||||
|
"test.local", 3232, "password", str(firmware_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_code == 1
|
||||||
|
assert result_host is None
|
||||||
|
mock_socket.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_ota_impl_resolve_failed(tmp_path: Path, mock_resolve_ip: Mock) -> None:
|
||||||
|
"""Test run_ota_impl_ when DNS resolution fails."""
|
||||||
|
# Create a real firmware file
|
||||||
|
firmware_file = tmp_path / "firmware.bin"
|
||||||
|
firmware_file.write_bytes(b"firmware content")
|
||||||
|
|
||||||
|
mock_resolve_ip.side_effect = EsphomeError("DNS resolution failed")
|
||||||
|
|
||||||
|
with pytest.raises(espota2.OTAError, match="DNS resolution failed"):
|
||||||
|
result_code, result_host = espota2.run_ota_impl_(
|
||||||
|
"unknown.host", 3232, "password", str(firmware_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_ota_wrapper(mock_run_ota_impl: Mock) -> None:
|
||||||
|
"""Test run_ota wrapper function."""
|
||||||
|
# Test successful case
|
||||||
|
mock_run_ota_impl.return_value = (0, "192.168.1.100")
|
||||||
|
result = espota2.run_ota("test.local", 3232, "pass", "fw.bin")
|
||||||
|
assert result == (0, "192.168.1.100")
|
||||||
|
|
||||||
|
# Test error case
|
||||||
|
mock_run_ota_impl.side_effect = espota2.OTAError("Test error")
|
||||||
|
result = espota2.run_ota("test.local", 3232, "pass", "fw.bin")
|
||||||
|
assert result == (1, None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress_bar(capsys: CaptureFixture[str]) -> None:
|
||||||
|
"""Test ProgressBar functionality."""
|
||||||
|
progress = espota2.ProgressBar()
|
||||||
|
|
||||||
|
# Test initial update
|
||||||
|
progress.update(0.0)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "0%" in captured.err
|
||||||
|
assert "[" in captured.err
|
||||||
|
|
||||||
|
# Test progress update
|
||||||
|
progress.update(0.5)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "50%" in captured.err
|
||||||
|
|
||||||
|
# Test completion
|
||||||
|
progress.update(1.0)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "100%" in captured.err
|
||||||
|
assert "Done" in captured.err
|
||||||
|
|
||||||
|
# Test done method
|
||||||
|
progress.done()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.err == "\n"
|
||||||
|
|
||||||
|
# Test same progress doesn't update
|
||||||
|
progress.update(0.5)
|
||||||
|
progress.update(0.5)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
# Should only see one update (second call shouldn't write)
|
||||||
|
assert captured.err.count("50%") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for SHA256 authentication
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_successful_sha256_auth(
|
||||||
|
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test successful OTA with SHA256 authentication."""
|
||||||
|
# Setup socket responses for recv calls
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_REQUEST_SHA256_AUTH]), # SHA256 Auth request
|
||||||
|
MOCK_SHA256_NONCE, # 64 char hex nonce
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # Auth result
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
# Run OTA
|
||||||
|
espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Verify magic bytes were sent
|
||||||
|
assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES))
|
||||||
|
|
||||||
|
# Verify features were sent (compression + SHA256 support)
|
||||||
|
assert mock_socket.sendall.call_args_list[1] == call(
|
||||||
|
bytes(
|
||||||
|
[
|
||||||
|
espota2.FEATURE_SUPPORTS_COMPRESSION
|
||||||
|
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify cnonce was sent (SHA256 of random.random())
|
||||||
|
cnonce = hashlib.sha256(MOCK_RANDOM_BYTES).hexdigest()
|
||||||
|
assert mock_socket.sendall.call_args_list[2] == call(cnonce.encode())
|
||||||
|
|
||||||
|
# Verify auth result was computed correctly with SHA256
|
||||||
|
expected_hash = hashlib.sha256()
|
||||||
|
expected_hash.update(b"testpass")
|
||||||
|
expected_hash.update(MOCK_SHA256_NONCE)
|
||||||
|
expected_hash.update(cnonce.encode())
|
||||||
|
expected_result = expected_hash.hexdigest()
|
||||||
|
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_sha256_fallback_to_md5(
|
||||||
|
mock_socket: Mock, mock_file: io.BytesIO, mock_random: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test SHA256-capable client falls back to MD5 for compatibility."""
|
||||||
|
# This test verifies the temporary backward compatibility
|
||||||
|
# where a SHA256-capable client can still authenticate with MD5
|
||||||
|
# This compatibility will be removed in 2026.1.0
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes(
|
||||||
|
[espota2.RESPONSE_REQUEST_AUTH]
|
||||||
|
), # MD5 Auth request (device doesn't support SHA256)
|
||||||
|
MOCK_MD5_NONCE, # 32 char hex nonce for MD5
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # Auth result
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
|
||||||
|
# Run OTA - should work even though device requested MD5
|
||||||
|
espota2.perform_ota(mock_socket, "testpass", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# Verify client still advertised SHA256 support
|
||||||
|
assert mock_socket.sendall.call_args_list[1] == call(
|
||||||
|
bytes(
|
||||||
|
[
|
||||||
|
espota2.FEATURE_SUPPORTS_COMPRESSION
|
||||||
|
| espota2.FEATURE_SUPPORTS_SHA256_AUTH
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# But authentication was done with MD5
|
||||||
|
cnonce = hashlib.md5(MOCK_RANDOM_BYTES).hexdigest()
|
||||||
|
expected_hash = hashlib.md5()
|
||||||
|
expected_hash.update(b"testpass")
|
||||||
|
expected_hash.update(MOCK_MD5_NONCE)
|
||||||
|
expected_hash.update(cnonce.encode())
|
||||||
|
expected_result = expected_hash.hexdigest()
|
||||||
|
assert mock_socket.sendall.call_args_list[3] == call(expected_result.encode())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_time")
|
||||||
|
def test_perform_ota_version_differences(
|
||||||
|
mock_socket: Mock, mock_file: io.BytesIO
|
||||||
|
) -> None:
|
||||||
|
"""Test OTA behavior differences between version 1.0 and 2.0."""
|
||||||
|
# Test version 1.0 - no chunk acknowledgments
|
||||||
|
recv_responses = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_1_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
# No RESPONSE_CHUNK_OK for v1
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# For v1.0, verify that we only get the expected number of recv calls
|
||||||
|
# v1.0 doesn't have chunk acknowledgments, so fewer recv calls
|
||||||
|
assert mock_socket.recv.call_count == 8 # v1.0 has 8 recv calls
|
||||||
|
|
||||||
|
# Reset mock for v2.0 test
|
||||||
|
mock_socket.reset_mock()
|
||||||
|
|
||||||
|
# Reset file position for second test
|
||||||
|
mock_file.seek(0)
|
||||||
|
|
||||||
|
# Test version 2.0 - with chunk acknowledgments
|
||||||
|
recv_responses_v2 = [
|
||||||
|
bytes([espota2.RESPONSE_OK]), # First byte of version response
|
||||||
|
bytes([espota2.OTA_VERSION_2_0]), # Version number
|
||||||
|
bytes([espota2.RESPONSE_HEADER_OK]), # Features response
|
||||||
|
bytes([espota2.RESPONSE_AUTH_OK]), # No auth required
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK
|
||||||
|
bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK
|
||||||
|
bytes([espota2.RESPONSE_CHUNK_OK]), # v2.0 has chunk acknowledgment
|
||||||
|
bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK
|
||||||
|
bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_socket.recv.side_effect = recv_responses_v2
|
||||||
|
espota2.perform_ota(mock_socket, "", mock_file, "test.bin")
|
||||||
|
|
||||||
|
# For v2.0, verify more recv calls due to chunk acknowledgments
|
||||||
|
assert mock_socket.recv.call_count == 9 # v2.0 has 9 recv calls (includes chunk OK)
|
Reference in New Issue
Block a user