mirror of
https://github.com/esphome/esphome.git
synced 2025-09-26 23:22:21 +01:00
446 lines
15 KiB
Python
446 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
import gzip
|
|
import hashlib
|
|
import io
|
|
import logging
|
|
from pathlib import Path
|
|
import random
|
|
import socket
|
|
import sys
|
|
import time
|
|
from typing import Any
|
|
|
|
from esphome.core import EsphomeError
|
|
from esphome.helpers import resolve_ip_address
|
|
|
|
RESPONSE_OK = 0x00
|
|
RESPONSE_REQUEST_AUTH = 0x01
|
|
RESPONSE_REQUEST_SHA256_AUTH = 0x02
|
|
|
|
RESPONSE_HEADER_OK = 0x40
|
|
RESPONSE_AUTH_OK = 0x41
|
|
RESPONSE_UPDATE_PREPARE_OK = 0x42
|
|
RESPONSE_BIN_MD5_OK = 0x43
|
|
RESPONSE_RECEIVE_OK = 0x44
|
|
RESPONSE_UPDATE_END_OK = 0x45
|
|
RESPONSE_SUPPORTS_COMPRESSION = 0x46
|
|
RESPONSE_CHUNK_OK = 0x47
|
|
|
|
RESPONSE_ERROR_MAGIC = 0x80
|
|
RESPONSE_ERROR_UPDATE_PREPARE = 0x81
|
|
RESPONSE_ERROR_AUTH_INVALID = 0x82
|
|
RESPONSE_ERROR_WRITING_FLASH = 0x83
|
|
RESPONSE_ERROR_UPDATE_END = 0x84
|
|
RESPONSE_ERROR_INVALID_BOOTSTRAPPING = 0x85
|
|
RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG = 0x86
|
|
RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG = 0x87
|
|
RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE = 0x88
|
|
RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE = 0x89
|
|
RESPONSE_ERROR_NO_UPDATE_PARTITION = 0x8A
|
|
RESPONSE_ERROR_MD5_MISMATCH = 0x8B
|
|
RESPONSE_ERROR_UNKNOWN = 0xFF
|
|
|
|
OTA_VERSION_1_0 = 1
|
|
OTA_VERSION_2_0 = 2
|
|
|
|
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
|
|
|
|
FEATURE_SUPPORTS_COMPRESSION = 0x01
|
|
FEATURE_SUPPORTS_SHA256_AUTH = 0x02
|
|
|
|
|
|
UPLOAD_BLOCK_SIZE = 8192
|
|
UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Authentication method lookup table: response -> (hash_func, nonce_size, name)
|
|
_AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = {
|
|
RESPONSE_REQUEST_SHA256_AUTH: (hashlib.sha256, 64, "SHA256"),
|
|
RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"),
|
|
}
|
|
|
|
|
|
class ProgressBar:
|
|
def __init__(self):
|
|
self.last_progress = None
|
|
|
|
def update(self, progress):
|
|
bar_length = 60
|
|
status = ""
|
|
if progress >= 1:
|
|
progress = 1
|
|
status = "Done...\r\n"
|
|
new_progress = int(progress * 100)
|
|
if new_progress == self.last_progress:
|
|
return
|
|
self.last_progress = new_progress
|
|
block = int(round(bar_length * progress))
|
|
text = f"\rUploading: [{'=' * block + ' ' * (bar_length - block)}] {new_progress}% {status}"
|
|
sys.stderr.write(text)
|
|
sys.stderr.flush()
|
|
|
|
def done(self):
|
|
sys.stderr.write("\n")
|
|
sys.stderr.flush()
|
|
|
|
|
|
class OTAError(EsphomeError):
|
|
pass
|
|
|
|
|
|
def recv_decode(
|
|
sock: socket.socket, amount: int, decode: bool = True
|
|
) -> bytes | list[int]:
|
|
"""Receive data from socket and optionally decode to list of integers.
|
|
|
|
:param sock: Socket to receive data from.
|
|
:param amount: Number of bytes to receive.
|
|
:param decode: If True, convert bytes to list of integers, otherwise return raw bytes.
|
|
:return: List of integers if decode=True, otherwise raw bytes.
|
|
"""
|
|
data = sock.recv(amount)
|
|
if not decode:
|
|
return data
|
|
return list(data)
|
|
|
|
|
|
def receive_exactly(
|
|
sock: socket.socket,
|
|
amount: int,
|
|
msg: str,
|
|
expect: int | list[int] | None,
|
|
decode: bool = True,
|
|
) -> list[int] | bytes:
|
|
"""Receive exactly the specified amount of data from socket with error checking.
|
|
|
|
:param sock: Socket to receive data from.
|
|
:param amount: Exact number of bytes to receive.
|
|
:param msg: Description of what is being received for error messages.
|
|
:param expect: Expected response code(s) for validation, None to skip validation.
|
|
:param decode: If True, return list of integers, otherwise return raw bytes.
|
|
:return: List of integers if decode=True, otherwise raw bytes.
|
|
:raises OTAError: If receiving fails or response doesn't match expected.
|
|
"""
|
|
data: list[int] | bytes = [] if decode else b""
|
|
|
|
try:
|
|
data += recv_decode(sock, 1, decode=decode) # type: ignore[operator]
|
|
except OSError as err:
|
|
raise OTAError(f"Error receiving acknowledge {msg}: {err}") from err
|
|
|
|
try:
|
|
check_error(data, expect)
|
|
except OTAError as err:
|
|
sock.close()
|
|
raise OTAError(f"Error {msg}: {err}") from err
|
|
|
|
while len(data) < amount:
|
|
try:
|
|
data += recv_decode(sock, amount - len(data), decode=decode) # type: ignore[operator]
|
|
except OSError as err:
|
|
raise OTAError(f"Error receiving {msg}: {err}") from err
|
|
return data
|
|
|
|
|
|
def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None:
|
|
"""Check response data for error codes and validate against expected response.
|
|
|
|
:param data: Response data from device (first byte is the response code).
|
|
:param expect: Expected response code(s), None to skip validation.
|
|
:raises OTAError: If an error code is detected or response doesn't match expected.
|
|
"""
|
|
if not expect:
|
|
return
|
|
dat = data[0]
|
|
if dat == RESPONSE_ERROR_MAGIC:
|
|
raise OTAError("Error: Invalid magic byte")
|
|
if dat == RESPONSE_ERROR_UPDATE_PREPARE:
|
|
raise OTAError(
|
|
"Error: Couldn't prepare flash memory for update. Is the binary too big? "
|
|
"Please try restarting the ESP."
|
|
)
|
|
if dat == RESPONSE_ERROR_AUTH_INVALID:
|
|
raise OTAError("Error: Authentication invalid. Is the password correct?")
|
|
if dat == RESPONSE_ERROR_WRITING_FLASH:
|
|
raise OTAError(
|
|
"Error: Writing OTA data to flash memory failed. See USB logs for more "
|
|
"information."
|
|
)
|
|
if dat == RESPONSE_ERROR_UPDATE_END:
|
|
raise OTAError(
|
|
"Error: Finishing update failed. See the MQTT/USB logs for more "
|
|
"information."
|
|
)
|
|
if dat == RESPONSE_ERROR_INVALID_BOOTSTRAPPING:
|
|
raise OTAError(
|
|
"Error: Please press the reset button on the ESP. A manual reset is "
|
|
"required on the first OTA-Update after flashing via USB."
|
|
)
|
|
if dat == RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG:
|
|
raise OTAError(
|
|
"Error: ESP has been flashed with wrong flash size. Please choose the "
|
|
"correct 'board' option (esp01_1m always works) and then flash over USB."
|
|
)
|
|
if dat == RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG:
|
|
raise OTAError(
|
|
"Error: ESP does not have the requested flash size (wrong board). Please "
|
|
"choose the correct 'board' option (esp01_1m always works) and try "
|
|
"uploading again."
|
|
)
|
|
if dat == RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE:
|
|
raise OTAError(
|
|
"Error: ESP does not have enough space to store OTA file. Please try "
|
|
"flashing a minimal firmware (remove everything except ota)"
|
|
)
|
|
if dat == RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE:
|
|
raise OTAError(
|
|
"Error: The OTA partition on the ESP is too small. ESPHome needs to resize "
|
|
"this partition, please flash over USB."
|
|
)
|
|
if dat == RESPONSE_ERROR_NO_UPDATE_PARTITION:
|
|
raise OTAError(
|
|
"Error: The OTA partition on the ESP couldn't be found. ESPHome needs to create "
|
|
"this partition, please flash over USB."
|
|
)
|
|
if dat == RESPONSE_ERROR_MD5_MISMATCH:
|
|
raise OTAError(
|
|
"Error: Application MD5 code mismatch. Please try again "
|
|
"or flash over USB with a good quality cable."
|
|
)
|
|
if dat == RESPONSE_ERROR_UNKNOWN:
|
|
raise OTAError("Unknown error from ESP")
|
|
if not isinstance(expect, (list, tuple)):
|
|
expect = [expect]
|
|
if dat not in expect:
|
|
raise OTAError(f"Unexpected response from ESP: 0x{data[0]:02X}")
|
|
|
|
|
|
def send_check(
|
|
sock: socket.socket, data: list[int] | tuple[int, ...] | int | str | bytes, msg: str
|
|
) -> None:
|
|
"""Send data to socket with error handling.
|
|
|
|
:param sock: Socket to send data to.
|
|
:param data: Data to send (can be list/tuple of ints, single int, string, or bytes).
|
|
:param msg: Description of what is being sent for error messages.
|
|
:raises OTAError: If sending fails.
|
|
"""
|
|
try:
|
|
if isinstance(data, (list, tuple)):
|
|
data = bytes(data)
|
|
elif isinstance(data, int):
|
|
data = bytes([data])
|
|
elif isinstance(data, str):
|
|
data = data.encode("utf8")
|
|
|
|
sock.sendall(data)
|
|
except OSError as err:
|
|
raise OTAError(f"Error sending {msg}: {err}") from err
|
|
|
|
|
|
def perform_ota(
|
|
sock: socket.socket, password: str, file_handle: io.IOBase, filename: Path
|
|
) -> None:
|
|
file_contents = file_handle.read()
|
|
file_size = len(file_contents)
|
|
_LOGGER.info("Uploading %s (%s bytes)", filename, file_size)
|
|
|
|
# Enable nodelay, we need it for phase 1
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
send_check(sock, MAGIC_BYTES, "magic bytes")
|
|
|
|
_, version = receive_exactly(sock, 2, "version", RESPONSE_OK)
|
|
_LOGGER.debug("Device support OTA version: %s", version)
|
|
supported_versions = (OTA_VERSION_1_0, OTA_VERSION_2_0)
|
|
if version not in supported_versions:
|
|
raise OTAError(
|
|
f"Device uses unsupported OTA version {version}, this ESPHome supports {supported_versions}"
|
|
)
|
|
|
|
# Features - send both compression and SHA256 auth support
|
|
features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH
|
|
send_check(sock, features_to_send, "features")
|
|
features = receive_exactly(
|
|
sock,
|
|
1,
|
|
"features",
|
|
None, # Accept any response
|
|
)[0]
|
|
|
|
if features == RESPONSE_SUPPORTS_COMPRESSION:
|
|
upload_contents = gzip.compress(file_contents, compresslevel=9)
|
|
_LOGGER.info("Compressed to %s bytes", len(upload_contents))
|
|
else:
|
|
upload_contents = file_contents
|
|
|
|
def perform_auth(
|
|
sock: socket.socket,
|
|
password: str,
|
|
hash_func: Callable[..., Any],
|
|
nonce_size: int,
|
|
hash_name: str,
|
|
) -> None:
|
|
"""Perform challenge-response authentication using specified hash algorithm."""
|
|
if not password:
|
|
raise OTAError("ESP requests password, but no password given!")
|
|
|
|
nonce_bytes = receive_exactly(
|
|
sock, nonce_size, f"{hash_name} authentication nonce", [], decode=False
|
|
)
|
|
assert isinstance(nonce_bytes, bytes)
|
|
nonce = nonce_bytes.decode()
|
|
_LOGGER.debug("Auth: %s Nonce is %s", hash_name, nonce)
|
|
|
|
# Generate cnonce
|
|
cnonce = hash_func(str(random.random()).encode()).hexdigest()
|
|
_LOGGER.debug("Auth: %s CNonce is %s", hash_name, cnonce)
|
|
|
|
send_check(sock, cnonce, "auth cnonce")
|
|
|
|
# Calculate challenge response
|
|
hasher = hash_func()
|
|
hasher.update(password.encode("utf-8"))
|
|
hasher.update(nonce.encode())
|
|
hasher.update(cnonce.encode())
|
|
result = hasher.hexdigest()
|
|
_LOGGER.debug("Auth: %s Result is %s", hash_name, result)
|
|
|
|
send_check(sock, result, "auth result")
|
|
receive_exactly(sock, 1, "auth result", RESPONSE_AUTH_OK)
|
|
|
|
(auth,) = receive_exactly(
|
|
sock,
|
|
1,
|
|
"auth",
|
|
[RESPONSE_REQUEST_AUTH, RESPONSE_REQUEST_SHA256_AUTH, RESPONSE_AUTH_OK],
|
|
)
|
|
|
|
if auth != RESPONSE_AUTH_OK:
|
|
hash_func, nonce_size, hash_name = _AUTH_METHODS[auth]
|
|
perform_auth(sock, password, hash_func, nonce_size, hash_name)
|
|
|
|
# Set higher timeout during upload
|
|
sock.settimeout(30.0)
|
|
|
|
upload_size = len(upload_contents)
|
|
upload_size_encoded = [
|
|
(upload_size >> 24) & 0xFF,
|
|
(upload_size >> 16) & 0xFF,
|
|
(upload_size >> 8) & 0xFF,
|
|
(upload_size >> 0) & 0xFF,
|
|
]
|
|
send_check(sock, upload_size_encoded, "binary size")
|
|
receive_exactly(sock, 1, "binary size", RESPONSE_UPDATE_PREPARE_OK)
|
|
|
|
upload_md5 = hashlib.md5(upload_contents).hexdigest()
|
|
_LOGGER.debug("MD5 of upload is %s", upload_md5)
|
|
|
|
send_check(sock, upload_md5, "file checksum")
|
|
receive_exactly(sock, 1, "file checksum", RESPONSE_BIN_MD5_OK)
|
|
|
|
# Disable nodelay for transfer
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
|
|
# Limit send buffer (usually around 100kB) in order to have progress bar
|
|
# show the actual progress
|
|
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, UPLOAD_BUFFER_SIZE)
|
|
start_time = time.perf_counter()
|
|
|
|
offset = 0
|
|
progress = ProgressBar()
|
|
while True:
|
|
chunk = upload_contents[offset : offset + UPLOAD_BLOCK_SIZE]
|
|
if not chunk:
|
|
break
|
|
offset += len(chunk)
|
|
|
|
try:
|
|
sock.sendall(chunk)
|
|
if version >= OTA_VERSION_2_0:
|
|
receive_exactly(sock, 1, "chunk OK", RESPONSE_CHUNK_OK)
|
|
except OSError as err:
|
|
sys.stderr.write("\n")
|
|
raise OTAError(f"Error sending data: {err}") from err
|
|
|
|
progress.update(offset / upload_size)
|
|
progress.done()
|
|
|
|
# Enable nodelay for last checks
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
duration = time.perf_counter() - start_time
|
|
|
|
_LOGGER.info("Upload took %.2f seconds, waiting for result...", duration)
|
|
|
|
receive_exactly(sock, 1, "receive OK", RESPONSE_RECEIVE_OK)
|
|
receive_exactly(sock, 1, "Update end", RESPONSE_UPDATE_END_OK)
|
|
send_check(sock, RESPONSE_OK, "end acknowledgement")
|
|
|
|
_LOGGER.info("OTA successful")
|
|
|
|
# Do not connect logs until it is fully on
|
|
time.sleep(1)
|
|
|
|
|
|
def run_ota_impl_(
|
|
remote_host: str | list[str], remote_port: int, password: str, filename: Path
|
|
) -> tuple[int, str | None]:
|
|
from esphome.core import CORE
|
|
|
|
# Handle both single host and list of hosts
|
|
try:
|
|
# Resolve all hosts at once for parallel DNS resolution
|
|
res = resolve_ip_address(
|
|
remote_host, remote_port, address_cache=CORE.address_cache
|
|
)
|
|
except EsphomeError as err:
|
|
_LOGGER.error(
|
|
"Error resolving IP address of %s. Is it connected to WiFi?",
|
|
remote_host,
|
|
)
|
|
_LOGGER.error(
|
|
"(If this error persists, please set a static IP address: "
|
|
"https://esphome.io/components/wifi.html#manual-ips)"
|
|
)
|
|
raise OTAError(err) from err
|
|
|
|
for r in res:
|
|
af, socktype, _, _, sa = r
|
|
_LOGGER.info("Connecting to %s port %s...", sa[0], sa[1])
|
|
sock = socket.socket(af, socktype)
|
|
sock.settimeout(10.0)
|
|
try:
|
|
sock.connect(sa)
|
|
except OSError as err:
|
|
sock.close()
|
|
_LOGGER.error("Connecting to %s port %s failed: %s", sa[0], sa[1], err)
|
|
continue
|
|
|
|
_LOGGER.info("Connected to %s", sa[0])
|
|
with open(filename, "rb") as file_handle:
|
|
try:
|
|
perform_ota(sock, password, file_handle, filename)
|
|
except OTAError as err:
|
|
_LOGGER.error(str(err))
|
|
return 1, None
|
|
finally:
|
|
sock.close()
|
|
|
|
# Successfully uploaded to sa[0]
|
|
return 0, sa[0]
|
|
|
|
_LOGGER.error("Connection failed.")
|
|
return 1, None
|
|
|
|
|
|
def run_ota(
|
|
remote_host: str | list[str], remote_port: int, password: str, filename: Path
|
|
) -> tuple[int, str | None]:
|
|
try:
|
|
return run_ota_impl_(remote_host, remote_port, password, filename)
|
|
except OTAError as err:
|
|
_LOGGER.error(err)
|
|
return 1, None
|