mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-31 07:03:55 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			235 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			235 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import difflib
 | |
| import itertools
 | |
| 
 | |
| import voluptuous as vol
 | |
| 
 | |
| from esphome.schema_extractors import schema_extractor_extended
 | |
| 
 | |
| 
 | |
| class ExtraKeysInvalid(vol.Invalid):
 | |
|     def __init__(self, *arg, **kwargs):
 | |
|         self.candidates = kwargs.pop("candidates")
 | |
|         vol.Invalid.__init__(self, *arg, **kwargs)
 | |
| 
 | |
| 
 | |
| def ensure_multiple_invalid(err):
 | |
|     if isinstance(err, vol.MultipleInvalid):
 | |
|         return err
 | |
|     if isinstance(err, list):
 | |
|         return vol.MultipleInvalid(err)
 | |
|     return vol.MultipleInvalid([err])
 | |
| 
 | |
| 
 | |
| # pylint: disable=protected-access, unidiomatic-typecheck
 | |
| class _Schema(vol.Schema):
 | |
|     """Custom cv.Schema that prints similar keys on error."""
 | |
| 
 | |
|     def __init__(
 | |
|         self, schema, required=False, extra=vol.PREVENT_EXTRA, extra_schemas=None
 | |
|     ):
 | |
|         super().__init__(schema, required=required, extra=extra)
 | |
|         # List of extra schemas to apply after validation
 | |
|         # Should be used sparingly, as it's not a very voluptuous-way/clean way of
 | |
|         # doing things.
 | |
|         self._extra_schemas = extra_schemas or []
 | |
| 
 | |
|     def __call__(self, data):
 | |
|         res = super().__call__(data)
 | |
|         for extra in self._extra_schemas:
 | |
|             try:
 | |
|                 res = extra(res)
 | |
|             except vol.Invalid as err:
 | |
|                 # pylint: disable=raise-missing-from
 | |
|                 raise ensure_multiple_invalid(err)
 | |
|         return res
 | |
| 
 | |
|     def _compile_mapping(self, schema, invalid_msg=None):
 | |
|         invalid_msg = invalid_msg or "mapping value"
 | |
| 
 | |
|         # Check some things that ESPHome's schemas do not allow
 | |
|         # mostly to keep the logic in this method sane (so these may be re-added if needed).
 | |
|         for key in schema:
 | |
|             if key is vol.Extra:
 | |
|                 raise ValueError("ESPHome does not allow vol.Extra")
 | |
|             if isinstance(key, vol.Remove):
 | |
|                 raise ValueError("ESPHome does not allow vol.Remove")
 | |
|             if isinstance(key, vol.primitive_types):
 | |
|                 raise ValueError(
 | |
|                     "All schema keys must be wrapped in cv.Required or cv.Optional"
 | |
|                 )
 | |
| 
 | |
|         # Keys that may be required
 | |
|         all_required_keys = {key for key in schema if isinstance(key, vol.Required)}
 | |
| 
 | |
|         # Keys that may have defaults
 | |
|         # This is a list because sets do not guarantee insertion order
 | |
|         all_default_keys = [key for key in schema if isinstance(key, vol.Optional)]
 | |
| 
 | |
|         # Recursively compile schema
 | |
|         _compiled_schema = {}
 | |
|         for skey, svalue in schema.items():
 | |
|             new_key = self._compile(skey)
 | |
|             new_value = self._compile(svalue)
 | |
|             _compiled_schema[skey] = (new_key, new_value)
 | |
| 
 | |
|         # Sort compiled schema (probably not necessary for esphome, but leave it here just in case)
 | |
|         candidates = list(
 | |
|             vol.schema_builder._iterate_mapping_candidates(_compiled_schema)
 | |
|         )
 | |
| 
 | |
|         # After we have the list of candidates in the correct order, we want to apply some
 | |
|         # optimization so that each
 | |
|         # key in the data being validated will be matched against the relevant schema keys only.
 | |
|         # No point in matching against different keys
 | |
|         additional_candidates = []
 | |
|         candidates_by_key = {}
 | |
|         for skey, (ckey, cvalue) in candidates:
 | |
|             if type(skey) in vol.primitive_types:
 | |
|                 candidates_by_key.setdefault(skey, []).append((skey, (ckey, cvalue)))
 | |
|             elif (
 | |
|                 isinstance(skey, vol.Marker)
 | |
|                 and type(skey.schema) in vol.primitive_types
 | |
|             ):
 | |
|                 candidates_by_key.setdefault(skey.schema, []).append(
 | |
|                     (skey, (ckey, cvalue))
 | |
|                 )
 | |
|             else:
 | |
|                 # These are wildcards such as 'int', 'str', 'Remove' and others which should be
 | |
|                 # applied to all keys
 | |
|                 additional_candidates.append((skey, (ckey, cvalue)))
 | |
| 
 | |
|         key_names = []
 | |
|         for skey in schema:
 | |
|             if isinstance(skey, str):
 | |
|                 key_names.append(skey)
 | |
|             elif isinstance(skey, vol.Marker) and isinstance(skey.schema, str):
 | |
|                 key_names.append(skey.schema)
 | |
| 
 | |
|         def validate_mapping(path, iterable, out):
 | |
|             required_keys = all_required_keys.copy()
 | |
| 
 | |
|             # Build a map of all provided key-value pairs.
 | |
