1
0
mirror of https://github.com/esphome/esphome.git synced 2025-01-21 21:34:05 +00:00
esphome/esphome/util.py
Guillermo Ruffino 69879920eb
add-black (#1593)
* Add black

Update pre commit

Update pre commit

add empty line

* Format with black
2021-03-07 16:03:16 -03:00

293 lines
8.6 KiB
Python

from typing import Union, List
import collections
import io
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from esphome import const
_LOGGER = logging.getLogger(__name__)
class RegistryEntry:
def __init__(self, name, fun, type_id, schema):
self.name = name
self.fun = fun
self.type_id = type_id
self.raw_schema = schema
@property
def coroutine_fun(self):
from esphome.core import coroutine
return coroutine(self.fun)
@property
def schema(self):
from esphome.config_validation import Schema
return Schema(self.raw_schema)
class Registry(dict):
def __init__(self, base_schema=None, type_id_key=None):
super().__init__()
self.base_schema = base_schema or {}
self.type_id_key = type_id_key
def register(self, name, type_id, schema):
def decorator(fun):
self[name] = RegistryEntry(name, fun, type_id, schema)
return fun
return decorator
class SimpleRegistry(dict):
def register(self, name, data):
def decorator(fun):
self[name] = (fun, data)
return fun
return decorator
def safe_print(message=""):
from esphome.core import CORE
if CORE.dashboard:
try:
message = message.replace("\033", "\\033")
except UnicodeEncodeError:
pass
try:
print(message)
return
except UnicodeEncodeError:
pass
try:
print(message.encode("utf-8", "backslashreplace"))
except UnicodeEncodeError:
try:
print(message.encode("ascii", "backslashreplace"))
except UnicodeEncodeError:
print("Cannot print line because of invalid locale!")
def shlex_quote(s):
if not s:
return "''"
if re.search(r"[^\w@%+=:,./-]", s) is None:
return s
return "'" + s.replace("'", "'\"'\"'") + "'"
ANSI_ESCAPE = re.compile(r"\033[@-_][0-?]*[ -/]*[@-~]")
class RedirectText:
def __init__(self, out, filter_lines=None):
self._out = out
if filter_lines is None:
self._filter_pattern = None
else:
pattern = r"|".join(r"(?:" + pattern + r")" for pattern in filter_lines)
self._filter_pattern = re.compile(pattern)
self._line_buffer = ""
def __getattr__(self, item):
return getattr(self._out, item)
def _write_color_replace(self, s):
from esphome.core import CORE
if CORE.dashboard:
# With the dashboard, we must create a little hack to make color output
# work. The shell we create in the dashboard is not a tty, so python removes
# all color codes from the resulting stream. We just convert them to something
# we can easily recognize later here.
s = s.replace("\033", "\\033")
self._out.write(s)
def write(self, s):
# s is usually a str already (self._out is of type TextIOWrapper)
# However, s is sometimes also a bytes object in python3. Let's make sure it's a
# str
# If the conversion fails, we will create an exception, which is okay because we won't
# be able to print it anyway.
if not isinstance(s, str):
s = s.decode()
if self._filter_pattern is not None:
self._line_buffer += s
lines = self._line_buffer.splitlines(True)
for line in lines:
if "\n" not in line and "\r" not in line:
# Not a complete line, set line buffer
self._line_buffer = line
break
self._line_buffer = ""
line_without_ansi = ANSI_ESCAPE.sub("", line)
line_without_end = line_without_ansi.rstrip()
if self._filter_pattern.match(line_without_end) is not None:
# Filter pattern matched, ignore the line
continue
self._write_color_replace(line)
else:
self._write_color_replace(s)
# write() returns the number of characters written
# Let's print the number of characters of the original string in order to not confuse
# any caller.
return len(s)
# pylint: disable=no-self-use
def isatty(self):
return True
def run_external_command(
func, *cmd, capture_stdout: bool = False, filter_lines: str = None
) -> Union[int, str]:
"""
Run a function from an external package that acts like a main method.
Temporarily replaces stdin/stderr/stdout, sys.argv and sys.exit handler during the run.
:param func: Function to execute
:param cmd: Command to run as (eg first element of sys.argv)
:param capture_stdout: Capture text from stdout and return that.
:param filter_lines: Regular expression used to filter captured output.
:return: str if `capture_stdout` is set else int exit code.
"""
def mock_exit(return_code):
raise SystemExit(return_code)
orig_argv = sys.argv
orig_exit = sys.exit # mock sys.exit
full_cmd = " ".join(shlex_quote(x) for x in cmd)
_LOGGER.info("Running: %s", full_cmd)
orig_stdout = sys.stdout
sys.stdout = RedirectText(sys.stdout, filter_lines=filter_lines)
orig_stderr = sys.stderr
sys.stderr = RedirectText(sys.stderr, filter_lines=filter_lines)
if capture_stdout:
cap_stdout = sys.stdout = io.StringIO()
try:
sys.argv = list(cmd)
sys.exit = mock_exit
return func() or 0
except KeyboardInterrupt:
return 1
except SystemExit as err:
return err.args[0]
except Exception as err: # pylint: disable=broad-except
_LOGGER.error("Running command failed: %s", err)
_LOGGER.error("Please try running %s locally.", full_cmd)
return 1
finally:
sys.argv = orig_argv
sys.exit = orig_exit
sys.stdout = orig_stdout
sys.stderr = orig_stderr
if capture_stdout:
# pylint: disable=lost-exception
return cap_stdout.getvalue()
def run_external_process(*cmd, **kwargs):
full_cmd = " ".join(shlex_quote(x) for x in cmd)
_LOGGER.info("Running: %s", full_cmd)
filter_lines = kwargs.get("filter_lines")
capture_stdout = kwargs.get("capture_stdout", False)
if capture_stdout:
sub_stdout = io.BytesIO()
else:
sub_stdout = RedirectText(sys.stdout, filter_lines=filter_lines)
sub_stderr = RedirectText(sys.stderr, filter_lines=filter_lines)
try:
return subprocess.call(cmd, stdout=sub_stdout, stderr=sub_stderr)
except Exception as err: # pylint: disable=broad-except
_LOGGER.error("Running command failed: %s", err)
_LOGGER.error("Please try running %s locally.", full_cmd)
return 1
finally:
if capture_stdout:
# pylint: disable=lost-exception
return sub_stdout.getvalue()
def is_dev_esphome_version():
return "dev" in const.__version__
# Custom OrderedDict with nicer repr method for debugging
class OrderedDict(collections.OrderedDict):
def __repr__(self):
return dict(self).__repr__()
def list_yaml_files(folder):
files = filter_yaml_files([os.path.join(folder, p) for p in os.listdir(folder)])
files.sort()
return files
def filter_yaml_files(files):
files = [f for f in files if os.path.splitext(f)[1] == ".yaml"]
files = [f for f in files if os.path.basename(f) != "secrets.yaml"]
files = [f for f in files if not os.path.basename(f).startswith(".")]
return files
class SerialPort:
def __init__(self, path: str, description: str):
self.path = path
self.description = description
# from https://github.com/pyserial/pyserial/blob/master/serial/tools/list_ports.py
def get_serial_ports() -> List[SerialPort]:
from serial.tools.list_ports import comports
result = []
for port, desc, info in comports(include_links=True):
if not port:
continue
if "VID:PID" in info:
result.append(SerialPort(path=port, description=desc))
# Also add objects in /dev/serial/by-id/
# ref: https://github.com/esphome/issues/issues/1346
by_id_path = Path("/dev/serial/by-id")
if sys.platform.lower().startswith("linux") and by_id_path.exists():
from serial.tools.list_ports_linux import SysFS
for path in by_id_path.glob("*"):
device = SysFS(path)
if device.subsystem == "platform":
result.append(SerialPort(path=str(path), description=info[1]))
result.sort(key=lambda x: x.path)
return result