From 1204b4f1bd96b636086d7118e7f79550d86c37e0 Mon Sep 17 00:00:00 2001
From: Landon Rohatensky <landonroha@gmail.com>
Date: Thu, 23 Nov 2023 11:10:33 -0800
Subject: [PATCH] Allow images to be downloaded from URLs (#5214)

Co-authored-by: guillempages <guillempages@users.noreply.github.com>
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
---
 docker/Dockerfile                    |   1 +
 esphome/components/image/__init__.py | 121 +++++++++++++++++++++------
 esphome/external_files.py            |  75 +++++++++++++++++
 requirements.txt                     |   1 +
 tests/test2.yaml                     |  12 +++
 5 files changed, 184 insertions(+), 26 deletions(-)
 create mode 100644 esphome/external_files.py

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 1bf754464d..ee7c70bb0f 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -38,6 +38,7 @@ RUN \
         openssh-client=1:9.2p1-2+deb12u1 \
         python3-cffi=1.15.1-5 \
         libcairo2=1.16.0-7 \
+        libmagic1=1:5.44-3 \
         patch=2.7.6-7; \
     if [ "$TARGETARCH$TARGETVARIANT" = "armv7" ]; then \
         apt-get install -y --no-install-recommends \
diff --git a/esphome/components/image/__init__.py b/esphome/components/image/__init__.py
index 1b7c654b0b..c11021fc9c 100644
--- a/esphome/components/image/__init__.py
+++ b/esphome/components/image/__init__.py
@@ -1,15 +1,23 @@
+from __future__ import annotations
+
 import logging
 
+import hashlib
 import io
 from pathlib import Path
 import re
 import requests
+from magic import Magic
+
+from PIL import Image
 
 from esphome import core
 from esphome.components import font
+from esphome import external_files
 import esphome.config_validation as cv
 import esphome.codegen as cg
 from esphome.const import (
+    __version__,
     CONF_DITHER,
     CONF_FILE,
     CONF_ICON,
@@ -19,6 +27,7 @@ from esphome.const import (
     CONF_RESIZE,
     CONF_SOURCE,
     CONF_TYPE,
+    CONF_URL,
 )
 from esphome.core import CORE, HexInt
 
@@ -43,34 +52,74 @@ IMAGE_TYPE = {
 CONF_USE_TRANSPARENCY = "use_transparency"
 
 # If the MDI file cannot be downloaded within this time, abort.
-MDI_DOWNLOAD_TIMEOUT = 30  # seconds
+IMAGE_DOWNLOAD_TIMEOUT = 30  # seconds
 
 SOURCE_LOCAL = "local"
 SOURCE_MDI = "mdi"
+SOURCE_WEB = "web"
+
 
 Image_ = image_ns.class_("Image")
 
 
-def _compute_local_icon_path(value) -> Path:
-    base_dir = Path(CORE.data_dir) / DOMAIN / "mdi"
+def _compute_local_icon_path(value: dict) -> Path:
+    base_dir = external_files.compute_local_file_dir(DOMAIN) / "mdi"
     return base_dir / f"{value[CONF_ICON]}.svg"
 
 
-def download_mdi(value):
-    mdi_id = value[CONF_ICON]
-    path = _compute_local_icon_path(value)
-    if path.is_file():
-        return value
-    url = f"https://raw.githubusercontent.com/Templarian/MaterialDesign/master/svg/{mdi_id}.svg"
-    _LOGGER.debug("Downloading %s MDI image from %s", mdi_id, url)
+def _compute_local_image_path(value: dict) -> Path:
+    url = value[CONF_URL]
+    h = hashlib.new("sha256")
+    h.update(url.encode())
+    key = h.hexdigest()[:8]
+    base_dir = external_files.compute_local_file_dir(DOMAIN)
+    return base_dir / key
+
+
+def download_content(url: str, path: Path) -> None:
+    if not external_files.has_remote_file_changed(url, path):
+        _LOGGER.debug("Remote file has not changed %s", url)
+        return
+
+    _LOGGER.debug(
+        "Remote file has changed, downloading from %s to %s",
+        url,
+        path,
+    )
+
     try:
-        req = requests.get(url, timeout=MDI_DOWNLOAD_TIMEOUT)
+        req = requests.get(
+            url,
+            timeout=IMAGE_DOWNLOAD_TIMEOUT,
+            headers={"User-agent": f"ESPHome/{__version__} (https://esphome.io)"},
+        )
         req.raise_for_status()
     except requests.exceptions.RequestException as e:
-        raise cv.Invalid(f"Could not download MDI image {mdi_id} from {url}: {e}")
+        raise cv.Invalid(f"Could not download from {url}: {e}")
 
     path.parent.mkdir(parents=True, exist_ok=True)
     path.write_bytes(req.content)
+
+
+def download_mdi(value):
+    validate_cairosvg_installed(value)
+
+    mdi_id = value[CONF_ICON]
+    path = _compute_local_icon_path(value)
+
+    url = f"https://raw.githubusercontent.com/Templarian/MaterialDesign/master/svg/{mdi_id}.svg"
+
+    download_content(url, path)
+
+    return value
+
+
+def download_image(value):
+    url = value[CONF_URL]
+    path = _compute_local_image_path(value)
+
+    download_content(url, path)
+
     return value
 
 
@@ -139,6 +188,13 @@ def validate_file_shorthand(value):
                 CONF_ICON: icon,
             }
         )
+    if value.startswith("http://") or value.startswith("https://"):
+        return FILE_SCHEMA(
+            {
+                CONF_SOURCE: SOURCE_WEB,
+                CONF_URL: value,
+            }
+        )
     return FILE_SCHEMA(
         {
             CONF_SOURCE: SOURCE_LOCAL,
@@ -160,10 +216,18 @@ MDI_SCHEMA = cv.All(
     download_mdi,
 )
 
+WEB_SCHEMA = cv.All(
+    {
+        cv.Required(CONF_URL): cv.string,
+    },
+    download_image,
+)
+
 TYPED_FILE_SCHEMA = cv.typed_schema(
     {
         SOURCE_LOCAL: LOCAL_SCHEMA,
         SOURCE_MDI: MDI_SCHEMA,
+        SOURCE_WEB: WEB_SCHEMA,
     },
     key=CONF_SOURCE,
 )
@@ -201,9 +265,7 @@ IMAGE_SCHEMA = cv.Schema(
 CONFIG_SCHEMA = cv.All(font.validate_pillow_installed, IMAGE_SCHEMA)
 
 
-def load_svg_image(file: str, resize: tuple[int, int]):
-    from PIL import Image
-
+def load_svg_image(file: bytes, resize: tuple[int, int]):
     # This import is only needed in case of SVG images; adding it
     # to the top would force configurations not using SVG to also have it
     # installed for no reason.
@@ -212,19 +274,17 @@ def load_svg_image(file: str, resize: tuple[int, int]):
     if resize:
         req_width, req_height = resize
         svg_image = svg2png(
-            url=file,
+            file,
             output_width=req_width,
             output_height=req_height,
         )
     else:
-        svg_image = svg2png(url=file)
+        svg_image = svg2png(file)
 
     return Image.open(io.BytesIO(svg_image))
 
 
 async def to_code(config):
-    from PIL import Image
-
     conf_file = config[CONF_FILE]
 
     if conf_file[CONF_SOURCE] == SOURCE_LOCAL:
@@ -233,17 +293,26 @@ async def to_code(config):
     elif conf_file[CONF_SOURCE] == SOURCE_MDI:
         path = _compute_local_icon_path(conf_file).as_posix()
 
+    elif conf_file[CONF_SOURCE] == SOURCE_WEB:
+        path = _compute_local_image_path(conf_file).as_posix()
+
     try:
-        resize = config.get(CONF_RESIZE)
-        if path.lower().endswith(".svg"):
-            image = load_svg_image(path, resize)
-        else:
-            image = Image.open(path)
-            if resize:
-                image.thumbnail(resize)
+        with open(path, "rb") as f:
+            file_contents = f.read()
     except Exception as e:
         raise core.EsphomeError(f"Could not load image file {path}: {e}")
 
+    mime = Magic(mime=True)
+    file_type = mime.from_buffer(file_contents)
+
+    resize = config.get(CONF_RESIZE)
+    if "svg" in file_type:
+        image = load_svg_image(file_contents, resize)
+    else:
+        image = Image.open(io.BytesIO(file_contents))
+        if resize:
+            image.thumbnail(resize)
+
     width, height = image.size
 
     if CONF_RESIZE not in config and (width > 500 or height > 500):
diff --git a/esphome/external_files.py b/esphome/external_files.py
new file mode 100644
index 0000000000..5b476286f3
--- /dev/null
+++ b/esphome/external_files.py
@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+import os
+from datetime import datetime
+import requests
+import esphome.config_validation as cv
+from esphome.core import CORE, TimePeriodSeconds
+
+_LOGGER = logging.getLogger(__name__)
+CODEOWNERS = ["@landonr"]
+
+NETWORK_TIMEOUT = 30
+
+IF_MODIFIED_SINCE = "If-Modified-Since"
+CACHE_CONTROL = "Cache-Control"
+CACHE_CONTROL_MAX_AGE = "max-age="
+CONTENT_DISPOSITION = "content-disposition"
+TEMP_DIR = "temp"
+
+
+def has_remote_file_changed(url, local_file_path):
+    if os.path.exists(local_file_path):
+        _LOGGER.debug("has_remote_file_changed: File exists at %s", local_file_path)
+        try:
+            local_modification_time = os.path.getmtime(local_file_path)
+            local_modification_time_str = datetime.utcfromtimestamp(
+                local_modification_time
+            ).strftime("%a, %d %b %Y %H:%M:%S GMT")
+
+            headers = {
+                IF_MODIFIED_SINCE: local_modification_time_str,
+                CACHE_CONTROL: CACHE_CONTROL_MAX_AGE + "3600",
+            }
+            response = requests.head(url, headers=headers, timeout=NETWORK_TIMEOUT)
+
+            _LOGGER.debug(
+                "has_remote_file_changed: File %s, Local modified %s, response code %d",
+                local_file_path,
+                local_modification_time_str,
+                response.status_code,
+            )
+
+            if response.status_code == 304:
+                _LOGGER.debug(
+                    "has_remote_file_changed: File not modified since %s",
+                    local_modification_time_str,
+                )
+                return False
+            _LOGGER.debug("has_remote_file_changed: File modified")
+            return True
+        except requests.exceptions.RequestException as e:
+            raise cv.Invalid(
+                f"Could not check if {url} has changed, please check if file exists "
+                f"({e})"
+            )
+
+    _LOGGER.debug("has_remote_file_changed: File doesn't exists at %s", local_file_path)
+    return True
+
+
+def is_file_recent(file_path: str, refresh: TimePeriodSeconds) -> bool:
+    if os.path.exists(file_path):
+        creation_time = os.path.getctime(file_path)
+        current_time = datetime.now().timestamp()
+        return current_time - creation_time <= refresh.total_seconds
+    return False
+
+
+def compute_local_file_dir(domain: str) -> Path:
+    base_directory = Path(CORE.data_dir) / domain
+    base_directory.mkdir(parents=True, exist_ok=True)
+
+    return base_directory
diff --git a/requirements.txt b/requirements.txt
index 7a3bb421ed..2c2bf1ba19 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,6 +12,7 @@ click==8.1.7
 esphome-dashboard==20231107.0
 aioesphomeapi==18.5.2
 zeroconf==0.123.0
+python-magic==0.4.27
 
 # esp-idf requires this, but doesn't bundle it by default
 # https://github.com/espressif/esp-idf/blob/220590d599e134d7a5e7f1e683cc4550349ffbf8/requirements.txt#L24
diff --git a/tests/test2.yaml b/tests/test2.yaml
index bfc886eaa4..91fb554146 100644
--- a/tests/test2.yaml
+++ b/tests/test2.yaml
@@ -752,6 +752,18 @@ image:
     file: pnglogo.png
     type: RGB565
     use_transparency: no
+  - id: web_svg_image
+    file: https://raw.githubusercontent.com/esphome/esphome-docs/a62d7ab193c1a464ed791670170c7d518189109b/images/logo.svg
+    resize: 256x48
+    type: TRANSPARENT_BINARY
+  - id: web_tiff_image
+    file: https://upload.wikimedia.org/wikipedia/commons/b/b6/SIPI_Jelly_Beans_4.1.07.tiff
+    type: RGB24
+    resize: 48x48
+  - id: web_redirect_image
+    file: https://avatars.githubusercontent.com/u/3060199?s=48&v=4
+    type: RGB24
+    resize: 48x48
 
   - id: mdi_alert
     file: mdi:alert-circle-outline