mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-30 22:53:59 +00:00 
			
		
		
		
	Rework OTA to be more stable (#177)
* Rework OTA to be more stable * Lint
This commit is contained in:
		| @@ -185,7 +185,6 @@ def upload_using_esptool(config, port): | |||||||
|  |  | ||||||
|  |  | ||||||
| def upload_program(config, args, port): | def upload_program(config, args, port): | ||||||
|     _LOGGER.info("Uploading binary...") |  | ||||||
|     build_path = relative_path(config[CONF_ESPHOMEYAML][CONF_BUILD_PATH]) |     build_path = relative_path(config[CONF_ESPHOMEYAML][CONF_BUILD_PATH]) | ||||||
|  |  | ||||||
|     # if upload is to a serial port use platformio, otherwise assume ota |     # if upload is to a serial port use platformio, otherwise assume ota | ||||||
| @@ -211,19 +210,23 @@ def upload_program(config, args, port): | |||||||
|         host = get_upload_host(config) |         host = get_upload_host(config) | ||||||
|  |  | ||||||
|     from esphomeyaml.components import ota |     from esphomeyaml.components import ota | ||||||
|     from esphomeyaml import espota |     from esphomeyaml import espota2 | ||||||
|  |  | ||||||
|     bin_file = os.path.join(build_path, '.pioenvs', core.NAME, 'firmware.bin') |     bin_file = os.path.join(build_path, '.pioenvs', core.NAME, 'firmware.bin') | ||||||
|     if args.host_port is not None: |     if args.host_port is not None: | ||||||
|         host_port = args.host_port |         host_port = args.host_port | ||||||
|     else: |     else: | ||||||
|         host_port = int(os.getenv('ESPHOMEYAML_OTA_HOST_PORT', random.randint(10000, 60000))) |         host_port = int(os.getenv('ESPHOMEYAML_OTA_HOST_PORT', random.randint(10000, 60000))) | ||||||
|     espota_args = ['espota.py', '--debug', '--progress', '-i', host, |  | ||||||
|                    '-p', str(ota.get_port(config)), '-f', bin_file, |     verbose = args.verbose | ||||||
|                    '-a', ota.get_auth(config), '-P', str(host_port)] |     remote_port = ota.get_port(config) | ||||||
|     if args.verbose: |     password = ota.get_auth(config) | ||||||
|         espota_args.append('-d') |  | ||||||
|     return espota.main(espota_args) |     res = espota2.run_ota(host, remote_port, password, bin_file) | ||||||
|  |     if res == 0: | ||||||
|  |         return res | ||||||
|  |     _LOGGER.warn("OTA v2 method failed. Trying with legacy OTA...") | ||||||
|  |     return espota2.run_legacy_ota(verbose, host_port, host, remote_port, password, bin_file) | ||||||
|  |  | ||||||
|  |  | ||||||
| def show_logs(config, args, port, escape=False): | def show_logs(config, args, port, escape=False): | ||||||
|   | |||||||
| @@ -1,10 +1,9 @@ | |||||||
| import hashlib |  | ||||||
| import logging | import logging | ||||||
|  |  | ||||||
| import voluptuous as vol | import voluptuous as vol | ||||||
|  |  | ||||||
| import esphomeyaml.config_validation as cv |  | ||||||
| from esphomeyaml import core | from esphomeyaml import core | ||||||
|  | import esphomeyaml.config_validation as cv | ||||||
| from esphomeyaml.const import CONF_ID, CONF_OTA, CONF_PASSWORD, CONF_PORT, CONF_SAFE_MODE, \ | from esphomeyaml.const import CONF_ID, CONF_OTA, CONF_PASSWORD, CONF_PORT, CONF_SAFE_MODE, \ | ||||||
|     ESP_PLATFORM_ESP32, ESP_PLATFORM_ESP8266 |     ESP_PLATFORM_ESP32, ESP_PLATFORM_ESP8266 | ||||||
| from esphomeyaml.core import ESPHomeYAMLError | from esphomeyaml.core import ESPHomeYAMLError | ||||||
| @@ -27,8 +26,7 @@ def to_code(config): | |||||||
|     rhs = App.init_ota() |     rhs = App.init_ota() | ||||||
|     ota = Pvariable(config[CONF_ID], rhs) |     ota = Pvariable(config[CONF_ID], rhs) | ||||||
|     if CONF_PASSWORD in config: |     if CONF_PASSWORD in config: | ||||||
|         hash_ = hashlib.md5(config[CONF_PASSWORD].encode()).hexdigest() |         add(ota.set_auth_password(config[CONF_PASSWORD])) | ||||||
|         add(ota.set_auth_password_hash(hash_)) |  | ||||||
|     if CONF_PORT in config: |     if CONF_PORT in config: | ||||||
|         add(ota.set_port(config[CONF_PORT])) |         add(ota.set_port(config[CONF_PORT])) | ||||||
|     if config[CONF_SAFE_MODE]: |     if config[CONF_SAFE_MODE]: | ||||||
| @@ -50,6 +48,7 @@ def get_auth(config): | |||||||
|  |  | ||||||
|  |  | ||||||
| BUILD_FLAGS = '-DUSE_OTA' | BUILD_FLAGS = '-DUSE_OTA' | ||||||
|  | REQUIRED_BUILD_FLAGS = '-DUSE_NEW_OTA' | ||||||
|  |  | ||||||
|  |  | ||||||
| def lib_deps(config): | def lib_deps(config): | ||||||
|   | |||||||
| @@ -60,7 +60,7 @@ def validate_simple_esphomelib_version(value): | |||||||
|             CONF_REPOSITORY: LIBRARY_URI_REPO, |             CONF_REPOSITORY: LIBRARY_URI_REPO, | ||||||
|             CONF_TAG: 'v' + value, |             CONF_TAG: 'v' + value, | ||||||
|         } |         } | ||||||
|     return value |     raise vol.Invalid("Only simple esphomelib versions!") | ||||||
|  |  | ||||||
|  |  | ||||||
| def validate_local_esphomelib_version(value): | def validate_local_esphomelib_version(value): | ||||||
|   | |||||||
							
								
								
									
										230
									
								
								esphomeyaml/espota2.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										230
									
								
								esphomeyaml/espota2.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,230 @@ | |||||||
