From 2f77d316905f688fef7d899dfedc004ba6a73a49 Mon Sep 17 00:00:00 2001
From: David Woodhouse <dwmw2@infradead.org>
Date: Fri, 8 Nov 2024 03:38:13 +0000
Subject: [PATCH] OTA: Fix IPv6 and multiple address support (#7414)

---
 esphome/__main__.py |  4 +--
 esphome/espota2.py  | 69 +++++++++++++++++++-------------------
 esphome/helpers.py  | 82 ++++++++++++++++++++++++++++++++++++---------
 esphome/mqtt.py     | 11 ++++--
 esphome/zeroconf.py |  8 ++---
 5 files changed, 117 insertions(+), 57 deletions(-)

diff --git a/esphome/__main__.py b/esphome/__main__.py
index cf2741dbdb..85ab3cc00c 100644
--- a/esphome/__main__.py
+++ b/esphome/__main__.py
@@ -38,7 +38,7 @@ from esphome.const import (
     SECRETS_FILES,
 )
 from esphome.core import CORE, EsphomeError, coroutine
-from esphome.helpers import indent, is_ip_address, get_bool_env
+from esphome.helpers import get_bool_env, indent, is_ip_address
 from esphome.log import Fore, color, setup_log
 from esphome.util import (
     get_serial_ports,
@@ -378,7 +378,7 @@ def show_logs(config, args, port):
 
             port = mqtt.get_esphome_device_ip(
                 config, args.username, args.password, args.client_id
-            )
+            )[0]
 
         from esphome.components.api.client import run_logs
 
diff --git a/esphome/espota2.py b/esphome/espota2.py
index 580536153a..94b845b246 100644
--- a/esphome/espota2.py
+++ b/esphome/espota2.py
@@ -10,7 +10,7 @@ import sys
 import time
 
 from esphome.core import EsphomeError
-from esphome.helpers import is_ip_address, resolve_ip_address
+from esphome.helpers import resolve_ip_address
 
 RESPONSE_OK = 0x00
 RESPONSE_REQUEST_AUTH = 0x01
@@ -311,44 +311,45 @@ def perform_ota(
 
 
 def run_ota_impl_(remote_host, remote_port, password, filename):
-    if is_ip_address(remote_host):
-        _LOGGER.info("Connecting to %s", remote_host)
-        ip = remote_host
-    else:
-        _LOGGER.info("Resolving IP address of %s", remote_host)
-        try:
-            ip = resolve_ip_address(remote_host)
-        except EsphomeError as err:
-            _LOGGER.error(
-                "Error resolving IP address of %s. Is it connected to WiFi?",
-                remote_host,
-            )
-            _LOGGER.error(
-                "(If this error persists, please set a static IP address: "
-                "https://esphome.io/components/wifi.html#manual-ips)"
-            )
-            raise OTAError(err) from err
-        _LOGGER.info(" -> %s", ip)
-
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    sock.settimeout(10.0)
     try:
-        sock.connect((ip, remote_port))
-    except OSError as err:
-        sock.close()
-        _LOGGER.error("Connecting to %s:%s failed: %s", remote_host, remote_port, err)
-        return 1
+        res = resolve_ip_address(remote_host, remote_port)
+    except EsphomeError as err:
+        _LOGGER.error(
+            "Error resolving IP address of %s. Is it connected to WiFi?",
+            remote_host,
+        )
+        _LOGGER.error(
+            "(If this error persists, please set a static IP address: "
+            "https://esphome.io/components/wifi.html#manual-ips)"
+        )
+        raise OTAError(err) from err
 
-    with open(filename, "rb") as file_handle:
+    for r in res:
+        af, socktype, _, _, sa = r
+        _LOGGER.info("Connecting to %s port %s...", sa[0], sa[1])
+        sock = socket.socket(af, socktype)
+        sock.settimeout(10.0)
         try:
-            perform_ota(sock, password, file_handle, filename)
-        except OTAError as err:
-            _LOGGER.error(str(err))
-            return 1
-        finally:
+            sock.connect(sa)
+        except OSError as err:
             sock.close()
+            _LOGGER.error("Connecting to %s port %s failed: %s", sa[0], sa[1], err)
+            continue
 
-    return 0
+        _LOGGER.info("Connected to %s", sa[0])
+        with open(filename, "rb") as file_handle:
+            try:
+                perform_ota(sock, password, file_handle, filename)
+            except OTAError as err:
+                _LOGGER.error(str(err))
+                return 1
+            finally:
+                sock.close()
+
+        return 0
+
+    _LOGGER.error("Connection failed.")
+    return 1
 
 
 def run_ota(remote_host, remote_port, password, filename):
diff --git a/esphome/helpers.py b/esphome/helpers.py
index 2a7e5cd9b6..8aae43c2bb 100644
--- a/esphome/helpers.py
+++ b/esphome/helpers.py
@@ -1,5 +1,6 @@
 import codecs
 from contextlib import suppress
+import ipaddress
 import logging
 import os
 from pathlib import Path
@@ -91,12 +92,8 @@ def mkdir_p(path):
 
 
 def is_ip_address(host):
-    parts = host.split(".")
-    if len(parts) != 4:
-        return False
     try:
-        for p in parts:
-            int(p)
+        ipaddress.ip_address(host)
         return True
     except ValueError:
         return False
@@ -127,25 +124,80 @@ def _resolve_with_zeroconf(host):
     return info
 
 
-def resolve_ip_address(host):
+def addr_preference_(res):
+    # Trivial alternative to RFC6724 sorting. Put sane IPv6 first, then
+    # Legacy IP, then IPv6 link-local addresses without an actual link.
+    sa = res[4]
+    ip = ipaddress.ip_address(sa[0])
+    if ip.version == 4:
+        return 2
+    if ip.is_link_local and sa[3] == 0:
+        return 3
+    return 1
+
+
+def resolve_ip_address(host, port):
     import socket
 
     from esphome.core import EsphomeError
 
+    # There are five cases here. The host argument could be one of:
+    #  • a *list* of IP addresses discovered by MQTT,
+    #  • a single IP address specified by the user,
+    #  • a .local hostname to be resolved by mDNS,
+    #  • a normal hostname to be resolved in DNS, or
+    #  • A URL from which we should extract the hostname.
+    #
+    # In each of the first three cases, we end up with IP addresses in
+    # string form which need to be converted to a 5-tuple to be used
+    # for the socket connection attempt. The easiest way to construct
+    # those is to pass the IP address string to getaddrinfo(). Which,
+    # coincidentally, is how we do hostname lookups in the other cases
+    # too. So first build a list which contains either IP addresses or
+    # a single hostname, then call getaddrinfo() on each element of
+    # that list.
+
     errs = []
+    if isinstance(host, list):
+        addr_list = host
+    elif is_ip_address(host):
+        addr_list = [host]
+    else:
+        url = urlparse(host)
+        if url.scheme != "":
+            host = url.hostname
 
-    if host.endswith(".local"):
+        addr_list = []
+        if host.endswith(".local"):
+            try:
+                _LOGGER.info("Resolving IP address of %s in mDNS", host)
+                addr_list = _resolve_with_zeroconf(host)
+            except EsphomeError as err:
+                errs.append(str(err))
+
+        # If not mDNS, or if mDNS failed, use normal DNS
+        if not addr_list:
+            addr_list = [host]
+
+    # Now we have a list containing either IP addresses or a hostname
+    res = []
+    for addr in addr_list:
+        if not is_ip_address(addr):
+            _LOGGER.info("Resolving IP address of %s", host)
         try:
-            return _resolve_with_zeroconf(host)
-        except EsphomeError as err:
+            r = socket.getaddrinfo(addr, port, proto=socket.IPPROTO_TCP)
+        except OSError as err:
             errs.append(str(err))
+            raise EsphomeError(
+                f"Error resolving IP address: {', '.join(errs)}"
+            ) from err
 
-    try:
-        host_url = host if (urlparse(host).scheme != "") else "http://" + host
-        return socket.gethostbyname(urlparse(host_url).hostname)
-    except OSError as err:
-        errs.append(str(err))
-        raise EsphomeError(f"Error resolving IP address: {', '.join(errs)}") from err
+        res = res + r
+
+    # Zeroconf tends to give us link-local IPv6 addresses without specifying
+    # the link. Put those last in the list to be attempted.
+    res.sort(key=addr_preference_)
+    return res
 
 
 def get_bool_env(var, default=False):
diff --git a/esphome/mqtt.py b/esphome/mqtt.py
index d55fb0202d..2f90c49025 100644
--- a/esphome/mqtt.py
+++ b/esphome/mqtt.py
@@ -175,8 +175,15 @@ def get_esphome_device_ip(
                 _LOGGER.Warn("Wrong device answer")
                 return
 
-            if "ip" in data:
-                dev_ip = data["ip"]
+            dev_ip = []
+            key = "ip"
+            n = 0
+            while key in data:
+                dev_ip.append(data[key])
+                n = n + 1
+                key = "ip" + str(n)
+
+            if dev_ip:
                 client.disconnect()
 
     def on_connect(client, userdata, flags, return_code):
diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py
index b3ee64e259..76049fa776 100644
--- a/esphome/zeroconf.py
+++ b/esphome/zeroconf.py
@@ -182,8 +182,8 @@ class EsphomeZeroconf(Zeroconf):
         if (
             info.load_from_cache(self)
             or (timeout and info.request(self, timeout * 1000))
-        ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
-            return str(addresses[0])
+        ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
+            return addresses
         return None
 
 
@@ -194,6 +194,6 @@ class AsyncEsphomeZeroconf(AsyncZeroconf):
         if (
             info.load_from_cache(self.zeroconf)
             or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
-        ) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
-            return str(addresses[0])
+        ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
+            return addresses
         return None