mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-26 04:33:47 +00:00 
			
		
		
		
	OTA: Fix IPv6 and multiple address support (#7414)
This commit is contained in:
		| @@ -38,7 +38,7 @@ from esphome.const import ( | |||||||
|     SECRETS_FILES, |     SECRETS_FILES, | ||||||
| ) | ) | ||||||
| from esphome.core import CORE, EsphomeError, coroutine | from esphome.core import CORE, EsphomeError, coroutine | ||||||
| from esphome.helpers import indent, is_ip_address, get_bool_env | from esphome.helpers import get_bool_env, indent, is_ip_address | ||||||
| from esphome.log import Fore, color, setup_log | from esphome.log import Fore, color, setup_log | ||||||
| from esphome.util import ( | from esphome.util import ( | ||||||
|     get_serial_ports, |     get_serial_ports, | ||||||
| @@ -378,7 +378,7 @@ def show_logs(config, args, port): | |||||||
|  |  | ||||||
|             port = mqtt.get_esphome_device_ip( |             port = mqtt.get_esphome_device_ip( | ||||||
|                 config, args.username, args.password, args.client_id |                 config, args.username, args.password, args.client_id | ||||||
|             ) |             )[0] | ||||||
|  |  | ||||||
|         from esphome.components.api.client import run_logs |         from esphome.components.api.client import run_logs | ||||||
|  |  | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ import sys | |||||||
| import time | import time | ||||||
|  |  | ||||||
| from esphome.core import EsphomeError | from esphome.core import EsphomeError | ||||||
| from esphome.helpers import is_ip_address, resolve_ip_address | from esphome.helpers import resolve_ip_address | ||||||
|  |  | ||||||
| RESPONSE_OK = 0x00 | RESPONSE_OK = 0x00 | ||||||
| RESPONSE_REQUEST_AUTH = 0x01 | RESPONSE_REQUEST_AUTH = 0x01 | ||||||
| @@ -311,13 +311,8 @@ def perform_ota( | |||||||
|  |  | ||||||
|  |  | ||||||
| def run_ota_impl_(remote_host, remote_port, password, filename): | def run_ota_impl_(remote_host, remote_port, password, filename): | ||||||
|     if is_ip_address(remote_host): |  | ||||||
|         _LOGGER.info("Connecting to %s", remote_host) |  | ||||||
|         ip = remote_host |  | ||||||
|     else: |  | ||||||
|         _LOGGER.info("Resolving IP address of %s", remote_host) |  | ||||||
|     try: |     try: | ||||||
|             ip = resolve_ip_address(remote_host) |         res = resolve_ip_address(remote_host, remote_port) | ||||||
|     except EsphomeError as err: |     except EsphomeError as err: | ||||||
|         _LOGGER.error( |         _LOGGER.error( | ||||||
|             "Error resolving IP address of %s. Is it connected to WiFi?", |             "Error resolving IP address of %s. Is it connected to WiFi?", | ||||||
| @@ -328,17 +323,20 @@ def run_ota_impl_(remote_host, remote_port, password, filename): | |||||||
|             "https://esphome.io/components/wifi.html#manual-ips)" |             "https://esphome.io/components/wifi.html#manual-ips)" | ||||||
|         ) |         ) | ||||||
|         raise OTAError(err) from err |         raise OTAError(err) from err | ||||||
|         _LOGGER.info(" -> %s", ip) |  | ||||||
|  |  | ||||||
|     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |     for r in res: | ||||||
|  |         af, socktype, _, _, sa = r | ||||||
|  |         _LOGGER.info("Connecting to %s port %s...", sa[0], sa[1]) | ||||||
|  |         sock = socket.socket(af, socktype) | ||||||
|         sock.settimeout(10.0) |         sock.settimeout(10.0) | ||||||
|         try: |         try: | ||||||
|         sock.connect((ip, remote_port)) |             sock.connect(sa) | ||||||
|         except OSError as err: |         except OSError as err: | ||||||
|             sock.close() |             sock.close() | ||||||
|         _LOGGER.error("Connecting to %s:%s failed: %s", remote_host, remote_port, err) |             _LOGGER.error("Connecting to %s port %s failed: %s", sa[0], sa[1], err) | ||||||
|         return 1 |             continue | ||||||
|  |  | ||||||
|  |         _LOGGER.info("Connected to %s", sa[0]) | ||||||
|         with open(filename, "rb") as file_handle: |         with open(filename, "rb") as file_handle: | ||||||
|             try: |             try: | ||||||
|                 perform_ota(sock, password, file_handle, filename) |                 perform_ota(sock, password, file_handle, filename) | ||||||
| @@ -350,6 +348,9 @@ def run_ota_impl_(remote_host, remote_port, password, filename): | |||||||
|  |  | ||||||
|         return 0 |         return 0 | ||||||
|  |  | ||||||
|  |     _LOGGER.error("Connection failed.") | ||||||
|  |     return 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def run_ota(remote_host, remote_port, password, filename): | def run_ota(remote_host, remote_port, password, filename): | ||||||
|     try: |     try: | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import codecs | import codecs | ||||||
| from contextlib import suppress | from contextlib import suppress | ||||||
|  | import ipaddress | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -91,12 +92,8 @@ def mkdir_p(path): | |||||||
|  |  | ||||||
|  |  | ||||||
| def is_ip_address(host): | def is_ip_address(host): | ||||||
|     parts = host.split(".") |  | ||||||
|     if len(parts) != 4: |  | ||||||
|         return False |  | ||||||
|     try: |     try: | ||||||
|         for p in parts: |         ipaddress.ip_address(host) | ||||||
|             int(p) |  | ||||||
|         return True |         return True | ||||||
|     except ValueError: |     except ValueError: | ||||||
|         return False |         return False | ||||||
| @@ -127,25 +124,80 @@ def _resolve_with_zeroconf(host): | |||||||
|     return info |     return info | ||||||
|  |  | ||||||
|  |  | ||||||
| def resolve_ip_address(host): | def addr_preference_(res): | ||||||
|  |     # Trivial alternative to RFC6724 sorting. Put sane IPv6 first, then | ||||||
|  |     # Legacy IP, then IPv6 link-local addresses without an actual link. | ||||||
|  |     sa = res[4] | ||||||
|  |     ip = ipaddress.ip_address(sa[0]) | ||||||
|  |     if ip.version == 4: | ||||||
|  |         return 2 | ||||||
|  |     if ip.is_link_local and sa[3] == 0: | ||||||
|  |         return 3 | ||||||
|  |     return 1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def resolve_ip_address(host, port): | ||||||
|     import socket |     import socket | ||||||
|  |  | ||||||
|     from esphome.core import EsphomeError |     from esphome.core import EsphomeError | ||||||
|  |  | ||||||
|     errs = [] |     # There are five cases here. The host argument could be one of: | ||||||
|  |     #  • a *list* of IP addresses discovered by MQTT, | ||||||
|  |     #  • a single IP address specified by the user, | ||||||
|  |     #  • a .local hostname to be resolved by mDNS, | ||||||
|  |     #  • a normal hostname to be resolved in DNS, or | ||||||
|  |     #  • A URL from which we should extract the hostname. | ||||||
|  |     # | ||||||
|  |     # In each of the first three cases, we end up with IP addresses in | ||||||
|  |     # string form which need to be converted to a 5-tuple to be used | ||||||
|  |     # for the socket connection attempt. The easiest way to construct | ||||||
|  |     # those is to pass the IP address string to getaddrinfo(). Which, | ||||||
|  |     # coincidentally, is how we do hostname lookups in the other cases | ||||||
|  |     # too. So first build a list which contains either IP addresses or | ||||||
|  |     # a single hostname, then call getaddrinfo() on each element of | ||||||
|  |     # that list. | ||||||
|  |  | ||||||
|  |     errs = [] | ||||||
|  |     if isinstance(host, list): | ||||||
|  |         addr_list = host | ||||||
|  |     elif is_ip_address(host): | ||||||
|  |         addr_list = [host] | ||||||
|  |     else: | ||||||
|  |         url = urlparse(host) | ||||||
|  |         if url.scheme != "": | ||||||
|  |             host = url.hostname | ||||||
|  |  | ||||||
|  |         addr_list = [] | ||||||
|         if host.endswith(".local"): |         if host.endswith(".local"): | ||||||
|             try: |             try: | ||||||
|             return _resolve_with_zeroconf(host) |                 _LOGGER.info("Resolving IP address of %s in mDNS", host) | ||||||
|  |                 addr_list = _resolve_with_zeroconf(host) | ||||||
|             except EsphomeError as err: |             except EsphomeError as err: | ||||||
|                 errs.append(str(err)) |                 errs.append(str(err)) | ||||||
|  |  | ||||||
|  |         # If not mDNS, or if mDNS failed, use normal DNS | ||||||
|  |         if not addr_list: | ||||||
|  |             addr_list = [host] | ||||||
|  |  | ||||||
|  |     # Now we have a list containing either IP addresses or a hostname | ||||||
|  |     res = [] | ||||||
|  |     for addr in addr_list: | ||||||
|  |         if not is_ip_address(addr): | ||||||
|  |             _LOGGER.info("Resolving IP address of %s", host) | ||||||
|         try: |         try: | ||||||
|         host_url = host if (urlparse(host).scheme != "") else "http://" + host |             r = socket.getaddrinfo(addr, port, proto=socket.IPPROTO_TCP) | ||||||
|         return socket.gethostbyname(urlparse(host_url).hostname) |  | ||||||
|         except OSError as err: |         except OSError as err: | ||||||
|             errs.append(str(err)) |             errs.append(str(err)) | ||||||
|         raise EsphomeError(f"Error resolving IP address: {', '.join(errs)}") from err |             raise EsphomeError( | ||||||
|  |                 f"Error resolving IP address: {', '.join(errs)}" | ||||||
|  |             ) from err | ||||||
|  |  | ||||||
|  |         res = res + r | ||||||
|  |  | ||||||
|  |     # Zeroconf tends to give us link-local IPv6 addresses without specifying | ||||||
|  |     # the link. Put those last in the list to be attempted. | ||||||
|  |     res.sort(key=addr_preference_) | ||||||
|  |     return res | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_bool_env(var, default=False): | def get_bool_env(var, default=False): | ||||||
|   | |||||||
| @@ -175,8 +175,15 @@ def get_esphome_device_ip( | |||||||
|                 _LOGGER.Warn("Wrong device answer") |                 _LOGGER.Warn("Wrong device answer") | ||||||
|                 return |                 return | ||||||
|  |  | ||||||
|             if "ip" in data: |             dev_ip = [] | ||||||
|                 dev_ip = data["ip"] |             key = "ip" | ||||||
|  |             n = 0 | ||||||
|  |             while key in data: | ||||||
|  |                 dev_ip.append(data[key]) | ||||||
|  |                 n = n + 1 | ||||||
|  |                 key = "ip" + str(n) | ||||||
|  |  | ||||||
|  |             if dev_ip: | ||||||
|                 client.disconnect() |                 client.disconnect() | ||||||
|  |  | ||||||
|     def on_connect(client, userdata, flags, return_code): |     def on_connect(client, userdata, flags, return_code): | ||||||
|   | |||||||
| @@ -182,8 +182,8 @@ class EsphomeZeroconf(Zeroconf): | |||||||
|         if ( |         if ( | ||||||
|             info.load_from_cache(self) |             info.load_from_cache(self) | ||||||
|             or (timeout and info.request(self, timeout * 1000)) |             or (timeout and info.request(self, timeout * 1000)) | ||||||
|         ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): |         ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)): | ||||||
|             return str(addresses[0]) |             return addresses | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -194,6 +194,6 @@ class AsyncEsphomeZeroconf(AsyncZeroconf): | |||||||
|         if ( |         if ( | ||||||
|             info.load_from_cache(self.zeroconf) |             info.load_from_cache(self.zeroconf) | ||||||
|             or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) |             or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) | ||||||
|         ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)): |         ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)): | ||||||
|             return str(addresses[0]) |             return addresses | ||||||
|         return None |         return None | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user