mirror of
https://github.com/ARM-software/workload-automation.git
synced 2025-09-06 05:12:36 +01:00
Initial commit of open source Workload Automation.
This commit is contained in:
703
wlauto/utils/misc.py
Normal file
703
wlauto/utils/misc.py
Normal file
@@ -0,0 +1,703 @@
|
||||
# Copyright 2013-2015 ARM Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
"""
|
||||
Miscellaneous functions that don't fit anywhere else.
|
||||
|
||||
"""
|
||||
from __future__ import division
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import math
|
||||
import imp
|
||||
import string
|
||||
import threading
|
||||
import signal
|
||||
import subprocess
|
||||
import pkgutil
|
||||
import traceback
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from operator import mul
|
||||
from StringIO import StringIO
|
||||
from itertools import cycle
|
||||
from functools import partial
|
||||
from distutils.spawn import find_executable
|
||||
|
||||
import yaml
|
||||
from dateutil import tz
|
||||
|
||||
|
||||
def preexec_function():
|
||||
# Ignore the SIGINT signal by setting the handler to the standard
|
||||
# signal handler SIG_IGN.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
# Change process group in case we have to kill the subprocess and all of
|
||||
# its children later.
|
||||
# TODO: this is Unix-specific; would be good to find an OS-agnostic way
|
||||
# to do this in case we wanna port WA to Windows.
|
||||
os.setpgrp()
|
||||
|
||||
|
||||
check_output_logger = logging.getLogger('check_output')
|
||||
|
||||
|
||||
# Defined here rather than in wlauto.exceptions due to module load dependencies
|
||||
class TimeoutError(Exception):
|
||||
"""Raised when a subprocess command times out. This is basically a ``WAError``-derived version
|
||||
of ``subprocess.CalledProcessError``, the thinking being that while a timeout could be due to
|
||||
programming error (e.g. not setting long enough timers), it is often due to some failure in the
|
||||
environment, and there fore should be classed as a "user error"."""
|
||||
|
||||
def __init__(self, command, output):
|
||||
super(TimeoutError, self).__init__('Timed out: {}'.format(command))
|
||||
self.command = command
|
||||
self.output = output
|
||||
|
||||
def __str__(self):
|
||||
return '\n'.join([self.message, 'OUTPUT:', self.output or ''])
|
||||
|
||||
|
||||
def check_output(command, timeout=None, **kwargs):
|
||||
"""This is a version of subprocess.check_output that adds a timeout parameter to kill
|
||||
the subprocess if it does not return within the specified time."""
|
||||
|
||||
def callback(pid):
|
||||
try:
|
||||
check_output_logger.debug('{} timed out; sending SIGKILL'.format(pid))
|
||||
os.killpg(pid, signal.SIGKILL)
|
||||
except OSError:
|
||||
pass # process may have already terminated.
|
||||
|
||||
if 'stdout' in kwargs:
|
||||
raise ValueError('stdout argument not allowed, it will be overridden.')
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
preexec_fn=preexec_function, **kwargs)
|
||||
|
||||
if timeout:
|
||||
timer = threading.Timer(timeout, callback, [process.pid, ])
|
||||
timer.start()
|
||||
|
||||
try:
|
||||
output, error = process.communicate()
|
||||
finally:
|
||||
if timeout:
|
||||
timer.cancel()
|
||||
|
||||
retcode = process.poll()
|
||||
if retcode:
|
||||
if retcode == -9: # killed, assume due to timeout callback
|
||||
raise TimeoutError(command, output='\n'.join([output, error]))
|
||||
else:
|
||||
raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output, error]))
|
||||
return output, error
|
||||
|
||||
|
||||
def walk_modules(path):
|
||||
"""
|
||||
Given package name, return a list of all modules (including submodules, etc)
|
||||
in that package.
|
||||
|
||||
"""
|
||||
root_mod = __import__(path, {}, {}, [''])
|
||||
mods = [root_mod]
|
||||
for _, name, ispkg in pkgutil.iter_modules(root_mod.__path__):
|
||||
submod_path = '.'.join([path, name])
|
||||
if ispkg:
|
||||
mods.extend(walk_modules(submod_path))
|
||||
else:
|
||||
submod = __import__(submod_path, {}, {}, [''])
|
||||
mods.append(submod)
|
||||
return mods
|
||||
|
||||
|
||||
def ensure_directory_exists(dirpath):
|
||||
"""A filter for directory paths to ensure they exist."""
|
||||
if not os.path.isdir(dirpath):
|
||||
os.makedirs(dirpath)
|
||||
return dirpath
|
||||
|
||||
|
||||
def ensure_file_directory_exists(filepath):
|
||||
"""
|
||||
A filter for file paths to ensure the directory of the
|
||||
file exists and the file can be created there. The file
|
||||
itself is *not* going to be created if it doesn't already
|
||||
exist.
|
||||
|
||||
"""
|
||||
ensure_directory_exists(os.path.dirname(filepath))
|
||||
return filepath
|
||||
|
||||
|
||||
def diff_tokens(before_token, after_token):
|
||||
"""
|
||||
Creates a diff of two tokens.
|
||||
|
||||
If the two tokens are the same it just returns returns the token
|
||||
(whitespace tokens are considered the same irrespective of type/number
|
||||
of whitespace characters in the token).
|
||||
|
||||
If the tokens are numeric, the difference between the two values
|
||||
is returned.
|
||||
|
||||
Otherwise, a string in the form [before -> after] is returned.
|
||||
|
||||
"""
|
||||
if before_token.isspace() and after_token.isspace():
|
||||
return after_token
|
||||
elif before_token.isdigit() and after_token.isdigit():
|
||||
try:
|
||||
diff = int(after_token) - int(before_token)
|
||||
return str(diff)
|
||||
except ValueError:
|
||||
return "[%s -> %s]" % (before_token, after_token)
|
||||
elif before_token == after_token:
|
||||
return after_token
|
||||
else:
|
||||
return "[%s -> %s]" % (before_token, after_token)
|
||||
|
||||
|
||||
def prepare_table_rows(rows):
|
||||
"""Given a list of lists, make sure they are prepared to be formatted into a table
|
||||
by making sure each row has the same number of columns and stringifying all values."""
|
||||
rows = [map(str, r) for r in rows]
|
||||
max_cols = max(map(len, rows))
|
||||
for row in rows:
|
||||
pad = max_cols - len(row)
|
||||
for _ in xrange(pad):
|
||||
row.append('')
|
||||
return rows
|
||||
|
||||
|
||||
def write_table(rows, wfh, align='>', headers=None): # pylint: disable=R0914
|
||||
"""Write a column-aligned table to the specified file object."""
|
||||
if not rows:
|
||||
return
|
||||
rows = prepare_table_rows(rows)
|
||||
num_cols = len(rows[0])
|
||||
|
||||
# cycle specified alignments until we have max_cols of them. This is
|
||||
# consitent with how such cases are handled in R, pandas, etc.
|
||||
it = cycle(align)
|
||||
align = [it.next() for _ in xrange(num_cols)]
|
||||
|
||||
cols = zip(*rows)
|
||||
col_widths = [max(map(len, c)) for c in cols]
|
||||
row_format = ' '.join(['{:%s%s}' % (align[i], w) for i, w in enumerate(col_widths)])
|
||||
row_format += '\n'
|
||||
|
||||
if headers:
|
||||
wfh.write(row_format.format(*headers))
|
||||
underlines = ['-' * len(h) for h in headers]
|
||||
wfh.write(row_format.format(*underlines))
|
||||
|
||||
for row in rows:
|
||||
wfh.write(row_format.format(*row))
|
||||
|
||||
|
||||
def get_null():
|
||||
"""Returns the correct null sink based on the OS."""
|
||||
return 'NUL' if os.name == 'nt' else '/dev/null'
|
||||
|
||||
|
||||
def get_traceback(exc=None):
|
||||
"""
|
||||
Returns the string with the traceback for the specifiec exc
|
||||
object, or for the current exception exc is not specified.
|
||||
|
||||
"""
|
||||
if exc is None:
|
||||
exc = sys.exc_info()
|
||||
if not exc:
|
||||
return None
|
||||
tb = exc[2]
|
||||
sio = StringIO()
|
||||
traceback.print_tb(tb, file=sio)
|
||||
del tb # needs to be done explicitly see: http://docs.python.org/2/library/sys.html#sys.exc_info
|
||||
return sio.getvalue()
|
||||
|
||||
|
||||
def merge_dicts(*args, **kwargs):
|
||||
if not len(args) >= 2:
|
||||
raise ValueError('Must specify at least two dicts to merge.')
|
||||
func = partial(_merge_two_dicts, **kwargs)
|
||||
return reduce(func, args)
|
||||
|
||||
|
||||
def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # pylint: disable=R0912,R0914
|
||||
dict_type=dict, should_normalize=True, should_merge_lists=True):
|
||||
"""Merge dicts normalizing their keys."""
|
||||
merged = dict_type()
|
||||
base_keys = base.keys()
|
||||
other_keys = other.keys()
|
||||
norm = normalize if should_normalize else lambda x, y: x
|
||||
|
||||
base_only = []
|
||||
other_only = []
|
||||
both = []
|
||||
union = []
|
||||
for k in base_keys:
|
||||
if k in other_keys:
|
||||
both.append(k)
|
||||
else:
|
||||
base_only.append(k)
|
||||
union.append(k)
|
||||
for k in other_keys:
|
||||
if k in base_keys:
|
||||
union.append(k)
|
||||
else:
|
||||
union.append(k)
|
||||
other_only.append(k)
|
||||
|
||||
for k in union:
|
||||
if k in base_only:
|
||||
merged[k] = norm(base[k], dict_type)
|
||||
elif k in other_only:
|
||||
merged[k] = norm(other[k], dict_type)
|
||||
elif k in both:
|
||||
base_value = base[k]
|
||||
other_value = other[k]
|
||||
base_type = type(base_value)
|
||||
other_type = type(other_value)
|
||||
if (match_types and (base_type != other_type) and
|
||||
(base_value is not None) and (other_value is not None)):
|
||||
raise ValueError('Type mismatch for {} got {} ({}) and {} ({})'.format(k, base_value, base_type,
|
||||
other_value, other_type))
|
||||
if isinstance(base_value, dict):
|
||||
merged[k] = _merge_two_dicts(base_value, other_value, list_duplicates, match_types, dict_type)
|
||||
elif isinstance(base_value, list):
|
||||
if should_merge_lists:
|
||||
merged[k] = _merge_two_lists(base_value, other_value, list_duplicates, dict_type)
|
||||
else:
|
||||
merged[k] = _merge_two_lists([], other_value, list_duplicates, dict_type)
|
||||
|
||||
elif isinstance(base_value, set):
|
||||
merged[k] = norm(base_value.union(other_value), dict_type)
|
||||
else:
|
||||
merged[k] = norm(other_value, dict_type)
|
||||
else: # Should never get here
|
||||
raise AssertionError('Unexpected merge key: {}'.format(k))
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def merge_lists(*args, **kwargs):
|
||||
if not len(args) >= 2:
|
||||
raise ValueError('Must specify at least two lists to merge.')
|
||||
func = partial(_merge_two_lists, **kwargs)
|
||||
return reduce(func, args)
|
||||
|
||||
|
||||
def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: disable=R0912
|
||||
"""Merge lists, normalizing their entries."""
|
||||
if duplicates == 'all':
|
||||
merged_list = []
|
||||
for v in normalize(base, dict_type) + normalize(other, dict_type):
|
||||
if not _check_remove_item(merged_list, v):
|
||||
merged_list.append(v)
|
||||
return merged_list
|
||||
elif duplicates == 'first':
|
||||
merged_list = []
|
||||
for v in normalize(base + other, dict_type):
|
||||
if not _check_remove_item(merged_list, v):
|
||||
if v not in merged_list:
|
||||
merged_list.append(v)
|
||||
return merged_list
|
||||
elif duplicates == 'last':
|
||||
merged_list = []
|
||||
for v in normalize(base + other, dict_type):
|
||||
if not _check_remove_item(merged_list, v):
|
||||
if v in merged_list:
|
||||
del merged_list[merged_list.index(v)]
|
||||
merged_list.append(v)
|
||||
return merged_list
|
||||
else:
|
||||
raise ValueError('Unexpected value for list duplcates argument: {}. '.format(duplicates) +
|
||||
'Must be in {"all", "first", "last"}.')
|
||||
|
||||
|
||||
def _check_remove_item(the_list, item):
|
||||
"""Helper function for merge_lists that implements checking wether an items
|
||||
should be removed from the list and doing so if needed. Returns ``True`` if
|
||||
the item has been removed and ``False`` otherwise."""
|
||||
if not isinstance(item, basestring):
|
||||
return False
|
||||
if not item.startswith('~'):
|
||||
return False
|
||||
actual_item = item[1:]
|
||||
if actual_item in the_list:
|
||||
del the_list[the_list.index(actual_item)]
|
||||
return True
|
||||
|
||||
|
||||
def normalize(value, dict_type=dict):
|
||||
"""Normalize values. Recursively normalizes dict keys to be lower case,
|
||||
no surrounding whitespace, underscore-delimited strings."""
|
||||
if isinstance(value, dict):
|
||||
normalized = dict_type()
|
||||
for k, v in value.iteritems():
|
||||
key = k.strip().lower().replace(' ', '_')
|
||||
normalized[key] = normalize(v, dict_type)
|
||||
return normalized
|
||||
elif isinstance(value, list):
|
||||
return [normalize(v, dict_type) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple([normalize(v, dict_type) for v in value])
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
VALUE_REGEX = re.compile(r'(\d+(?:\.\d+)?)\s*(\w*)')
|
||||
|
||||
UNITS_MAP = {
|
||||
's': 'seconds',
|
||||
'ms': 'milliseconds',
|
||||
'us': 'microseconds',
|
||||
'ns': 'nanoseconds',
|
||||
'V': 'volts',
|
||||
'A': 'amps',
|
||||
'mA': 'milliamps',
|
||||
'J': 'joules',
|
||||
}
|
||||
|
||||
|
||||
def parse_value(value_string):
|
||||
"""parses a string representing a numerical value and returns
|
||||
a tuple (value, units), where value will be either int or float,
|
||||
and units will be a string representing the units or None."""
|
||||
match = VALUE_REGEX.search(value_string)
|
||||
if match:
|
||||
vs = match.group(1)
|
||||
value = float(vs) if '.' in vs else int(vs)
|
||||
us = match.group(2)
|
||||
units = UNITS_MAP.get(us, us)
|
||||
return (value, units)
|
||||
else:
|
||||
return (value_string, None)
|
||||
|
||||
|
||||
def get_meansd(values):
|
||||
"""Returns mean and standard deviation of the specified values."""
|
||||
if not values:
|
||||
return float('nan'), float('nan')
|
||||
mean = sum(values) / len(values)
|
||||
sd = math.sqrt(sum([v ** 2 for v in values]) / len(values) - mean ** 2)
|
||||
return mean, sd
|
||||
|
||||
|
||||
def geomean(values):
|
||||
"""Returns the geometric mean of the values."""
|
||||
return reduce(mul, values) ** (1.0 / len(values))
|
||||
|
||||
|
||||
def capitalize(text):
|
||||
"""Capitalises the specified text: first letter upper case,
|
||||
all subsequent letters lower case."""
|
||||
if not text:
|
||||
return ''
|
||||
return text[0].upper() + text[1:].lower()
|
||||
|
||||
|
||||
def convert_new_lines(text):
|
||||
""" Convert new lines to a common format. """
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def escape_quotes(text):
|
||||
"""Escape quotes, and escaped quotes, in the specified text."""
|
||||
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\\\'').replace('\"', '\\\"')
|
||||
|
||||
|
||||
def escape_single_quotes(text):
|
||||
"""Escape single quotes, and escaped single quotes, in the specified text."""
|
||||
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\'\\\'\'')
|
||||
|
||||
|
||||
def escape_double_quotes(text):
|
||||
"""Escape double quotes, and escaped double quotes, in the specified text."""
|
||||
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\"', '\\\"')
|
||||
|
||||
|
||||
def getch(count=1):
|
||||
"""Read ``count`` characters from standard input."""
|
||||
if os.name == 'nt':
|
||||
import msvcrt # pylint: disable=F0401
|
||||
return ''.join([msvcrt.getch() for _ in xrange(count)])
|
||||
else: # assume Unix
|
||||
import tty # NOQA
|
||||
import termios # NOQA
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
ch = sys.stdin.read(count)
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
return ch
|
||||
|
||||
|
||||
def isiterable(obj):
|
||||
"""Returns ``True`` if the specified object is iterable and
|
||||
*is not a string type*, ``False`` otherwise."""
|
||||
return hasattr(obj, '__iter__') and not isinstance(obj, basestring)
|
||||
|
||||
|
||||
def utc_to_local(dt):
|
||||
"""Convert naive datetime to local time zone, assuming UTC."""
|
||||
return dt.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal())
|
||||
|
||||
|
||||
def local_to_utc(dt):
|
||||
"""Convert naive datetime to UTC, assuming local time zone."""
|
||||
return dt.replace(tzinfo=tz.tzlocal()).astimezone(tz.tzutc())
|
||||
|
||||
|
||||
def as_relative(path):
|
||||
"""Convert path to relative by stripping away the leading '/' on UNIX or
|
||||
the equivant on other platforms."""
|
||||
path = os.path.splitdrive(path)[1]
|
||||
return path.lstrip(os.sep)
|
||||
|
||||
|
||||
def get_cpu_mask(cores):
|
||||
"""Return a string with the hex for the cpu mask for the specified core numbers."""
|
||||
mask = 0
|
||||
for i in cores:
|
||||
mask |= 1 << i
|
||||
return '0x{0:x}'.format(mask)
|
||||
|
||||
|
||||
def load_class(classpath):
|
||||
"""Loads the specified Python class. ``classpath`` must be a fully-qualified
|
||||
class name (i.e. namspaced under module/package)."""
|
||||
modname, clsname = classpath.rsplit('.', 1)
|
||||
return getattr(__import__(modname), clsname)
|
||||
|
||||
|
||||
def get_pager():
|
||||
"""Returns the name of the system pager program."""
|
||||
pager = os.getenv('PAGER')
|
||||
if not pager:
|
||||
pager = find_executable('less')
|
||||
if not pager:
|
||||
pager = find_executable('more')
|
||||
return pager
|
||||
|
||||
|
||||
def enum_metaclass(enum_param, return_name=False, start=0):
|
||||
"""
|
||||
Returns a ``type`` subclass that may be used as a metaclass for
|
||||
an enum.
|
||||
|
||||
Paremeters:
|
||||
|
||||
:enum_param: the name of class attribute that defines enum values.
|
||||
The metaclass will add a class attribute for each value in
|
||||
``enum_param``. The value of the attribute depends on the type
|
||||
of ``enum_param`` and on the values of ``return_name``. If
|
||||
``return_name`` is ``True``, then the value of the new attribute is
|
||||
the name of that attribute; otherwise, if ``enum_param`` is a ``list``
|
||||
or a ``tuple``, the value will be the index of that param in
|
||||
``enum_param``, optionally offset by ``start``, otherwise, it will
|
||||
be assumed that ``enum_param`` implementa a dict-like inteface and
|
||||
the value will be ``enum_param[attr_name]``.
|
||||
:return_name: If ``True``, the enum values will the names of enum attributes. If
|
||||
``False``, the default, the values will depend on the type of
|
||||
``enum_param`` (see above).
|
||||
:start: If ``enum_param`` is a list or a tuple, and ``return_name`` is ``False``,
|
||||
this specifies an "offset" that will be added to the index of the attribute
|
||||
within ``enum_param`` to form the value.
|
||||
|
||||
|
||||
"""
|
||||
class __EnumMeta(type):
|
||||
def __new__(mcs, clsname, bases, attrs):
|
||||
cls = type.__new__(mcs, clsname, bases, attrs)
|
||||
values = getattr(cls, enum_param, [])
|
||||
if return_name:
|
||||
for name in values:
|
||||
setattr(cls, name, name)
|
||||
else:
|
||||
if isinstance(values, list) or isinstance(values, tuple):
|
||||
for i, name in enumerate(values):
|
||||
setattr(cls, name, i + start)
|
||||
else: # assume dict-like
|
||||
for name in values:
|
||||
setattr(cls, name, values[name])
|
||||
return cls
|
||||
return __EnumMeta
|
||||
|
||||
|
||||
def which(name):
|
||||
"""Platform-independent version of UNIX which utility."""
|
||||
if os.name == 'nt':
|
||||
paths = os.getenv('PATH').split(os.pathsep)
|
||||
exts = os.getenv('PATHEXT').split(os.pathsep)
|
||||
for path in paths:
|
||||
testpath = os.path.join(path, name)
|
||||
if os.path.isfile(testpath):
|
||||
return testpath
|
||||
for ext in exts:
|
||||
testpathext = testpath + ext
|
||||
if os.path.isfile(testpathext):
|
||||
return testpathext
|
||||
return None
|
||||
else: # assume UNIX-like
|
||||
try:
|
||||
return check_output(['which', name])[0].strip()
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
_bash_color_regex = re.compile('\x1b\[[0-9;]+m')
|
||||
|
||||
|
||||
def strip_bash_colors(text):
|
||||
return _bash_color_regex.sub('', text)
|
||||
|
||||
|
||||
def format_duration(seconds, sep=' ', order=['day', 'hour', 'minute', 'second']): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Formats the specified number of seconds into human-readable duration.
|
||||
|
||||
"""
|
||||
if isinstance(seconds, timedelta):
|
||||
td = seconds
|
||||
else:
|
||||
td = timedelta(seconds=seconds)
|
||||
dt = datetime(1, 1, 1) + td
|
||||
result = []
|
||||
for item in order:
|
||||
value = getattr(dt, item, None)
|
||||
if item is 'day':
|
||||
value -= 1
|
||||
if not value:
|
||||
continue
|
||||
suffix = '' if value == 1 else 's'
|
||||
result.append('{} {}{}'.format(value, item, suffix))
|
||||
return sep.join(result)
|
||||
|
||||
|
||||
def get_article(word):
|
||||
"""
|
||||
Returns the appropriate indefinite article for the word (ish).
|
||||
|
||||
.. note:: Indefinite article assignment in English is based on
|
||||
sound rather than spelling, so this will not work correctly
|
||||
in all case; e.g. this will return ``"a hour"``.
|
||||
|
||||
"""
|
||||
return'an' if word[0] in 'aoeiu' else 'a'
|
||||
|
||||
|
||||
def get_random_string(length):
|
||||
"""Returns a random ASCII string of the specified length)."""
|
||||
return ''.join(random.choice(string.ascii_letters + string.digits) for _ in xrange(length))
|
||||
|
||||
|
||||
class LoadSyntaxError(Exception):
|
||||
|
||||
def __init__(self, message, filepath, lineno):
|
||||
super(LoadSyntaxError, self).__init__(message)
|
||||
self.filepath = filepath
|
||||
self.lineno = lineno
|
||||
|
||||
def __str__(self):
|
||||
message = 'Syntax Error in {}, line {}:\n\t{}'
|
||||
return message.format(self.filepath, self.lineno, self.message)
|
||||
|
||||
|
||||
RAND_MOD_NAME_LEN = 30
|
||||
BAD_CHARS = string.punctuation + string.whitespace
|
||||
TRANS_TABLE = string.maketrans(BAD_CHARS, '_' * len(BAD_CHARS))
|
||||
|
||||
|
||||
def to_identifier(text):
|
||||
"""Converts text to a valid Python identifier by replacing all
|
||||
whitespace and punctuation."""
|
||||
return re.sub('_+', '_', text.translate(TRANS_TABLE))
|
||||
|
||||
|
||||
def load_struct_from_python(filepath=None, text=None):
|
||||
"""Parses a config structure from a .py file. The structure should be composed
|
||||
of basic Python types (strings, ints, lists, dicts, etc.)."""
|
||||
if not (filepath or text) or (filepath and text):
|
||||
raise ValueError('Exactly one of filepath or text must be specified.')
|
||||
try:
|
||||
if filepath:
|
||||
modname = to_identifier(filepath)
|
||||
mod = imp.load_source(modname, filepath)
|
||||
else:
|
||||
modname = get_random_string(RAND_MOD_NAME_LEN)
|
||||
while modname in sys.modules: # highly unlikely, but...
|
||||
modname = get_random_string(RAND_MOD_NAME_LEN)
|
||||
mod = imp.new_module(modname)
|
||||
exec text in mod.__dict__ # pylint: disable=exec-used
|
||||
return dict((k, v)
|
||||
for k, v in mod.__dict__.iteritems()
|
||||
if not k.startswith('_'))
|
||||
except SyntaxError as e:
|
||||
raise LoadSyntaxError(e.message, e.filepath, e.lineno)
|
||||
|
||||
|
||||
def load_struct_from_yaml(filepath=None, text=None):
|
||||
"""Parses a config structure from a .yaml file. The structure should be composed
|
||||
of basic Python types (strings, ints, lists, dicts, etc.)."""
|
||||
if not (filepath or text) or (filepath and text):
|
||||
raise ValueError('Exactly one of filepath or text must be specified.')
|
||||
try:
|
||||
if filepath:
|
||||
with open(filepath) as fh:
|
||||
return yaml.load(fh)
|
||||
else:
|
||||
return yaml.load(text)
|
||||
except yaml.YAMLError as e:
|
||||
lineno = None
|
||||
if hasattr(e, 'problem_mark'):
|
||||
lineno = e.problem_mark.line
|
||||
raise LoadSyntaxError(e.message, filepath=filepath, lineno=lineno)
|
||||
|
||||
|
||||
def load_struct_from_file(filepath):
|
||||
"""
|
||||
Attempts to parse a Python structure consisting of basic types from the specified file.
|
||||
Raises a ``ValueError`` if the specified file is of unkown format; ``LoadSyntaxError`` if
|
||||
there is an issue parsing the file.
|
||||
|
||||
"""
|
||||
extn = os.path.splitext(filepath)[1].lower()
|
||||
if (extn == '.py') or (extn == '.pyc') or (extn == '.pyo'):
|
||||
return load_struct_from_python(filepath)
|
||||
elif extn == '.yaml':
|
||||
return load_struct_from_yaml(filepath)
|
||||
else:
|
||||
raise ValueError('Unknown format "{}": {}'.format(extn, filepath))
|
||||
|
||||
|
||||
def unique(alist):
|
||||
"""
|
||||
Returns a list containing only unique elements from the input list (but preserves
|
||||
order, unlike sets).
|
||||
|
||||
"""
|
||||
result = []
|
||||
for item in alist:
|
||||
if item not in result:
|
||||
result.append(item)
|
||||
return result
|
Reference in New Issue
Block a user