mirror of
synced 2025-03-03 17:28:16 +00:00
Co-authored-by: Paul Monigatti <pm@paul.pm> Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
523 lines
19 KiB
523 lines
19 KiB
import fnmatch
import functools
import inspect
import logging
import math
import os
import uuid
import yaml
import yaml.constructor
from esphome import core
from esphome.config_helpers import read_config_file, Extend
from esphome.core import (
from esphome.helpers import add_class_to_obj
from esphome.util import OrderedDict, filter_yaml_files
_LOGGER = logging.getLogger(__name__)
# Mostly copied from Home Assistant because that code works fine and
# let's not reinvent the wheel here
SECRET_YAML = "secrets.yaml"
class ESPHomeDataBase:
def esp_range(self):
return getattr(self, "_esp_range", None)
def content_offset(self):
return getattr(self, "_content_offset", 0)
def from_node(self, node):
# pylint: disable=attribute-defined-outside-init
self._esp_range = DocumentRange.from_marks(node.start_mark, node.end_mark)
if isinstance(node, yaml.ScalarNode):
if node.style is not None and node.style in "|>":
self._content_offset = 1
def from_database(self, database):
# pylint: disable=attribute-defined-outside-init
self._esp_range = database.esp_range
self._content_offset = database.content_offset
class ESPForceValue:
def make_data_base(value, from_database: ESPHomeDataBase = None):
value = add_class_to_obj(value, ESPHomeDataBase)
if from_database is not None:
return value
except TypeError:
# Adding class failed, ignore error
return value
def _add_data_ref(fn):
def wrapped(loader, node):
res = fn(loader, node)
# newer PyYAML versions use generators, resolve them
if inspect.isgenerator(res):
generator = res
res = next(generator)
# Let generator finish
for _ in generator:
res = make_data_base(res)
if isinstance(res, ESPHomeDataBase):
return res
return wrapped
class ESPHomeLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers."""
def construct_yaml_int(self, node):
return super().construct_yaml_int(node)
def construct_yaml_float(self, node):
return super().construct_yaml_float(node)
def construct_yaml_binary(self, node):
return super().construct_yaml_binary(node)
def construct_yaml_omap(self, node):
return super().construct_yaml_omap(node)
def construct_yaml_str(self, node):
return super().construct_yaml_str(node)
def construct_yaml_seq(self, node):
return super().construct_yaml_seq(node)
def construct_yaml_map(self, node):
"""Traverses the given mapping node and returns a list of constructed key-value pairs."""
assert isinstance(node, yaml.MappingNode)
# A list of key-value pairs we find in the current mapping
pairs = []
# A list of key-value pairs we find while resolving merges ('<<' key), will be
# added to pairs in a second pass
merge_pairs = []
# A dict of seen keys so far, used to alert the user of duplicate keys and checking
# which keys to merge.
# Value of dict items is the start mark of the previous declaration.
seen_keys = {}
for key_node, value_node in node.value:
# merge key is '<<'
is_merge_key = key_node.tag == "tag:yaml.org,2002:merge"
# key has no explicit tag set
is_default_tag = key_node.tag == "tag:yaml.org,2002:value"
if is_default_tag:
# Default tag for mapping keys is string
key_node.tag = "tag:yaml.org,2002:str"
if not is_merge_key:
# base case, this is a simple key-value pair
key = self.construct_object(key_node)
value = self.construct_object(value_node)
# Check if key is hashable
except TypeError:
# pylint: disable=raise-missing-from
raise yaml.constructor.ConstructorError(
f'Invalid key "{key}" (not hashable)', key_node.start_mark
key = make_data_base(str(key))
# Check if it is a duplicate key
if key in seen_keys:
raise yaml.constructor.ConstructorError(
f'Duplicate key "{key}"',
"NOTE: Previous declaration here:",
seen_keys[key] = key_node.start_mark
# Add to pairs
pairs.append((key, value))
# This is a merge key, resolve value and add to merge_pairs
value = self.construct_object(value_node)
if isinstance(value, dict):
# base case, copy directly to merge_pairs
# direct merge, like "<<: {some_key: some_value}"
elif isinstance(value, list):
# sequence merge, like "<<: [{some_key: some_value}, {other_key: some_value}]"
for item in value:
if not isinstance(item, dict):
raise yaml.constructor.ConstructorError(
"While constructing a mapping",
f"Expected a mapping for merging, but found {type(item)}",
raise yaml.constructor.ConstructorError(
"While constructing a mapping",
f"Expected a mapping or list of mappings for merging, but found {type(value)}",
if merge_pairs:
# We found some merge keys along the way, merge them into base pairs
# https://yaml.org/type/merge.html
# Construct a new merge set with values overridden by current mapping or earlier
# sequence entries removed
for key, value in merge_pairs:
if key in seen_keys:
# key already in the current map or from an earlier merge sequence entry,
# do not override
# "... each of its key/value pairs is inserted into the current mapping,
# unless the key already exists in it."
# "If the value associated with the merge key is a sequence, then this sequence
# is expected to contain mapping nodes and each of these nodes is merged in
# turn according to its order in the sequence. Keys in mapping nodes earlier
# in the sequence override keys specified in later mapping nodes."
pairs.append((key, value))
# Add key node to seen keys, for sequence merge values.
seen_keys[key] = None
return OrderedDict(pairs)
def construct_env_var(self, node):
args = node.value.split()
# Check for a default value
if len(args) > 1:
return os.getenv(args[0], " ".join(args[1:]))
if args[0] in os.environ:
return os.environ[args[0]]
raise yaml.MarkedYAMLError(
f"Environment variable '{node.value}' not defined", node.start_mark
def _directory(self):
return os.path.dirname(self.name)
def _rel_path(self, *args):
return os.path.join(self._directory, *args)
def construct_secret(self, node):
secrets = _load_yaml_internal(self._rel_path(SECRET_YAML))
if node.value not in secrets:
raise yaml.MarkedYAMLError(
f"Secret '{node.value}' not defined", node.start_mark
val = secrets[node.value]
_SECRET_VALUES[str(val)] = node.value
return val
def construct_include(self, node):
def extract_file_vars(node):
fields = self.construct_yaml_map(node)
file = fields.get("file")
if file is None:
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
vars = fields.get("vars")
if vars:
vars = {k: str(v) for k, v in vars.items()}
return file, vars
def substitute_vars(config, vars):
from esphome.const import CONF_SUBSTITUTIONS
from esphome.components import substitutions
org_subs = None
result = config
if not isinstance(config, dict):
# when the included yaml contains a list or a scalar
# wrap it into an OrderedDict because do_substitution_pass expects it
result = OrderedDict([("yaml", config)])
elif CONF_SUBSTITUTIONS in result:
org_subs = result.pop(CONF_SUBSTITUTIONS)
# Ignore missing vars that refer to the top level substitutions
substitutions.do_substitution_pass(result, None, ignore_missing=True)
if not isinstance(config, dict):
result = result["yaml"] # unwrap the result
elif org_subs:
result[CONF_SUBSTITUTIONS] = org_subs
return result
if isinstance(node, yaml.nodes.MappingNode):
file, vars = extract_file_vars(node)
file, vars = node.value, None
result = _load_yaml_internal(self._rel_path(file))
if vars:
result = substitute_vars(result, vars)
return result
def construct_include_dir_list(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
return [_load_yaml_internal(f) for f in files]
def construct_include_dir_merge_list(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
merged_list = []
for fname in files:
loaded_yaml = _load_yaml_internal(fname)
if isinstance(loaded_yaml, list):
return merged_list
def construct_include_dir_named(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
mapping = OrderedDict()
for fname in files:
filename = os.path.splitext(os.path.basename(fname))[0]
mapping[filename] = _load_yaml_internal(fname)
return mapping
def construct_include_dir_merge_named(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
mapping = OrderedDict()
for fname in files:
loaded_yaml = _load_yaml_internal(fname)
if isinstance(loaded_yaml, dict):
return mapping
def construct_lambda(self, node):
return Lambda(str(node.value))
def construct_force(self, node):
obj = self.construct_scalar(node)
return add_class_to_obj(obj, ESPForceValue)
def construct_extend(self, node):
return Extend(str(node.value))
ESPHomeLoader.add_constructor("tag:yaml.org,2002:int", ESPHomeLoader.construct_yaml_int)
"tag:yaml.org,2002:float", ESPHomeLoader.construct_yaml_float
"tag:yaml.org,2002:binary", ESPHomeLoader.construct_yaml_binary
"tag:yaml.org,2002:omap", ESPHomeLoader.construct_yaml_omap
ESPHomeLoader.add_constructor("tag:yaml.org,2002:str", ESPHomeLoader.construct_yaml_str)
ESPHomeLoader.add_constructor("tag:yaml.org,2002:seq", ESPHomeLoader.construct_yaml_seq)
ESPHomeLoader.add_constructor("tag:yaml.org,2002:map", ESPHomeLoader.construct_yaml_map)
ESPHomeLoader.add_constructor("!env_var", ESPHomeLoader.construct_env_var)
ESPHomeLoader.add_constructor("!secret", ESPHomeLoader.construct_secret)
ESPHomeLoader.add_constructor("!include", ESPHomeLoader.construct_include)
"!include_dir_list", ESPHomeLoader.construct_include_dir_list
"!include_dir_merge_list", ESPHomeLoader.construct_include_dir_merge_list
"!include_dir_named", ESPHomeLoader.construct_include_dir_named
"!include_dir_merge_named", ESPHomeLoader.construct_include_dir_merge_named
ESPHomeLoader.add_constructor("!lambda", ESPHomeLoader.construct_lambda)
ESPHomeLoader.add_constructor("!force", ESPHomeLoader.construct_force)
ESPHomeLoader.add_constructor("!extend", ESPHomeLoader.construct_extend)
def load_yaml(fname, clear_secrets=True):
if clear_secrets:
return _load_yaml_internal(fname)
def _load_yaml_internal(fname):
content = read_config_file(fname)
loader = ESPHomeLoader(content)
loader.name = fname
return loader.get_single_data() or OrderedDict()
except yaml.YAMLError as exc:
raise EsphomeError(exc) from exc
def dump(dict_, show_secrets=False):
"""Dump YAML to a string and remove null."""
if show_secrets:
return yaml.dump(
dict_, default_flow_style=False, allow_unicode=True, Dumper=ESPHomeDumper
def _is_file_valid(name):
"""Decide if a file is valid."""
return not name.startswith(".")
def _find_files(directory, pattern):
"""Recursively load files in a directory."""
for root, dirs, files in os.walk(directory, topdown=True):
dirs[:] = [d for d in dirs if _is_file_valid(d)]
for basename in files:
if _is_file_valid(basename) and fnmatch.fnmatch(basename, pattern):
filename = os.path.join(root, basename)
yield filename
def is_secret(value):
return _SECRET_VALUES[str(value)]
except (KeyError, ValueError):
return None
class ESPHomeDumper(yaml.SafeDumper):
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
node = yaml.MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, "items"):
mapping = list(mapping.items())
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
node.flow_style = best_style
return node
def represent_secret(self, value):
return self.represent_scalar(tag="!secret", value=_SECRET_VALUES[str(value)])
def represent_stringify(self, value):
if is_secret(value):
return self.represent_secret(value)
return self.represent_scalar(tag="tag:yaml.org,2002:str", value=str(value))
# pylint: disable=arguments-renamed
def represent_bool(self, value):
return self.represent_scalar(
"tag:yaml.org,2002:bool", "true" if value else "false"
# pylint: disable=arguments-renamed
def represent_int(self, value):
if is_secret(value):
return self.represent_secret(value)
return self.represent_scalar(tag="tag:yaml.org,2002:int", value=str(value))
# pylint: disable=arguments-renamed
def represent_float(self, value):
if is_secret(value):
return self.represent_secret(value)
if math.isnan(value):
value = ".nan"
elif math.isinf(value):
value = ".inf" if value > 0 else "-.inf"
value = str(repr(value)).lower()
# Note that in some cases `repr(data)` represents a float number
# without the decimal parts. For instance:
# >>> repr(1e17)
# '1e17'
# Unfortunately, this is not a valid float representation according
# to the definition of the `!!float` tag. We fix this by adding
# '.0' before the 'e' symbol.
if "." not in value and "e" in value:
value = value.replace("e", ".0e", 1)
return self.represent_scalar(tag="tag:yaml.org,2002:float", value=value)
def represent_lambda(self, value):
if is_secret(value.value):
return self.represent_secret(value.value)
return self.represent_scalar(tag="!lambda", value=value.value, style="|")
def represent_id(self, value):
if is_secret(value.id):
return self.represent_secret(value.id)
return self.represent_stringify(value.id)
dict, lambda dumper, value: dumper.represent_mapping("tag:yaml.org,2002:map", value)
lambda dumper, value: dumper.represent_sequence("tag:yaml.org,2002:seq", value),
ESPHomeDumper.add_multi_representer(bool, ESPHomeDumper.represent_bool)
ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int)
ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float)
ESPHomeDumper.add_multi_representer(IPAddress, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda)
ESPHomeDumper.add_multi_representer(core.ID, ESPHomeDumper.represent_id)
ESPHomeDumper.add_multi_representer(uuid.UUID, ESPHomeDumper.represent_stringify)