From c2a0c22bd9c78454195d762c2deb77019fca4e49 Mon Sep 17 00:00:00 2001
From: Otto Winter <otto@otto-winter.com>
Date: Tue, 26 Feb 2019 18:32:20 +0100
Subject: [PATCH] Automatically hide secrets in validation (#455)

* Hide secrets in validation

* Lint
---
 esphome/config.py    | 18 ++++++++------
 esphome/yaml_util.py | 57 +++++++++++++++++++++++++++++++++-----------
 2 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/esphome/config.py b/esphome/config.py
index e358fa0228..2b517b6301 100644
--- a/esphome/config.py
+++ b/esphome/config.py
@@ -2,7 +2,6 @@ from __future__ import print_function
 
 from collections import OrderedDict
 import importlib
-import json
 import logging
 import re
 
@@ -19,6 +18,7 @@ from esphome.util import safe_print
 # pylint: disable=unused-import, wrong-import-order
 from typing import List, Optional, Tuple, Union  # noqa
 from esphome.core import ConfigType  # noqa
+from esphome.yaml_util import is_secret
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -397,10 +397,7 @@ def _nested_getitem(data, path):
 def humanize_error(config, validation_error):
     offending_item_summary = _nested_getitem(config, validation_error.path)
     if isinstance(offending_item_summary, dict):
-        try:
-            offending_item_summary = json.dumps(offending_item_summary)
-        except (TypeError, ValueError):
-            pass
+        offending_item_summary = None
     validation_error = text_type(validation_error)
     m = re.match(r'^(.*?)\s*(?:for dictionary value )?@ data\[.*$', validation_error)
     if m is not None:
@@ -408,8 +405,9 @@ def humanize_error(config, validation_error):
     validation_error = validation_error.strip()
     if not validation_error.endswith(u'.'):
         validation_error += u'.'
-    if offending_item_summary is None:
+    if offending_item_summary is None or is_secret(offending_item_summary):
         return validation_error
+
     return u"{} Got '{}'".format(validation_error, offending_item_summary)
 
 
@@ -438,7 +436,8 @@ def load_config():
     try:
         config = yaml_util.load_yaml(CORE.config_path)
     except OSError:
-        raise EsphomeError(u"Invalid YAML at {}".format(CORE.config_path))
+        raise EsphomeError(u"Invalid YAML at {}. Please see YAML syntax reference or use an online "
+                           u"YAML syntax validator".format(CORE.config_path))
     CORE.raw_config = config
     config = substitutions.do_substitution_pass(config)
     core_config.preload_core_config(config)
@@ -536,6 +535,8 @@ def dump_dict(config, path, at_root=True):
                     msg = msg + u' ' + inf
             ret += st + msg + u'\n'
     elif isinstance(conf, str):
+        if is_secret(conf):
+            conf = u'!secret {}'.format(is_secret(conf))
         if not conf:
             conf += u"''"
 
@@ -545,6 +546,9 @@ def dump_dict(config, path, at_root=True):
         col = 'bold_red' if error else 'white'
         ret += color(col, text_type(conf))
     elif isinstance(conf, core.Lambda):
+        if is_secret(conf):
+            conf = u'!secret {}'.format(is_secret(conf))
+
         conf = u'!lambda |-\n' + indent(text_type(conf.value))
         error = config.get_error_for_path(path)
         col = 'bold_red' if error else 'white'
diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py
index 750dfc993d..a47d6d8122 100644
--- a/esphome/yaml_util.py
+++ b/esphome/yaml_util.py
@@ -12,7 +12,7 @@ import yaml.constructor
 
 from esphome import core
 from esphome.core import EsphomeError, HexInt, IPAddress, Lambda, MACAddress, TimePeriod
-from esphome.py_compat import string_types, text_type
+from esphome.py_compat import string_types, text_type, IS_PY2
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -20,6 +20,8 @@ _LOGGER = logging.getLogger(__name__)
 # let's not reinvent the wheel here
 
 SECRET_YAML = u'secrets.yaml'
+_SECRET_CACHE = {}
+_SECRET_VALUES = {}
 
 
 class NodeListClass(list):
@@ -42,6 +44,12 @@ class SafeLineLoader(yaml.SafeLoader):  # pylint: disable=too-many-ancestors
 
 
 def load_yaml(fname):
+    _SECRET_VALUES.clear()
+    _SECRET_CACHE.clear()
+    return _load_yaml_internal(fname)
+
+
+def _load_yaml_internal(fname):
     """Load a YAML file."""
     try:
         with codecs.open(fname, encoding='utf-8') as conf_file:
@@ -193,7 +201,7 @@ def _include_yaml(loader, node):
         device_tracker: !include device_tracker.yaml
     """
     fname = os.path.join(os.path.dirname(loader.name), node.value)
-    return _add_reference(load_yaml(fname), loader, node)
+    return _add_reference(_load_yaml_internal(fname), loader, node)
 
 
 def _is_file_valid(name):
@@ -217,7 +225,7 @@ def _include_dir_named_yaml(loader, node):
     loc = os.path.join(os.path.dirname(loader.name), node.value)
     for fname in _find_files(loc, '*.yaml'):
         filename = os.path.splitext(os.path.basename(fname))[0]
-        mapping[filename] = load_yaml(fname)
+        mapping[filename] = _load_yaml_internal(fname)
     return _add_reference(mapping, loader, node)
 
 
@@ -228,7 +236,7 @@ def _include_dir_merge_named_yaml(loader, node):
     for fname in _find_files(loc, '*.yaml'):
         if os.path.basename(fname) == SECRET_YAML:
             continue
-        loaded_yaml = load_yaml(fname)
+        loaded_yaml = _load_yaml_internal(fname)
         if isinstance(loaded_yaml, dict):
             mapping.update(loaded_yaml)
     return _add_reference(mapping, loader, node)
@@ -237,7 +245,7 @@ def _include_dir_merge_named_yaml(loader, node):
 def _include_dir_list_yaml(loader, node):
     """Load multiple files from directory as a list."""
     loc = os.path.join(os.path.dirname(loader.name), node.value)
-    return [load_yaml(f) for f in _find_files(loc, '*.yaml')
+    return [_load_yaml_internal(f) for f in _find_files(loc, '*.yaml')
             if os.path.basename(f) != SECRET_YAML]
 
 
@@ -248,20 +256,29 @@ def _include_dir_merge_list_yaml(loader, node):
     for fname in _find_files(path, '*.yaml'):
         if os.path.basename(fname) == SECRET_YAML:
             continue
-        loaded_yaml = load_yaml(fname)
+        loaded_yaml = _load_yaml_internal(fname)
         if isinstance(loaded_yaml, list):
             merged_list.extend(loaded_yaml)
     return _add_reference(merged_list, loader, node)
 
 
+def is_secret(value):
+    try:
+        return _SECRET_VALUES[text_type(value)]
+    except (KeyError, ValueError):
+        return None
+
+
 # pylint: disable=protected-access
 def _secret_yaml(loader, node):
     """Load secrets and embed it into the configuration YAML."""
     secret_path = os.path.join(os.path.dirname(loader.name), SECRET_YAML)
-    secrets = load_yaml(secret_path)
+    secrets = _load_yaml_internal(secret_path)
     if node.value not in secrets:
         raise EsphomeError(u"Secret {} not defined".format(node.value))
-    return secrets[node.value]
+    val = secrets[node.value]
+    _SECRET_VALUES[text_type(val)] = node.value
+    return val
 
 
 def _lambda(loader, node):
@@ -310,17 +327,27 @@ def represent_odict(dump, tag, mapping, flow_style=None):
     return node
 
 
+def represent_secret(value):
+    return yaml.ScalarNode(tag=u'!secret', value=_SECRET_VALUES[value])
+
+
 def unicode_representer(_, uni):
+    if is_secret(uni):
+        return represent_secret(uni)
     node = yaml.ScalarNode(tag=u'tag:yaml.org,2002:str', value=uni)
     return node
 
 
 def hex_int_representer(_, data):
+    if is_secret(data):
+        return represent_secret(data)
     node = yaml.ScalarNode(tag=u'tag:yaml.org,2002:int', value=str(data))
     return node
 
 
 def stringify_representer(_, data):
+    if is_secret(data):
+        return represent_secret(data)
     node = yaml.ScalarNode(tag=u'tag:yaml.org,2002:str', value=str(data))
     return node
 
@@ -345,18 +372,18 @@ def represent_time_period(dumper, data):
 
 
 def represent_lambda(_, data):
+    if is_secret(data.value):
+        return represent_secret(data.value)
     node = yaml.ScalarNode(tag='!lambda', value=data.value, style='|')
     return node
 
 
 def represent_id(_, data):
+    if is_secret(data.id):
+        return represent_secret(data.id)
     return yaml.ScalarNode(tag=u'tag:yaml.org,2002:str', value=data.id)
 
 
-def represent_uuid(_, data):
-    return yaml.ScalarNode(tag=u'tag:yaml.org,2002:str', value=str(data))
-
-
 yaml.SafeDumper.add_representer(
     OrderedDict,
     lambda dumper, value:
@@ -369,11 +396,13 @@ yaml.SafeDumper.add_representer(
     dumper.represent_sequence('tag:yaml.org,2002:seq', value)
 )
 
-yaml.SafeDumper.add_representer(text_type, unicode_representer)
+yaml.SafeDumper.add_representer(str, unicode_representer)
+if IS_PY2:
+    yaml.SafeDumper.add_representer(unicode, unicode_representer)
 yaml.SafeDumper.add_representer(HexInt, hex_int_representer)
 yaml.SafeDumper.add_representer(IPAddress, stringify_representer)
 yaml.SafeDumper.add_representer(MACAddress, stringify_representer)
 yaml.SafeDumper.add_multi_representer(TimePeriod, represent_time_period)
 yaml.SafeDumper.add_multi_representer(Lambda, represent_lambda)
 yaml.SafeDumper.add_multi_representer(core.ID, represent_id)
-yaml.SafeDumper.add_multi_representer(uuid.UUID, represent_uuid)
+yaml.SafeDumper.add_multi_representer(uuid.UUID, stringify_representer)