|  | import hashlib | ||||||
|  | import logging | ||||||
|  | import random | ||||||
|  | import socket | ||||||
|  | import sys | ||||||
|  |  | ||||||
|  | RESPONSE_OK = 0 | ||||||
|  | RESPONSE_REQUEST_AUTH = 1 | ||||||
|  |  | ||||||
|  | RESPONSE_HEADER_OK = 64 | ||||||
|  | RESPONSE_AUTH_OK = 65 | ||||||
|  | RESPONSE_UPDATE_PREPARE_OK = 66 | ||||||
|  | RESPONSE_BIN_MD5_OK = 67 | ||||||
|  | RESPONSE_RECEIVE_OK = 68 | ||||||
|  | RESPONSE_UPDATE_END_OK = 69 | ||||||
|  |  | ||||||
|  | RESPONSE_ERROR_MAGIC = 128 | ||||||
|  | RESPONSE_ERROR_UPDATE_PREPARE = 129 | ||||||
|  | RESPONSE_ERROR_AUTH_INVALID = 130 | ||||||
|  | RESPONSE_ERROR_WRITING_FLASH = 131 | ||||||
|  | RESPONSE_ERROR_UPDATE_END = 132 | ||||||
|  | RESPONSE_ERROR_INVALID_BOOTSTRAPPING = 133 | ||||||
|  | RESPONSE_ERROR_UNKNOWN = 255 | ||||||
|  |  | ||||||
|  | OTA_VERSION_1_0 = 1 | ||||||
|  |  | ||||||
|  | MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45] | ||||||
|  |  | ||||||
|  | _LOGGER = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def update_progress(progress): | ||||||
|  |     bar_length = 60 | ||||||
|  |     status = "" | ||||||
|  |     if progress >= 1: | ||||||
|  |         progress = 1 | ||||||
|  |         status = "Done...\r\n" | ||||||
|  |     block = int(round(bar_length * progress)) | ||||||
|  |     text = "\rUploading: [{0}] {1}% {2}".format("=" * block + " " * (bar_length - block), | ||||||
|  |                                                 int(progress * 100), status) | ||||||
|  |     sys.stderr.write(text) | ||||||
|  |     sys.stderr.flush() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class OTAError(Exception): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def recv_decode(sock, amount, decode=True): | ||||||
|  |     data = sock.recv(amount) | ||||||
|  |     if not decode: | ||||||
|  |         return data | ||||||
|  |     return [ord(x) for x in data] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def receive_exactly(sock, amount, msg, expect, decode=True): | ||||||
|  |     if decode: | ||||||
|  |         data = [] | ||||||
|  |     else: | ||||||
|  |         data = '' | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         data += recv_decode(sock, 1, decode=decode) | ||||||
|  |     except socket.error as err: | ||||||
|  |         raise OTAError("Error receiving acknowledge {}: {}".format(msg, err)) | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         check_error(data, expect) | ||||||
|  |     except OTAError: | ||||||
|  |         sock.close() | ||||||
|  |         raise | ||||||
|  |  | ||||||
|  |     while len(data) < amount: | ||||||
|  |         try: | ||||||
|  |             data += recv_decode(sock, amount - len(data), decode=decode) | ||||||
|  |         except socket.error as err: | ||||||
|  |             raise OTAError("Error receiving {}: {}".format(msg, err)) | ||||||
|  |     return data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_error(data, expect): | ||||||
|  |     if not expect: | ||||||
|  |         return | ||||||
|  |     dat = data[0] | ||||||
|  |     if dat == RESPONSE_ERROR_MAGIC: | ||||||
|  |         raise OTAError("Error: Invalid magic byte") | ||||||
|  |     if dat == RESPONSE_ERROR_UPDATE_PREPARE: | ||||||
|  |         raise OTAError("Error: Couldn't prepare flash memory for update. Is the binary too big?") | ||||||
|  |     if dat == RESPONSE_ERROR_AUTH_INVALID: | ||||||
|  |         raise OTAError("Error: Authentication invalid. Is the password correct?") | ||||||
|  |     if dat == RESPONSE_ERROR_WRITING_FLASH: | ||||||
|  |         raise OTAError("Error: Wring OTA data to flash memory failed. See USB logs for more " | ||||||
|  |                        "information.") | ||||||
|  |     if dat == RESPONSE_ERROR_UPDATE_END: | ||||||
|  |         raise OTAError("Error: Finishing update failed. See the MQTT/USB logs for more " | ||||||
|  |                        "information.") | ||||||
|  |     if dat == RESPONSE_ERROR_INVALID_BOOTSTRAPPING: | ||||||
|  |         raise OTAError("Error: Please press the reset button on the ESP. A manual reset is " | ||||||
|  |                        "required on the first OTA-Update after flashing via USB.") | ||||||
|  |     if dat == RESPONSE_ERROR_UNKNOWN: | ||||||
|  |         raise OTAError("Unknown error from ESP") | ||||||
|  |     if not isinstance(expect, (list, tuple)): | ||||||
|  |         expect = [expect] | ||||||
|  |     if dat not in expect: | ||||||
|  |         raise OTAError("Unexpected response from ESP: 0x{:02X}".format(data[0])) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def send_check(sock, data, msg): | ||||||
|  |     try: | ||||||
|  |         if isinstance(data, (list, tuple)): | ||||||
|  |             data = ''.join([chr(x) for x in data]) | ||||||
|  |         elif isinstance(data, int): | ||||||
|  |             data = chr(data) | ||||||
|  |         sock.send(data) | ||||||
|  |     except socket.error as err: | ||||||
|  |         raise OTAError("Error sending {}: {}".format(msg, err)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def perform_ota(sock, password, file_handle, filename): | ||||||
|  |     file_md5 = hashlib.md5(file_handle.read()).hexdigest() | ||||||
|  |     file_size = file_handle.tell() | ||||||
|  |     _LOGGER.info('Uploading %s (%s bytes)', filename, file_size) | ||||||
|  |     file_handle.seek(0) | ||||||
|  |     _LOGGER.debug("MD5 of binary is %s", file_md5) | ||||||
|  |  | ||||||
|  |     send_check(sock, MAGIC_BYTES, 'magic bytes') | ||||||
|  |  | ||||||
|  |     _, version = receive_exactly(sock, 2, 'version', RESPONSE_OK) | ||||||
|  |     if version != OTA_VERSION_1_0: | ||||||
|  |         raise OTAError("Unsupported OTA version {}".format(version)) | ||||||
|  |  | ||||||
|  |     # Features | ||||||
|  |     send_check(sock, 0x00, 'features') | ||||||
|  |     receive_exactly(sock, 1, 'features', RESPONSE_HEADER_OK) | ||||||
|  |  | ||||||
|  |     auth, = receive_exactly(sock, 1, 'auth', [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK]) | ||||||
|  |     if auth == RESPONSE_REQUEST_AUTH: | ||||||
|  |         if not password: | ||||||
|  |             raise OTAError("ESP requests password, but no password given!") | ||||||
|  |         nonce = receive_exactly(sock, 32, 'authentication nonce', [], decode=False) | ||||||
|  |         _LOGGER.debug("Auth: Nonce is %s", nonce) | ||||||
|  |         cnonce = hashlib.md5(str(random.random()).encode()).hexdigest() | ||||||
|  |         _LOGGER.debug("Auth: CNonce is %s", cnonce) | ||||||
|  |  | ||||||
|  |         send_check(sock, cnonce, 'auth cnonce') | ||||||
|  |  | ||||||
|  |         result_md5 = hashlib.md5() | ||||||
|  |         result_md5.update(password.encode()) | ||||||
|  |         result_md5.update(nonce.encode()) | ||||||
|  |         result_md5.update(cnonce.encode()) | ||||||
|  |         result = result_md5.hexdigest() | ||||||
|  |         _LOGGER.debug("Auth: Result is %s", result) | ||||||
|  |  | ||||||
|  |         send_check(sock, result, 'auth result') | ||||||
|  |         receive_exactly(sock, 1, 'auth result', RESPONSE_AUTH_OK) | ||||||
|  |     else: | ||||||
|  |         if password: | ||||||
|  |             raise OTAError("Password specified, but ESP doesn't accept password!") | ||||||
|  |  | ||||||
|  |     file_size_encoded = [ | ||||||
|  |         (file_size >> 24) & 0xFF, | ||||||
|  |         (file_size >> 16) & 0xFF, | ||||||
|  |         (file_size >> 8) & 0xFF, | ||||||
|  |         (file_size >> 0) & 0xFF, | ||||||
|  |     ] | ||||||
|  |     send_check(sock, file_size_encoded, 'binary size') | ||||||
|  |     receive_exactly(sock, 1, 'binary size', RESPONSE_UPDATE_PREPARE_OK) | ||||||
|  |  | ||||||
|  |     send_check(sock, file_md5, 'file checksum') | ||||||
|  |     receive_exactly(sock, 1, 'file checksum', RESPONSE_BIN_MD5_OK) | ||||||
|  |  | ||||||
|  |     sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096) | ||||||
|  |     offset = 0 | ||||||
|  |     update_progress(0.0) | ||||||
|  |     while True: | ||||||
|  |         chunk = file_handle.read(1024) | ||||||
|  |         if not chunk: | ||||||
|  |             break | ||||||
|  |         offset += len(chunk) | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             sock.sendall(chunk) | ||||||
|  |         except socket.error as err: | ||||||
|  |             sys.stderr.write('\n') | ||||||
|  |             raise OTAError("Error sending data: {}".format(err)) | ||||||
|  |  | ||||||
|  |         update_progress(offset / float(file_size)) | ||||||
|  |  | ||||||
|  |     sys.stderr.write('\n') | ||||||
|  |     _LOGGER.info("Waiting for result...") | ||||||
|  |  | ||||||
|  |     receive_exactly(sock, 1, 'receive OK', RESPONSE_RECEIVE_OK) | ||||||
|  |     receive_exactly(sock, 1, 'Update end', RESPONSE_UPDATE_END_OK) | ||||||
|  |  | ||||||
|  |     _LOGGER.info("OTA successful") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_ota(remote_host, remote_port, password, filename): | ||||||
|  |     _LOGGER.info("Connecting to %s:%s...", remote_host, remote_port) | ||||||
|  |     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||||||
|  |     sock.settimeout(5.0) | ||||||
|  |     try: | ||||||
|  |         sock.connect((remote_host, remote_port)) | ||||||
|  |     except socket.error as err: | ||||||
|  |         sock.close() | ||||||
|  |         _LOGGER.error("Connecting to %s:%s failed: %s", remote_host, remote_port, err) | ||||||
|  |         return 1 | ||||||
|  |  | ||||||
|  |     file_handle = open(filename, 'rb') | ||||||
|  |     try: | ||||||
|  |         perform_ota(sock, password, file_handle, filename) | ||||||
|  |     except OTAError as err: | ||||||
|  |         _LOGGER.error(str(err)) | ||||||
|  |         return 1 | ||||||
|  |     finally: | ||||||
|  |         sock.close() | ||||||
|  |         file_handle.close() | ||||||
|  |  | ||||||
|  |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_legacy_ota(verbose, host_port, remote_host, remote_port, password, filename): | ||||||
|  |     from esphomeyaml import espota | ||||||
|  |  | ||||||
|  |     espota_args = ['espota.py', '--debug', '--progress', '-i', remote_host, | ||||||
|  |                    '-p', str(remote_port), '-f', filename, | ||||||
|  |                    '-a', password, '-P', str(host_port)] | ||||||
|  |     if verbose: | ||||||
|  |         espota_args.append('-d') | ||||||
|  |     return espota.main(espota_args) | ||||||
		Reference in New Issue
	
	Block a user