import collections import io import logging from pathlib import Path import re import subprocess import sys from typing import Any from esphome import const _LOGGER = logging.getLogger(__name__) class RegistryEntry: def __init__(self, name, fun, type_id, schema): self.name = name self.fun = fun self.type_id = type_id self.raw_schema = schema @property def coroutine_fun(self): from esphome.core import coroutine return coroutine(self.fun) @property def schema(self): from esphome.config_validation import Schema return Schema(self.raw_schema) class Registry(dict[str, RegistryEntry]): def __init__(self, base_schema=None, type_id_key=None): super().__init__() self.base_schema = base_schema or {} self.type_id_key = type_id_key def register(self, name, type_id, schema): def decorator(fun): self[name] = RegistryEntry(name, fun, type_id, schema) return fun return decorator class SimpleRegistry(dict): def register(self, name, data): def decorator(fun): self[name] = (fun, data) return fun return decorator def safe_print(message="", end="\n"): from esphome.core import CORE if CORE.dashboard: try: # noqa: SIM105 message = message.replace("\033", "\\033") except UnicodeEncodeError: pass try: print(message, end=end) return except UnicodeEncodeError: pass try: print(message.encode("utf-8", "backslashreplace"), end=end) except UnicodeEncodeError: try: print(message.encode("ascii", "backslashreplace"), end=end) except UnicodeEncodeError: print("Cannot print line because of invalid locale!") def safe_input(prompt=""): if prompt: safe_print(prompt, end="") return input() def shlex_quote(s: str | Path) -> str: # Convert Path objects to strings if isinstance(s, Path): s = str(s) if not s: return "''" if re.search(r"[^\w@%+=:,./-]", s) is None: return s return "'" + s.replace("'", "'\"'\"'") + "'" ANSI_ESCAPE = re.compile(r"\033[@-_][0-?]*[ -/]*[@-~]") class RedirectText: def __init__(self, out, filter_lines=None): self._out = out if filter_lines is None: self._filter_pattern = None else: pattern = r"|".join(r"(?:" + pattern + r")" for pattern in filter_lines) self._filter_pattern = re.compile(pattern) self._line_buffer = "" def __getattr__(self, item): return getattr(self._out, item) def _write_color_replace(self, s: str | bytes) -> None: from esphome.core import CORE if CORE.dashboard: # With the dashboard, we must create a little hack to make color output # work. The shell we create in the dashboard is not a tty, so python removes # all color codes from the resulting stream. We just convert them to something # we can easily recognize later here. s = s.replace("\033", "\\033") self._out.write(s) def write(self, s: str | bytes) -> int: # s is usually a str already (self._out is of type TextIOWrapper) # However, s is sometimes also a bytes object in python3. Let's make sure it's a # str # If the conversion fails, we will create an exception, which is okay because we won't # be able to print it anyway. if not isinstance(s, str): s = s.decode() if self._filter_pattern is not None: self._line_buffer += s lines = self._line_buffer.splitlines(True) for line in lines: if "\n" not in line and "\r" not in line: # Not a complete line, set line buffer self._line_buffer = line break self._line_buffer = "" line_without_ansi = ANSI_ESCAPE.sub("", line) line_without_end = line_without_ansi.rstrip() if self._filter_pattern.match(line_without_end) is not None: # Filter pattern matched, ignore the line continue self._write_color_replace(line) # Check for flash size error and provide helpful guidance if ( "Error: The program size" in line and "is greater than maximum allowed" in line and (help_msg := get_esp32_arduino_flash_error_help()) ): self._write_color_replace(help_msg) else: self._write_color_replace(s) # write() returns the number of characters written # Let's print the number of characters of the original string in order to not confuse # any caller. return len(s) def isatty(self): return True def run_external_command( func, *cmd, capture_stdout: bool = False, filter_lines: str = None ) -> int | str: """ Run a function from an external package that acts like a main method. Temporarily replaces stdin/stderr/stdout, sys.argv and sys.exit handler during the run. :param func: Function to execute :param cmd: Command to run as (eg first element of sys.argv) :param capture_stdout: Capture text from stdout and return that. :param filter_lines: Regular expression used to filter captured output. :return: str if `capture_stdout` is set else int exit code. """ def mock_exit(return_code): raise SystemExit(return_code) orig_argv = sys.argv orig_exit = sys.exit # mock sys.exit full_cmd = " ".join(shlex_quote(x) for x in cmd) _LOGGER.debug("Running: %s", full_cmd) orig_stdout = sys.stdout sys.stdout = RedirectText(sys.stdout, filter_lines=filter_lines) orig_stderr = sys.stderr sys.stderr = RedirectText(sys.stderr, filter_lines=filter_lines) if capture_stdout: cap_stdout = sys.stdout = io.StringIO() try: sys.argv = list(cmd) sys.exit = mock_exit retval = func() or 0 except KeyboardInterrupt: # pylint: disable=try-except-raise raise except SystemExit as err: return err.args[0] except Exception as err: # pylint: disable=broad-except _LOGGER.error("Running command failed: %s", err) _LOGGER.error("Please try running %s locally.", full_cmd) return 1 finally: sys.argv = orig_argv sys.exit = orig_exit sys.stdout = orig_stdout sys.stderr = orig_stderr if capture_stdout: return cap_stdout.getvalue() return retval def run_external_process(*cmd: str, **kwargs: Any) -> int | str: full_cmd = " ".join(shlex_quote(x) for x in cmd) _LOGGER.debug("Running: %s", full_cmd) filter_lines = kwargs.get("filter_lines") capture_stdout = kwargs.get("capture_stdout", False) if capture_stdout: sub_stdout = subprocess.PIPE else: sub_stdout = RedirectText(sys.stdout, filter_lines=filter_lines) sub_stderr = RedirectText(sys.stderr, filter_lines=filter_lines) try: proc = subprocess.run( cmd, stdout=sub_stdout, stderr=sub_stderr, encoding="utf-8", check=False, close_fds=False, ) return proc.stdout if capture_stdout else proc.returncode except KeyboardInterrupt: # pylint: disable=try-except-raise raise except Exception as err: # pylint: disable=broad-except _LOGGER.error("Running command failed: %s", err) _LOGGER.error("Please try running %s locally.", full_cmd) return 1 def is_dev_esphome_version(): return "dev" in const.__version__ def parse_esphome_version() -> tuple[int, int, int]: match = re.match(r"^(\d+).(\d+).(\d+)(-dev\d*|b\d*)?$", const.__version__) if match is None: raise ValueError(f"Failed to parse ESPHome version '{const.__version__}'") return int(match.group(1)), int(match.group(2)), int(match.group(3)) # Custom OrderedDict with nicer repr method for debugging class OrderedDict(collections.OrderedDict): def __repr__(self): return dict(self).__repr__() def list_yaml_files(configs: list[str | Path]) -> list[Path]: files: list[Path] = [] for config in configs: config = Path(config) if not config.exists(): raise FileNotFoundError(f"Config path '{config}' does not exist!") if config.is_file(): files.append(config) else: files.extend(config.glob("*")) files = filter_yaml_files(files) return sorted(files) def filter_yaml_files(files: list[Path]) -> list[Path]: return [ f for f in files if ( f.suffix in (".yaml", ".yml") and f.name not in ("secrets.yaml", "secrets.yml") and not f.name.startswith(".") ) ] class SerialPort: def __init__(self, path: str, description: str): self.path = path self.description = description # from https://github.com/pyserial/pyserial/blob/master/serial/tools/list_ports.py def get_serial_ports() -> list[SerialPort]: from serial.tools.list_ports import comports result = [] for port, desc, info in comports(include_links=True): if not port: continue if "VID:PID" in info: result.append(SerialPort(path=port, description=desc)) # Also add objects in /dev/serial/by-id/ # ref: https://github.com/esphome/issues/issues/1346 by_id_path = Path("/dev/serial/by-id") if sys.platform.lower().startswith("linux") and by_id_path.exists(): from serial.tools.list_ports_linux import SysFS for path in by_id_path.glob("*"): device = SysFS(path) if device.subsystem == "platform": result.append(SerialPort(path=str(path), description=info[1])) result.sort(key=lambda x: x.path) return result def get_esp32_arduino_flash_error_help() -> str | None: """Returns helpful message when ESP32 with Arduino runs out of flash space.""" from esphome.core import CORE if not (CORE.is_esp32 and CORE.using_arduino): return None from esphome.log import AnsiFore, color return ( "\n" + color( AnsiFore.YELLOW, "💡 TIP: Your ESP32 with Arduino framework has run out of flash space.\n", ) + "\n" + "To fix this, switch to the ESP-IDF framework which is more memory efficient:\n" + "\n" + "1. In your YAML configuration, modify the framework section:\n" + "\n" + " esp32:\n" + " framework:\n" + " type: esp-idf\n" + "\n" + "2. Clean build files and compile again\n" + "\n" + "Note: ESP-IDF uses less flash space and provides better performance.\n" + "Some Arduino-specific libraries may need alternatives.\n" + "\n" + "For detailed migration instructions, see:\n" + color( AnsiFore.BLUE, "https://esphome.io/guides/esp32_arduino_to_idf.html\n\n", ) )