1
0
mirror of https://github.com/esphome/esphome.git synced 2025-09-26 23:22:21 +01:00
Files
esphome/esphome/espota2.py
2025-09-23 09:45:02 -05:00

446 lines
15 KiB
Python

from __future__ import annotations
from collections.abc import Callable
import gzip
import hashlib
import io
import logging
from pathlib import Path
import random
import socket
import sys
import time
from typing import Any
from esphome.core import EsphomeError
from esphome.helpers import resolve_ip_address
RESPONSE_OK = 0x00
RESPONSE_REQUEST_AUTH = 0x01
RESPONSE_REQUEST_SHA256_AUTH = 0x02
RESPONSE_HEADER_OK = 0x40
RESPONSE_AUTH_OK = 0x41
RESPONSE_UPDATE_PREPARE_OK = 0x42
RESPONSE_BIN_MD5_OK = 0x43
RESPONSE_RECEIVE_OK = 0x44
RESPONSE_UPDATE_END_OK = 0x45
RESPONSE_SUPPORTS_COMPRESSION = 0x46
RESPONSE_CHUNK_OK = 0x47
RESPONSE_ERROR_MAGIC = 0x80
RESPONSE_ERROR_UPDATE_PREPARE = 0x81
RESPONSE_ERROR_AUTH_INVALID = 0x82
RESPONSE_ERROR_WRITING_FLASH = 0x83
RESPONSE_ERROR_UPDATE_END = 0x84
RESPONSE_ERROR_INVALID_BOOTSTRAPPING = 0x85
RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG = 0x86
RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG = 0x87
RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE = 0x88
RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE = 0x89
RESPONSE_ERROR_NO_UPDATE_PARTITION = 0x8A
RESPONSE_ERROR_MD5_MISMATCH = 0x8B
RESPONSE_ERROR_UNKNOWN = 0xFF
OTA_VERSION_1_0 = 1
OTA_VERSION_2_0 = 2
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
FEATURE_SUPPORTS_COMPRESSION = 0x01
FEATURE_SUPPORTS_SHA256_AUTH = 0x02
UPLOAD_BLOCK_SIZE = 8192
UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8
_LOGGER = logging.getLogger(__name__)
# Authentication method lookup table: response -> (hash_func, nonce_size, name)
_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = {
RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"),
RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"),
}
class ProgressBar:
def __init__(self):
self.last_progress = None
def update(self, progress):
bar_length = 60
status = ""
if progress >= 1:
progress = 1
status = "Done...\r\n"
new_progress = int(progress * 100)
if new_progress == self.last_progress:
return
self.last_progress = new_progress
block = int(round(bar_length * progress))
text = f"\rUploading: [{'=' * block + ' ' * (bar_length - block)}] {new_progress}% {status}"
sys.stderr.write(text)
sys.stderr.flush()
def done(self):
sys.stderr.write("\n")
sys.stderr.flush()
class OTAError(EsphomeError):
pass
def recv_decode(
sock: socket.socket, amount: int, decode: bool = True
) -> bytes | list[int]:
"""Receive data from socket and optionally decode to list of integers.
:param sock: Socket to receive data from.
:param amount: Number of bytes to receive.
:param decode: If True, convert bytes to list of integers, otherwise return raw bytes.
:return: List of integers if decode=True, otherwise raw bytes.
"""
data = sock.recv(amount)
if not decode:
return data
return list(data)
def receive_exactly(
sock: socket.socket,
amount: int,
msg: str,
expect: int | list[int] | None,
decode: bool = True,
) -> list[int] | bytes:
"""Receive exactly the specified amount of data from socket with error checking.
:param sock: Socket to receive data from.
:param amount: Exact number of bytes to receive.
:param msg: Description of what is being received for error messages.
:param expect: Expected response code(s) for validation, None to skip validation.
:param decode: If True, return list of integers, otherwise return raw bytes.
:return: List of integers if decode=True, otherwise raw bytes.
:raises OTAError: If receiving fails or response doesn't match expected.
"""
data: list[int] | bytes = [] if decode else b""
try:
data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
except OSError as err:
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
try:
check_error(data, expect)
except OTAError as err:
sock.close()
raise OTAError(f"Error {msg}: {err}") from err
while len(data) < amount:
try:
data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
except OSError as err:
raise OTAError(f"Error receiving {msg}: {err}") from err
return data
def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
"""Check response data for error codes and validate against expected response.
:param data: Response data from device (first byte is the response code).
:param expect: Expected response code(s), None to skip validation.
:raises OTAError: If an error code is detected or response doesn't match expected.
"""
if not expect:
return
dat = data[0]
if dat == RESPONSE_ERROR_MAGIC:
raise OTAError("Error: Invalid magic byte")
if dat == RESPONSE_ERROR_UPDATE_PREPARE:
raise OTAError(
"Error: Couldn't prepare flash memory for update. Is the binary too big? "
"Please try restarting the ESP."
)
if dat == RESPONSE_ERROR_AUTH_INVALID:
raise OTAError("Error: Authentication invalid. Is the password correct?")
if dat == RESPONSE_ERROR_WRITING_FLASH:
raise OTAError(
"Error: Writing OTA data to flash memory failed. See USB logs for more "
"information."
)
if dat == RESPONSE_ERROR_UPDATE_END:
raise OTAError(
"Error: Finishing update failed. See the MQTT/USB logs for more "
"information."
)
if dat == RESPONSE_ERROR_INVALID_BOOTSTRAPPING:
raise OTAError(
"Error: Please press the reset button on the ESP. A manual reset is "
"required on the first OTA-Update after flashing via USB."
)
if dat == RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG:
raise OTAError(
"Error: ESP has been flashed with wrong flash size. Please choose the "
"correct 'board' option (esp01_1m always works) and then flash over USB."
)
if dat == RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG:
raise OTAError(
"Error: ESP does not have the requested flash size (wrong board). Please "
"choose the correct 'board' option (esp01_1m always works) and try "
"uploading again."
)
if dat == RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE:
raise OTAError(
"Error: ESP does not have enough space to store OTA file. Please try "
"flashing a minimal firmware (remove everything except ota)"
)
if dat == RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE:
raise OTAError(
"Error: The OTA partition on the ESP is too small. ESPHome needs to resize "
"this partition, please flash over USB."
)
if dat == RESPONSE_ERROR_NO_UPDATE_PARTITION:
raise OTAError(
"Error: The OTA partition on the ESP couldn't be found. ESPHome needs to create "
"this partition, please flash over USB."
)
if dat == RESPONSE_ERROR_MD5_MISMATCH:
raise OTAError(
"Error: Application MD5 code mismatch. Please try again "
"or flash over USB with a good quality cable."
)
if dat == RESPONSE_ERROR_UNKNOWN:
raise OTAError("Unknown error from ESP")
if not isinstance(expect, (list, tuple)):
expect = [expect]
if dat not in expect:
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
def send_check(
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
) -> None:
"""Send data to socket with error handling.
:param sock: Socket to send data to.
:param data: Data to send (can be list/tuple of ints, single int, string, or bytes).
:param msg: Description of what is being sent for error messages.
:raises OTAError: If sending fails.
"""
try:
if isinstance(data, (list, tuple)):
data = bytes(data)
elif isinstance(data, int):
data = bytes([data])
elif isinstance(data, str):
data = data.encode("utf8")
sock.sendall(data)
except OSError as err:
raise OTAError(f"Error sending {msg}: {err}") from err
def perform_ota(
sock: socket.socket, password: str, file_handle: io.IOBase, filename: Path
) -> None:
file_contents = file_handle.read()
file_size = len(file_contents)
_LOGGER.info("Uploading %s (%s bytes)", filename, file_size)
# Enable nodelay, we need it for phase 1
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
send_check(sock, MAGIC_BYTES, "magic bytes")
_, version = receive_exactly(sock, 2, "version", RESPONSE_OK)
_LOGGER.debug("Device support OTA version: %s", version)
supported_versions = (OTA_VERSION_1_0, OTA_VERSION_2_0)
if version not in supported_versions:
raise OTAError(
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
)
# Features - send both compression and SHA256 auth support
features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH
send_check(sock, features_to_send, "features")
features = receive_exactly(
sock,
1,
"features",
None, # Accept any response
)[0]
if features == RESPONSE_SUPPORTS_COMPRESSION:
upload_contents = gzip.compress(file_contents, compresslevel=9)
_LOGGER.info("Compressed to %s bytes", len(upload_contents))
else:
upload_contents = file_contents
def perform_auth(
sock: socket.socket,
password: str,
hash_func: Callable[..., Any],
nonce_size: int,
hash_name: str,
) -> None:
"""Perform challenge-response authentication using specified hash algorithm."""
if not password:
raise OTAError("ESP requests password, but no password given!")
nonce_bytes = receive_exactly(
sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False
)
assert isinstance(nonce_bytes, bytes)
nonce = nonce_bytes.decode()
_LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce)
# Generate cnonce
cnonce = hash_func(str(random.random()).encode()).hexdigest()
_LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce)
send_check(sock, cnonce, "auth cnonce")
# Calculate challenge response
hasher = hash_func()
hasher.update(password.encode("utf-8"))
hasher.update(nonce.encode())
hasher.update(cnonce.encode())
result = hasher.hexdigest()
_LOGGER.debug("Auth: %s Result is %s", hash_name, result)
send_check(sock, result, "auth result")
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
(auth,) = receive_exactly(
sock,
1,
"auth",
[RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK],
)
if auth != RESPONSE_AUTH_OK:
hash_func, nonce_size, hash_name = _AUTH_METHODS[auth]
perform_auth(sock, password, hash_func, nonce_size, hash_name)
# Set higher timeout during upload
sock.settimeout(30.0)
upload_size = len(upload_contents)
upload_size_encoded = [
(upload_size >> 24) & 0xFF,
(upload_size >> 16) & 0xFF,
(upload_size >> 8) & 0xFF,
(upload_size >> 0) & 0xFF,
]
send_check(sock, upload_size_encoded, "binary size")
receive_exactly(sock, 1, "binary size", RESPONSE_UPDATE_PREPARE_OK)
upload_md5 = hashlib.md5(upload_contents).hexdigest()
_LOGGER.debug("MD5 of upload is %s", upload_md5)
send_check(sock, upload_md5, "file checksum")
receive_exactly(sock, 1, "file checksum", RESPONSE_BIN_MD5_OK)
# Disable nodelay for transfer
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
# Limit send buffer (usually around 100kB) in order to have progress bar
# show the actual progress
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, UPLOAD_BUFFER_SIZE)
start_time = time.perf_counter()
offset = 0
progress = ProgressBar()
while True:
chunk = upload_contents[offset : offset + UPLOAD_BLOCK_SIZE]
if not chunk:
break
offset += len(chunk)
try:
sock.sendall(chunk)
if version >= OTA_VERSION_2_0:
receive_exactly(sock, 1, "chunk OK", RESPONSE_CHUNK_OK)
except OSError as err:
sys.stderr.write("\n")
raise OTAError(f"Error sending data: {err}") from err
progress.update(offset / upload_size)
progress.done()
# Enable nodelay for last checks
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
duration = time.perf_counter() - start_time
_LOGGER.info("Upload took %.2f seconds, waiting for result...", duration)
receive_exactly(sock, 1, "receive OK", RESPONSE_RECEIVE_OK)
receive_exactly(sock, 1, "Update end", RESPONSE_UPDATE_END_OK)
send_check(sock, RESPONSE_OK, "end acknowledgement")
_LOGGER.info("OTA successful")
# Do not connect logs until it is fully on
time.sleep(1)
def run_ota_impl_(
remote_host: str | list[str], remote_port: int, password: str, filename: Path
) -> tuple[int, str | None]:
from esphome.core import CORE
# Handle both single host and list of hosts
try:
# Resolve all hosts at once for parallel DNS resolution
res = resolve_ip_address(
remote_host, remote_port, address_cache=CORE.address_cache
)
except EsphomeError as err:
_LOGGER.error(
"Error resolving IP address of %s. Is it connected to WiFi?",
remote_host,
)
_LOGGER.error(
"(If this error persists, please set a static IP address: "
"https://esphome.io/components/wifi.html#manual-ips)"
)
raise OTAError(err) from err
for r in res:
af, socktype, _, _, sa = r
_LOGGER.info("Connecting to %s port %s...", sa[0], sa[1])
sock = socket.socket(af, socktype)
sock.settimeout(10.0)
try:
sock.connect(sa)
except OSError as err:
sock.close()
_LOGGER.error("Connecting to %s port %s failed: %s", sa[0], sa[1], err)
continue
_LOGGER.info("Connected to %s", sa[0])
with open(filename, "rb") as file_handle:
try:
perform_ota(sock, password, file_handle, filename)
except OTAError as err:
_LOGGER.error(str(err))
return 1, None
finally:
sock.close()
# Successfully uploaded to sa[0]
return 0, sa[0]
_LOGGER.error("Connection failed.")
return 1, None
def run_ota(
remote_host: str | list[str], remote_port: int, password: str, filename: Path
) -> tuple[int, str | None]:
try:
return run_ota_impl_(remote_host, remote_port, password, filename)
except OTAError as err:
_LOGGER.error(err)
return 1, None