|             # The type(out) is used to retain ordering in case a ordered
 | |
|             # map type is provided as input.
 | |
|             key_value_map = type(out)()
 | |
|             for key, value in iterable:
 | |
|                 key_value_map[key] = value
 | |
| 
 | |
|             # Insert default values for non-existing keys.
 | |
|             for key in all_default_keys:
 | |
|                 if (
 | |
|                     not isinstance(key.default, vol.Undefined)
 | |
|                     and key.schema not in key_value_map
 | |
|                 ):
 | |
|                     # A default value has been specified for this missing key, insert it.
 | |
|                     key_value_map[key.schema] = key.default()
 | |
| 
 | |
|             error = None
 | |
|             errors = []
 | |
|             for key, value in key_value_map.items():
 | |
|                 key_path = path + [key]
 | |
|                 # Optimization. Validate against the matching key first, then fallback to the rest
 | |
|                 relevant_candidates = itertools.chain(
 | |
|                     candidates_by_key.get(key, []), additional_candidates
 | |
|                 )
 | |
| 
 | |
|                 # compare each given key/value against all compiled key/values
 | |
|                 # schema key, (compiled key, compiled value)
 | |
|                 for skey, (ckey, cvalue) in relevant_candidates:
 | |
|                     try:
 | |
|                         new_key = ckey(key_path, key)
 | |
|                     except vol.Invalid as e:
 | |
|                         if len(e.path) > len(key_path):
 | |
|                             raise
 | |
|                         if not error or len(e.path) > len(error.path):
 | |
|                             error = e
 | |
|                         continue
 | |
|                     # Backtracking is not performed once a key is selected, so if
 | |
|                     # the value is invalid we immediately throw an exception.
 | |
|                     exception_errors = []
 | |
|                     try:
 | |
|                         cval = cvalue(key_path, value)
 | |
|                         out[new_key] = cval
 | |
|                     except vol.MultipleInvalid as e:
 | |
|                         exception_errors.extend(e.errors)
 | |
|                     except vol.Invalid as e:
 | |
|                         exception_errors.append(e)
 | |
| 
 | |
|                     if exception_errors:
 | |
|                         for err in exception_errors:
 | |
|                             if len(err.path) <= len(key_path):
 | |
|                                 err.error_type = invalid_msg
 | |
|                             errors.append(err)
 | |
|                         # If there is a validation error for a required
 | |
|                         # key, this means that the key was provided.
 | |
|                         # Discard the required key so it does not
 | |
|                         # create an additional, noisy exception.
 | |
|                         required_keys.discard(skey)
 | |
|                         break
 | |
| 
 | |
|                     # Key and value okay, mark as found in case it was
 | |
|                     # a Required() field.
 | |
|                     required_keys.discard(skey)
 | |
| 
 | |
|                     break
 | |
|                 else:
 | |
|                     if self.extra == vol.ALLOW_EXTRA:
 | |
|                         out[key] = value
 | |
|                     elif self.extra != vol.REMOVE_EXTRA:
 | |
|                         if isinstance(key, str) and key_names:
 | |
|                             matches = difflib.get_close_matches(key, key_names)
 | |
|                             errors.append(
 | |
|                                 ExtraKeysInvalid(
 | |
|                                     "extra keys not allowed",
 | |
|                                     key_path,
 | |
|                                     candidates=matches,
 | |
|                                 )
 | |
|                             )
 | |
|                         else:
 | |
|                             errors.append(
 | |
|                                 vol.Invalid("extra keys not allowed", key_path)
 | |
|                             )
 | |
| 
 | |
|             # for any required keys left that weren't found and don't have defaults:
 | |
|             for key in required_keys:
 | |
|                 msg = getattr(key, "msg", None) or "required key not provided"
 | |
|                 errors.append(vol.RequiredFieldInvalid(msg, path + [key]))
 | |
|             if errors:
 | |
|                 raise vol.MultipleInvalid(errors)
 | |
| 
 | |
|             return out
 | |
| 
 | |
|         return validate_mapping
 | |
| 
 | |
|     def add_extra(self, validator):
 | |
|         validator = _Schema(validator)
 | |
|         self._extra_schemas.append(validator)
 | |
|         return self
 | |
| 
 | |
|     def prepend_extra(self, validator):
 | |
|         validator = _Schema(validator)
 | |
|         self._extra_schemas.insert(0, validator)
 | |
|         return self
 | |
| 
 | |
|     @schema_extractor_extended
 | |
|     def extend(self, *schemas, **kwargs):
 | |
|         extra = kwargs.pop("extra", None)
 | |
|         if kwargs:
 | |
|             raise ValueError
 | |
|         if not schemas:
 | |
|             return self.extend({})
 | |
|         if len(schemas) != 1:
 | |
|             ret = self
 | |
|             for schema in schemas:
 | |
|                 ret = ret.extend(schema)
 | |
|             return ret
 | |
| 
 | |
|         schema = schemas[0]
 | |
|         extra_schemas = self._extra_schemas.copy()
 | |
|         if isinstance(schema, _Schema):
 | |
|             extra_schemas.extend(schema._extra_schemas)
 | |
|         if isinstance(schema, vol.Schema):
 | |
|             schema = schema.schema
 | |
|         ret = super().extend(schema, extra=extra)
 | |
|         return _Schema(ret.schema, extra=ret.extra, extra_schemas=extra_schemas)
 |