mirror of
https://github.com/esphome/esphome.git
synced 2025-10-01 17:42:22 +01:00
[ota] Add SHA256 password authentication with backward compatibility (#10809)
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import gzip
|
||||
import hashlib
|
||||
import io
|
||||
@@ -9,12 +10,14 @@ import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from esphome.core import EsphomeError
|
||||
from esphome.helpers import resolve_ip_address
|
||||
|
||||
RESPONSE_OK = 0x00
|
||||
RESPONSE_REQUEST_AUTH = 0x01
|
||||
RESPONSE_REQUEST_SHA256_AUTH = 0x02
|
||||
|
||||
RESPONSE_HEADER_OK = 0x40
|
||||
RESPONSE_AUTH_OK = 0x41
|
||||
@@ -45,6 +48,7 @@ OTA_VERSION_2_0 = 2
|
||||
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
|
||||
|
||||
FEATURE_SUPPORTS_COMPRESSION = 0x01
|
||||
FEATURE_SUPPORTS_SHA256_AUTH = 0x02
|
||||
|
||||
|
||||
UPLOAD_BLOCK_SIZE = 8192
|
||||
@@ -52,6 +56,12 @@ UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Authentication method lookup table: response -> (hash_func, nonce_size, name)
|
||||
_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = {
|
||||
RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"),
|
||||
RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"),
|
||||
}
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self):
|
||||
@@ -81,18 +91,43 @@ class OTAError(EsphomeError):
|
||||
pass
|
||||
|
||||
|
||||
def recv_decode(sock, amount, decode=True):
|
||||
def recv_decode(
|
||||
sock: socket.socket, amount: int, decode: bool = True
|
||||
) -> bytes | list[int]:
|
||||
"""Receive data from socket and optionally decode to list of integers.
|
||||
|
||||
:param sock: Socket to receive data from.
|
||||
:param amount: Number of bytes to receive.
|
||||
:param decode: If True, convert bytes to list of integers, otherwise return raw bytes.
|
||||
:return: List of integers if decode=True, otherwise raw bytes.
|
||||
"""
|
||||
data = sock.recv(amount)
|
||||
if not decode:
|
||||
return data
|
||||
return list(data)
|
||||
|
||||
|
||||
def receive_exactly(sock, amount, msg, expect, decode=True):
|
||||
data = [] if decode else b""
|
||||
def receive_exactly(
|
||||
sock: socket.socket,
|
||||
amount: int,
|
||||
msg: str,
|
||||
expect: int | list[int] | None,
|
||||
decode: bool = True,
|
||||
) -> list[int] | bytes:
|
||||
"""Receive exactly the specified amount of data from socket with error checking.
|
||||
|
||||
:param sock: Socket to receive data from.
|
||||
:param amount: Exact number of bytes to receive.
|
||||
:param msg: Description of what is being received for error messages.
|
||||
:param expect: Expected response code(s) for validation, None to skip validation.
|
||||
:param decode: If True, return list of integers, otherwise return raw bytes.
|
||||
:return: List of integers if decode=True, otherwise raw bytes.
|
||||
:raises OTAError: If receiving fails or response doesn't match expected.
|
||||
"""
|
||||
data: list[int] | bytes = [] if decode else b""
|
||||
|
||||
try:
|
||||
data += recv_decode(sock, 1, decode=decode)
|
||||
data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
|
||||
except OSError as err:
|
||||
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
|
||||
|
||||
@@ -104,13 +139,19 @@ def receive_exactly(sock, amount, msg, expect, decode=True):
|
||||
|
||||
while len(data) < amount:
|
||||
try:
|
||||
data += recv_decode(sock, amount - len(data), decode=decode)
|
||||
data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
|
||||
except OSError as err:
|
||||
raise OTAError(f"Error receiving {msg}: {err}") from err
|
||||
return data
|
||||
|
||||
|
||||
def check_error(data, expect):
|
||||
def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
|
||||
"""Check response data for error codes and validate against expected response.
|
||||
|
||||
:param data: Response data from device (first byte is the response code).
|
||||
:param expect: Expected response code(s), None to skip validation.
|
||||
:raises OTAError: If an error code is detected or response doesn't match expected.
|
||||
"""
|
||||
if not expect:
|
||||
return
|
||||
dat = data[0]
|
||||
@@ -125,7 +166,7 @@ def check_error(data, expect):
|
||||
raise OTAError("Error: Authentication invalid. Is the password correct?")
|
||||
if dat == RESPONSE_ERROR_WRITING_FLASH:
|
||||
raise OTAError(
|
||||
"Error: Wring OTA data to flash memory failed. See USB logs for more "
|
||||
"Error: Writing OTA data to flash memory failed. See USB logs for more "
|
||||
"information."
|
||||
)
|
||||
if dat == RESPONSE_ERROR_UPDATE_END:
|
||||
@@ -177,7 +218,16 @@ def check_error(data, expect):
|
||||
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
|
||||
|
||||
|
||||
def send_check(sock, data, msg):
|
||||
def send_check(
|
||||
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
|
||||
) -> None:
|
||||
"""Send data to socket with error handling.
|
||||
|
||||
:param sock: Socket to send data to.
|
||||
:param data: Data to send (can be list/tuple of ints, single int, string, or bytes).
|
||||
:param msg: Description of what is being sent for error messages.
|
||||
:raises OTAError: If sending fails.
|
||||
"""
|
||||
try:
|
||||
if isinstance(data, (list, tuple)):
|
||||
data = bytes(data)
|
||||
@@ -210,10 +260,14 @@ def perform_ota(
|
||||
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
|
||||
)
|
||||
|
||||
# Features
|
||||
send_check(sock, FEATURE_SUPPORTS_COMPRESSION, "features")
|
||||
# Features - send both compression and SHA256 auth support
|
||||
features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH
|
||||
send_check(sock, features_to_send, "features")
|
||||
features = receive_exactly(
|
||||
sock, 1, "features", [RESPONSE_HEADER_OK, RESPONSE_SUPPORTS_COMPRESSION]
|
||||
sock,
|
||||
1,
|
||||
"features",
|
||||
None, # Accept any response
|
||||
)[0]
|
||||
|
||||
if features == RESPONSE_SUPPORTS_COMPRESSION:
|
||||
@@ -222,31 +276,52 @@ def perform_ota(
|
||||
else:
|
||||
upload_contents = file_contents
|
||||
|
||||
(auth,) = receive_exactly(
|
||||
sock, 1, "auth", [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK]
|
||||
)
|
||||
if auth == RESPONSE_REQUEST_AUTH:
|
||||
def perform_auth(
|
||||
sock: socket.socket,
|
||||
password: str,
|
||||
hash_func: Callable[..., Any],
|
||||
nonce_size: int,
|
||||
hash_name: str,
|
||||
) -> None:
|
||||
"""Perform challenge-response authentication using specified hash algorithm."""
|
||||
if not password:
|
||||
raise OTAError("ESP requests password, but no password given!")
|
||||
nonce = receive_exactly(
|
||||
sock, 32, "authentication nonce", [], decode=False
|
||||
).decode()
|
||||
_LOGGER.debug("Auth: Nonce is %s", nonce)
|
||||
cnonce = hashlib.md5(str(random.random()).encode()).hexdigest()
|
||||
_LOGGER.debug("Auth: CNonce is %s", cnonce)
|
||||
|
||||
nonce_bytes = receive_exactly(
|
||||
sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False
|
||||
)
|
||||
assert isinstance(nonce_bytes, bytes)
|
||||
nonce = nonce_bytes.decode()
|
||||
_LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce)
|
||||
|
||||
# Generate cnonce
|
||||
cnonce = hash_func(str(random.random()).encode()).hexdigest()
|
||||
_LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce)
|
||||
|
||||
send_check(sock, cnonce, "auth cnonce")
|
||||
|
||||
result_md5 = hashlib.md5()
|
||||
result_md5.update(password.encode("utf-8"))
|
||||
result_md5.update(nonce.encode())
|
||||
result_md5.update(cnonce.encode())
|
||||
result = result_md5.hexdigest()
|
||||
_LOGGER.debug("Auth: Result is %s", result)
|
||||
# Calculate challenge response
|
||||
hasher = hash_func()
|
||||
hasher.update(password.encode("utf-8"))
|
||||
hasher.update(nonce.encode())
|
||||
hasher.update(cnonce.encode())
|
||||
result = hasher.hexdigest()
|
||||
_LOGGER.debug("Auth: %s Result is %s", hash_name, result)
|
||||
|
||||
send_check(sock, result, "auth result")
|
||||
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
|
||||
|
||||
(auth,) = receive_exactly(
|
||||
sock,
|
||||
1,
|
||||
"auth",
|
||||
[RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK],
|
||||
)
|
||||
|
||||
if auth != RESPONSE_AUTH_OK:
|
||||
hash_func, nonce_size, hash_name = _AUTH_METHODS[auth]
|
||||
perform_auth(sock, password, hash_func, nonce_size, hash_name)
|
||||
|
||||
# Set higher timeout during upload
|
||||
sock.settimeout(30.0)
|
||||
|
||||
|
Reference in New Issue
Block a user