mirror of
https://github.com/ARM-software/workload-automation.git
synced 2025-09-03 03:42:35 +01:00
New target description + moving target stuff under "framework"
Changing the way target descriptions work from a static mapping to something that is dynamically generated and is extensible via plugins. Also moving core target implementation stuff under "framework".
This commit is contained in:
544
wa/utils/misc.py
544
wa/utils/misc.py
@@ -24,7 +24,6 @@ import sys
|
||||
import re
|
||||
import math
|
||||
import imp
|
||||
import uuid
|
||||
import string
|
||||
import threading
|
||||
import signal
|
||||
@@ -33,154 +32,28 @@ import pkgutil
|
||||
import traceback
|
||||
import logging
|
||||
import random
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from operator import mul, itemgetter
|
||||
from StringIO import StringIO
|
||||
from itertools import cycle, groupby
|
||||
from itertools import cycle, groupby, chain
|
||||
from functools import partial
|
||||
from distutils.spawn import find_executable
|
||||
|
||||
import yaml
|
||||
from dateutil import tz
|
||||
|
||||
from wa.framework.version import get_wa_version
|
||||
|
||||
|
||||
# ABI --> architectures list
|
||||
ABI_MAP = {
|
||||
'armeabi': ['armeabi', 'armv7', 'armv7l', 'armv7el', 'armv7lh'],
|
||||
'arm64': ['arm64', 'armv8', 'arm64-v8a'],
|
||||
}
|
||||
|
||||
|
||||
def preexec_function():
|
||||
# Ignore the SIGINT signal by setting the handler to the standard
|
||||
# signal handler SIG_IGN.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
# Change process group in case we have to kill the subprocess and all of
|
||||
# its children later.
|
||||
# TODO: this is Unix-specific; would be good to find an OS-agnostic way
|
||||
# to do this in case we wanna port WA to Windows.
|
||||
os.setpgrp()
|
||||
|
||||
from devlib.utils.misc import (ABI_MAP, check_output, walk_modules,
|
||||
ensure_directory_exists, ensure_file_directory_exists,
|
||||
normalize, convert_new_lines, get_cpu_mask, unique,
|
||||
escape_quotes, escape_single_quotes, escape_double_quotes,
|
||||
isiterable, getch, as_relative, ranges_to_list,
|
||||
list_to_ranges, list_to_mask, mask_to_list, which)
|
||||
|
||||
check_output_logger = logging.getLogger('check_output')
|
||||
|
||||
|
||||
# Defined here rather than in wlauto.exceptions due to module load dependencies
|
||||
class TimeoutError(Exception):
|
||||
"""Raised when a subprocess command times out. This is basically a ``WAError``-derived version
|
||||
of ``subprocess.CalledProcessError``, the thinking being that while a timeout could be due to
|
||||
programming error (e.g. not setting long enough timers), it is often due to some failure in the
|
||||
environment, and there fore should be classed as a "user error"."""
|
||||
|
||||
def __init__(self, command, output):
|
||||
super(TimeoutError, self).__init__('Timed out: {}'.format(command))
|
||||
self.command = command
|
||||
self.output = output
|
||||
|
||||
def __str__(self):
|
||||
return '\n'.join([self.message, 'OUTPUT:', self.output or ''])
|
||||
|
||||
|
||||
def check_output(command, timeout=None, ignore=None, **kwargs):
|
||||
"""This is a version of subprocess.check_output that adds a timeout parameter to kill
|
||||
the subprocess if it does not return within the specified time."""
|
||||
# pylint: disable=too-many-branches
|
||||
if ignore is None:
|
||||
ignore = []
|
||||
elif isinstance(ignore, int):
|
||||
ignore = [ignore]
|
||||
elif not isinstance(ignore, list) and ignore != 'all':
|
||||
message = 'Invalid value for ignore parameter: "{}"; must be an int or a list'
|
||||
raise ValueError(message.format(ignore))
|
||||
if 'stdout' in kwargs:
|
||||
raise ValueError('stdout argument not allowed, it will be overridden.')
|
||||
|
||||
def callback(pid):
|
||||
try:
|
||||
check_output_logger.debug('{} timed out; sending SIGKILL'.format(pid))
|
||||
os.killpg(pid, signal.SIGKILL)
|
||||
except OSError:
|
||||
pass # process may have already terminated.
|
||||
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
preexec_fn=preexec_function, **kwargs)
|
||||
|
||||
if timeout:
|
||||
timer = threading.Timer(timeout, callback, [process.pid, ])
|
||||
timer.start()
|
||||
|
||||
try:
|
||||
output, error = process.communicate()
|
||||
finally:
|
||||
if timeout:
|
||||
timer.cancel()
|
||||
|
||||
retcode = process.poll()
|
||||
if retcode:
|
||||
if retcode == -9: # killed, assume due to timeout callback
|
||||
raise TimeoutError(command, output='\n'.join([output, error]))
|
||||
elif ignore != 'all' and retcode not in ignore:
|
||||
raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output, error]))
|
||||
return output, error
|
||||
|
||||
|
||||
def init_argument_parser(parser):
|
||||
parser.add_argument('-c', '--config', help='specify an additional config.py')
|
||||
parser.add_argument('-v', '--verbose', action='count',
|
||||
help='The scripts will produce verbose output.')
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Enable debug mode. Note: this implies --verbose.')
|
||||
parser.add_argument('--version', action='version', version='%(prog)s {}'.format(get_wa_version()))
|
||||
return parser
|
||||
|
||||
|
||||
def walk_modules(path):
|
||||
"""
|
||||
Given a path to a Python package, iterate over all the modules and
|
||||
sub-packages in that package.
|
||||
|
||||
"""
|
||||
try:
|
||||
root_mod = __import__(path, {}, {}, [''])
|
||||
yield root_mod
|
||||
except ImportError as e:
|
||||
e.path = path
|
||||
raise e
|
||||
if not hasattr(root_mod, '__path__'): # module, not package
|
||||
return
|
||||
for _, name, ispkg in pkgutil.iter_modules(root_mod.__path__):
|
||||
try:
|
||||
submod_path = '.'.join([path, name])
|
||||
if ispkg:
|
||||
for submod in walk_modules(submod_path):
|
||||
yield submod
|
||||
else:
|
||||
yield __import__(submod_path, {}, {}, [''])
|
||||
except ImportError as e:
|
||||
e.path = submod_path
|
||||
raise e
|
||||
|
||||
|
||||
def ensure_directory_exists(dirpath):
|
||||
"""A filter for directory paths to ensure they exist."""
|
||||
if not os.path.isdir(dirpath):
|
||||
os.makedirs(dirpath)
|
||||
return dirpath
|
||||
|
||||
|
||||
def ensure_file_directory_exists(filepath):
|
||||
"""
|
||||
A filter for file paths to ensure the directory of the
|
||||
file exists and the file can be created there. The file
|
||||
itself is *not* going to be created if it doesn't already
|
||||
exist.
|
||||
|
||||
"""
|
||||
ensure_directory_exists(os.path.dirname(filepath))
|
||||
return filepath
|
||||
|
||||
|
||||
def diff_tokens(before_token, after_token):
|
||||
"""
|
||||
Creates a diff of two tokens.
|
||||
@@ -269,22 +142,18 @@ def get_traceback(exc=None):
|
||||
return sio.getvalue()
|
||||
|
||||
|
||||
def normalize(value, dict_type=dict):
|
||||
"""Normalize values. Recursively normalizes dict keys to be lower case,
|
||||
no surrounding whitespace, underscore-delimited strings."""
|
||||
if isinstance(value, dict):
|
||||
normalized = dict_type()
|
||||
for k, v in value.iteritems():
|
||||
if isinstance(k, basestring):
|
||||
k = k.strip().lower().replace(' ', '_')
|
||||
normalized[k] = normalize(v, dict_type)
|
||||
return normalized
|
||||
elif isinstance(value, list):
|
||||
return [normalize(v, dict_type) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple([normalize(v, dict_type) for v in value])
|
||||
else:
|
||||
return value
|
||||
def _check_remove_item(the_list, item):
|
||||
"""Helper function for merge_lists that implements checking wether an items
|
||||
should be removed from the list and doing so if needed. Returns ``True`` if
|
||||
the item has been removed and ``False`` otherwise."""
|
||||
if not isinstance(item, basestring):
|
||||
return False
|
||||
if not item.startswith('~'):
|
||||
return False
|
||||
actual_item = item[1:]
|
||||
if actual_item in the_list:
|
||||
del the_list[the_list.index(actual_item)]
|
||||
return True
|
||||
|
||||
|
||||
VALUE_REGEX = re.compile(r'(\d+(?:\.\d+)?)\s*(\w*)')
|
||||
@@ -338,50 +207,6 @@ def capitalize(text):
|
||||
return text[0].upper() + text[1:].lower()
|
||||
|
||||
|
||||
def convert_new_lines(text):
|
||||
""" Convert new lines to a common format. """
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def escape_quotes(text):
|
||||
"""Escape quotes, and escaped quotes, in the specified text."""
|
||||
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\\\'').replace('\"', '\\\"')
|
||||
|
||||
|
||||
def escape_single_quotes(text):
|
||||
"""Escape single quotes, and escaped single quotes, in the specified text."""
|
||||
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\'\\\'\'')
|
||||
|
||||
|
||||
def escape_double_quotes(text):
|
||||
"""Escape double quotes, and escaped double quotes, in the specified text."""
|
||||
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\"', '\\\"')
|
||||
|
||||
|
||||
def getch(count=1):
|
||||
"""Read ``count`` characters from standard input."""
|
||||
if os.name == 'nt':
|
||||
import msvcrt # pylint: disable=F0401
|
||||
return ''.join([msvcrt.getch() for _ in xrange(count)])
|
||||
else: # assume Unix
|
||||
import tty # NOQA
|
||||
import termios # NOQA
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
ch = sys.stdin.read(count)
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
return ch
|
||||
|
||||
|
||||
def isiterable(obj):
|
||||
"""Returns ``True`` if the specified object is iterable and
|
||||
*is not a string type*, ``False`` otherwise."""
|
||||
return hasattr(obj, '__iter__') and not isinstance(obj, basestring)
|
||||
|
||||
|
||||
def utc_to_local(dt):
|
||||
"""Convert naive datetime to local time zone, assuming UTC."""
|
||||
return dt.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal())
|
||||
@@ -392,21 +217,6 @@ def local_to_utc(dt):
|
||||
return dt.replace(tzinfo=tz.tzlocal()).astimezone(tz.tzutc())
|
||||
|
||||
|
||||
def as_relative(path):
|
||||
"""Convert path to relative by stripping away the leading '/' on UNIX or
|
||||
the equivant on other platforms."""
|
||||
path = os.path.splitdrive(path)[1]
|
||||
return path.lstrip(os.sep)
|
||||
|
||||
|
||||
def get_cpu_mask(cores):
|
||||
"""Return a string with the hex for the cpu mask for the specified core numbers."""
|
||||
mask = 0
|
||||
for i in cores:
|
||||
mask |= 1 << i
|
||||
return '0x{0:x}'.format(mask)
|
||||
|
||||
|
||||
def load_class(classpath):
|
||||
"""Loads the specified Python class. ``classpath`` must be a fully-qualified
|
||||
class name (i.e. namspaced under module/package)."""
|
||||
@@ -468,29 +278,7 @@ def enum_metaclass(enum_param, return_name=False, start=0):
|
||||
return __EnumMeta
|
||||
|
||||
|
||||
def which(name):
|
||||
"""Platform-independent version of UNIX which utility."""
|
||||
if os.name == 'nt':
|
||||
paths = os.getenv('PATH').split(os.pathsep)
|
||||
exts = os.getenv('PATHEXT').split(os.pathsep)
|
||||
for path in paths:
|
||||
testpath = os.path.join(path, name)
|
||||
if os.path.isfile(testpath):
|
||||
return testpath
|
||||
for ext in exts:
|
||||
testpathext = testpath + ext
|
||||
if os.path.isfile(testpathext):
|
||||
return testpathext
|
||||
return None
|
||||
else: # assume UNIX-like
|
||||
try:
|
||||
result = check_output(['which', name])[0]
|
||||
return result.strip() # pylint: disable=E1103
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
_bash_color_regex = re.compile('\x1b\\[[0-9;]+m')
|
||||
_bash_color_regex = re.compile('\x1b\[[0-9;]+m')
|
||||
|
||||
|
||||
def strip_bash_colors(text):
|
||||
@@ -536,6 +324,18 @@ def get_random_string(length):
|
||||
return ''.join(random.choice(string.ascii_letters + string.digits) for _ in xrange(length))
|
||||
|
||||
|
||||
class LoadSyntaxError(Exception):
|
||||
|
||||
def __init__(self, message, filepath, lineno):
|
||||
super(LoadSyntaxError, self).__init__(message)
|
||||
self.filepath = filepath
|
||||
self.lineno = lineno
|
||||
|
||||
def __str__(self):
|
||||
message = 'Syntax Error in {}, line {}:\n\t{}'
|
||||
return message.format(self.filepath, self.lineno, self.message)
|
||||
|
||||
|
||||
RAND_MOD_NAME_LEN = 30
|
||||
BAD_CHARS = string.punctuation + string.whitespace
|
||||
TRANS_TABLE = string.maketrans(BAD_CHARS, '_' * len(BAD_CHARS))
|
||||
@@ -544,23 +344,63 @@ TRANS_TABLE = string.maketrans(BAD_CHARS, '_' * len(BAD_CHARS))
|
||||
def to_identifier(text):
|
||||
"""Converts text to a valid Python identifier by replacing all
|
||||
whitespace and punctuation."""
|
||||
result = re.sub('_+', '_', text.translate(TRANS_TABLE))
|
||||
if result and result[0] in string.digits:
|
||||
result = '_' + result
|
||||
return result
|
||||
return re.sub('_+', '_', text.translate(TRANS_TABLE))
|
||||
|
||||
|
||||
def unique(alist):
|
||||
def load_struct_from_python(filepath=None, text=None):
|
||||
"""Parses a config structure from a .py file. The structure should be composed
|
||||
of basic Python types (strings, ints, lists, dicts, etc.)."""
|
||||
if not (filepath or text) or (filepath and text):
|
||||
raise ValueError('Exactly one of filepath or text must be specified.')
|
||||
try:
|
||||
if filepath:
|
||||
modname = to_identifier(filepath)
|
||||
mod = imp.load_source(modname, filepath)
|
||||
else:
|
||||
modname = get_random_string(RAND_MOD_NAME_LEN)
|
||||
while modname in sys.modules: # highly unlikely, but...
|
||||
modname = get_random_string(RAND_MOD_NAME_LEN)
|
||||
mod = imp.new_module(modname)
|
||||
exec text in mod.__dict__ # pylint: disable=exec-used
|
||||
return dict((k, v)
|
||||
for k, v in mod.__dict__.iteritems()
|
||||
if not k.startswith('_'))
|
||||
except SyntaxError as e:
|
||||
raise LoadSyntaxError(e.message, filepath, e.lineno)
|
||||
|
||||
|
||||
def load_struct_from_yaml(filepath=None, text=None):
|
||||
"""Parses a config structure from a .yaml file. The structure should be composed
|
||||
of basic Python types (strings, ints, lists, dicts, etc.)."""
|
||||
if not (filepath or text) or (filepath and text):
|
||||
raise ValueError('Exactly one of filepath or text must be specified.')
|
||||
try:
|
||||
if filepath:
|
||||
with open(filepath) as fh:
|
||||
return yaml.load(fh)
|
||||
else:
|
||||
return yaml.load(text)
|
||||
except yaml.YAMLError as e:
|
||||
lineno = None
|
||||
if hasattr(e, 'problem_mark'):
|
||||
lineno = e.problem_mark.line # pylint: disable=no-member
|
||||
raise LoadSyntaxError(e.message, filepath=filepath, lineno=lineno)
|
||||
|
||||
|
||||
def load_struct_from_file(filepath):
|
||||
"""
|
||||
Returns a list containing only unique elements from the input list (but preserves
|
||||
order, unlike sets).
|
||||
Attempts to parse a Python structure consisting of basic types from the specified file.
|
||||
Raises a ``ValueError`` if the specified file is of unkown format; ``LoadSyntaxError`` if
|
||||
there is an issue parsing the file.
|
||||
|
||||
"""
|
||||
result = []
|
||||
for item in alist:
|
||||
if item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
extn = os.path.splitext(filepath)[1].lower()
|
||||
if (extn == '.py') or (extn == '.pyc') or (extn == '.pyo'):
|
||||
return load_struct_from_python(filepath)
|
||||
elif extn == '.yaml':
|
||||
return load_struct_from_yaml(filepath)
|
||||
else:
|
||||
raise ValueError('Unknown format "{}": {}'.format(extn, filepath))
|
||||
|
||||
|
||||
def open_file(filepath):
|
||||
@@ -576,68 +416,170 @@ def open_file(filepath):
|
||||
return subprocess.call(['xdg-open', filepath])
|
||||
|
||||
|
||||
def ranges_to_list(ranges_string):
|
||||
"""Converts a sysfs-style ranges string, e.g. ``"0,2-4"``, into a list ,e.g ``[0,2,3,4]``"""
|
||||
values = []
|
||||
for rg in ranges_string.split(','):
|
||||
if '-' in rg:
|
||||
first, last = map(int, rg.split('-'))
|
||||
values.extend(xrange(first, last + 1))
|
||||
else:
|
||||
values.append(int(rg))
|
||||
return values
|
||||
def sha256(path, chunk=2048):
|
||||
"""Calculates SHA256 hexdigest of the file at the specified path."""
|
||||
h = hashlib.sha256()
|
||||
with open(path, 'rb') as fh:
|
||||
buf = fh.read(chunk)
|
||||
while buf:
|
||||
h.update(buf)
|
||||
buf = fh.read(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def list_to_ranges(values):
|
||||
"""Converts a list, e.g ``[0,2,3,4]``, into a sysfs-style ranges string, e.g. ``"0,2-4"``"""
|
||||
range_groups = []
|
||||
for _, g in groupby(enumerate(values), lambda (i, x): i - x):
|
||||
range_groups.append(map(itemgetter(1), g))
|
||||
range_strings = []
|
||||
for group in range_groups:
|
||||
if len(group) == 1:
|
||||
range_strings.append(str(group[0]))
|
||||
else:
|
||||
range_strings.append('{}-{}'.format(group[0], group[-1]))
|
||||
return ','.join(range_strings)
|
||||
def urljoin(*parts):
|
||||
return '/'.join(p.rstrip('/') for p in parts)
|
||||
|
||||
|
||||
def list_to_mask(values, base=0x0):
|
||||
"""Converts the specified list of integer values into
|
||||
a bit mask for those values. Optinally, the list can be
|
||||
applied to an existing mask."""
|
||||
for v in values:
|
||||
base |= (1 << v)
|
||||
return base
|
||||
|
||||
|
||||
def mask_to_list(mask):
|
||||
"""Converts the specfied integer bitmask into a list of
|
||||
indexes of bits that are set in the mask."""
|
||||
size = len(bin(mask)) - 2 # because of "0b"
|
||||
return [size - i - 1 for i in xrange(size)
|
||||
if mask & (1 << size - i - 1)]
|
||||
|
||||
|
||||
class Namespace(dict):
|
||||
# From: http://eli.thegreenplace.net/2011/10/19/perls-guess-if-file-is-text-or-binary-implemented-in-python/
|
||||
def istextfile(fileobj, blocksize=512):
|
||||
""" Uses heuristics to guess whether the given file is text or binary,
|
||||
by reading a single block of bytes from the file.
|
||||
If more than 30% of the chars in the block are non-text, or there
|
||||
are NUL ('\x00') bytes in the block, assume this is a binary file.
|
||||
"""
|
||||
A dict-like object that allows treating keys and attributes
|
||||
interchangeably (this means that keys are restricted to strings
|
||||
that are valid Python identifiers).
|
||||
_text_characters = (b''.join(chr(i) for i in range(32, 127)) +
|
||||
b'\n\r\t\f\b')
|
||||
|
||||
block = fileobj.read(blocksize)
|
||||
if b'\x00' in block:
|
||||
# Files with null bytes are binary
|
||||
return False
|
||||
elif not block:
|
||||
# An empty file is considered a valid text file
|
||||
return True
|
||||
|
||||
# Use translate's 'deletechars' argument to efficiently remove all
|
||||
# occurrences of _text_characters from the block
|
||||
nontext = block.translate(None, _text_characters)
|
||||
return float(len(nontext)) / len(block) <= 0.30
|
||||
|
||||
|
||||
def categorize(v):
|
||||
if hasattr(v, 'merge_with') and hasattr(v, 'merge_into'):
|
||||
return 'o'
|
||||
elif hasattr(v, 'iteritems'):
|
||||
return 'm'
|
||||
elif isiterable(v):
|
||||
return 's'
|
||||
elif v is None:
|
||||
return 'n'
|
||||
else:
|
||||
return 'c'
|
||||
|
||||
|
||||
def merge_config_values(base, other):
|
||||
"""
|
||||
This is used to merge two objects, typically when setting the value of a
|
||||
``ConfigurationPoint``. First, both objects are categorized into
|
||||
|
||||
c: A scalar value. Basically, most objects. These values
|
||||
are treated as atomic, and not mergeable.
|
||||
s: A sequence. Anything iterable that is not a dict or
|
||||
a string (strings are considered scalars).
|
||||
m: A key-value mapping. ``dict`` and its derivatives.
|
||||
n: ``None``.
|
||||
o: A mergeable object; this is an object that implements both
|
||||
``merge_with`` and ``merge_into`` methods.
|
||||
|
||||
The merge rules based on the two categories are then as follows:
|
||||
|
||||
(c1, c2) --> c2
|
||||
(s1, s2) --> s1 . s2
|
||||
(m1, m2) --> m1 . m2
|
||||
(c, s) --> [c] . s
|
||||
(s, c) --> s . [c]
|
||||
(s, m) --> s . [m]
|
||||
(m, s) --> [m] . s
|
||||
(m, c) --> ERROR
|
||||
(c, m) --> ERROR
|
||||
(o, X) --> o.merge_with(X)
|
||||
(X, o) --> o.merge_into(X)
|
||||
(X, n) --> X
|
||||
(n, X) --> X
|
||||
|
||||
where:
|
||||
|
||||
'.' means concatenation (for maps, contcationation of (k, v) streams
|
||||
then converted back into a map). If the types of the two objects
|
||||
differ, the type of ``other`` is used for the result.
|
||||
'X' means "any category"
|
||||
'[]' used to indicate a literal sequence (not necessarily a ``list``).
|
||||
when this is concatenated with an actual sequence, that sequencies
|
||||
type is used.
|
||||
|
||||
notes:
|
||||
|
||||
- When a mapping is combined with a sequence, that mapping is
|
||||
treated as a scalar value.
|
||||
- When combining two mergeable objects, they're combined using
|
||||
``o1.merge_with(o2)`` (_not_ using o2.merge_into(o1)).
|
||||
- Combining anything with ``None`` yields that value, irrespective
|
||||
of the order. So a ``None`` value is eqivalent to the corresponding
|
||||
item being omitted.
|
||||
- When both values are scalars, merging is equivalent to overwriting.
|
||||
- There is no recursion (e.g. if map values are lists, they will not
|
||||
be merged; ``other`` will overwrite ``base`` values). If complicated
|
||||
merging semantics (such as recursion) are required, they should be
|
||||
implemented within custom mergeable types (i.e. those that implement
|
||||
``merge_with`` and ``merge_into``).
|
||||
|
||||
While this can be used as a generic "combine any two arbitry objects"
|
||||
function, the semantics have been selected specifically for merging
|
||||
configuration point values.
|
||||
|
||||
"""
|
||||
cat_base = categorize(base)
|
||||
cat_other = categorize(other)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
if cat_base == 'n':
|
||||
return other
|
||||
elif cat_other == 'n':
|
||||
return base
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
if cat_base == 'o':
|
||||
return base.merge_with(other)
|
||||
elif cat_other == 'o':
|
||||
return other.merge_into(base)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if to_identifier(name) != name:
|
||||
message = 'Key must be a valid identifier; got "{}"'
|
||||
raise ValueError(message.format(name))
|
||||
dict.__setitem__(self, name, value)
|
||||
if cat_base == 'm':
|
||||
if cat_other == 's':
|
||||
return merge_sequencies([base], other)
|
||||
elif cat_other == 'm':
|
||||
return merge_maps(base, other)
|
||||
else:
|
||||
message = 'merge error ({}, {}): "{}" and "{}"'
|
||||
raise ValueError(message.format(cat_base, cat_other, base, other))
|
||||
elif cat_base == 's':
|
||||
if cat_other == 's':
|
||||
return merge_sequencies(base, other)
|
||||
else:
|
||||
return merge_sequencies(base, [other])
|
||||
else: # cat_base == 'c'
|
||||
if cat_other == 's':
|
||||
return merge_sequencies([base], other)
|
||||
elif cat_other == 'm':
|
||||
message = 'merge error ({}, {}): "{}" and "{}"'
|
||||
raise ValueError(message.format(cat_base, cat_other, base, other))
|
||||
else:
|
||||
return other
|
||||
|
||||
|
||||
def merge_sequencies(s1, s2):
|
||||
return type(s2)(unique(chain(s1, s2)))
|
||||
|
||||
|
||||
def merge_maps(m1, m2):
|
||||
return type(m2)(chain(m1.iteritems(), m2.iteritems()))
|
||||
|
||||
|
||||
def merge_dicts_simple(base, other):
|
||||
result = base.copy()
|
||||
for key, value in (base or {}).iteritems():
|
||||
result[key] = merge_config_values(result.get(key), value)
|
||||
return result
|
||||
|
||||
|
||||
def touch(path):
|
||||
with open(path, 'w'):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user