1
0
mirror of https://github.com/ARM-software/workload-automation.git synced 2025-09-01 10:52:33 +01:00

Initial commit of open source Workload Automation.

This commit is contained in:
Sergei Trofimov
2015-03-10 13:09:31 +00:00
commit a747ec7e4c
412 changed files with 41401 additions and 0 deletions

16
wlauto/utils/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

368
wlauto/utils/android.py Normal file
View File

@@ -0,0 +1,368 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utility functions for working with Android devices through adb.
"""
# pylint: disable=E1103
import os
import time
import subprocess
import logging
import re
from wlauto.exceptions import DeviceError, ConfigError, HostError
from wlauto.utils.misc import check_output, escape_single_quotes, escape_double_quotes, get_null
MAX_TRIES = 5
logger = logging.getLogger('android')
# See:
# http://developer.android.com/guide/topics/manifest/uses-sdk-element.html#ApiLevels
ANDROID_VERSION_MAP = {
19: 'KITKAT',
18: 'JELLY_BEAN_MR2',
17: 'JELLY_BEAN_MR1',
16: 'JELLY_BEAN',
15: 'ICE_CREAM_SANDWICH_MR1',
14: 'ICE_CREAM_SANDWICH',
13: 'HONEYCOMB_MR2',
12: 'HONEYCOMB_MR1',
11: 'HONEYCOMB',
10: 'GINGERBREAD_MR1',
9: 'GINGERBREAD',
8: 'FROYO',
7: 'ECLAIR_MR1',
6: 'ECLAIR_0_1',
5: 'ECLAIR',
4: 'DONUT',
3: 'CUPCAKE',
2: 'BASE_1_1',
1: 'BASE',
}
# TODO: these are set to their actual values near the bottom of the file. There
# is some HACKery involved to ensure that ANDROID_HOME does not need to be set
# or adb added to path for root when installing as root, and the whole
# implemenationt is kinda clunky and messier than I'd like. The only file that
# rivals this one in levels of mess is bootstrap.py (for very much the same
# reasons). There must be a neater way to ensure that enviromental dependencies
# are met when they are needed, and are not imposed when they are not.
android_home = None
platform_tools = None
adb = None
aapt = None
fastboot = None
class _AndroidEnvironment(object):
def __init__(self):
self.android_home = None
self.platform_tools = None
self.adb = None
self.aapt = None
self.fastboot = None
class AndroidProperties(object):
def __init__(self, text):
self._properties = {}
self.parse(text)
def parse(self, text):
self._properties = dict(re.findall(r'\[(.*?)\]:\s+\[(.*?)\]', text))
def __iter__(self):
return iter(self._properties)
def __getattr__(self, name):
return self._properties.get(name)
__getitem__ = __getattr__
class ApkInfo(object):
version_regex = re.compile(r"name='(?P<name>[^']+)' versionCode='(?P<vcode>[^']+)' versionName='(?P<vname>[^']+)'")
name_regex = re.compile(r"name='(?P<name>[^']+)'")
def __init__(self, path=None):
self.path = path
self.package = None
self.activity = None
self.label = None
self.version_name = None
self.version_code = None
self.parse(path)
def parse(self, apk_path):
_check_env()
command = [aapt, 'dump', 'badging', apk_path]
logger.debug(' '.join(command))
output = subprocess.check_output(command)
for line in output.split('\n'):
if line.startswith('application-label:'):
self.label = line.split(':')[1].strip().replace('\'', '')
elif line.startswith('package:'):
match = self.version_regex.search(line)
if match:
self.package = match.group('name')
self.version_code = match.group('vcode')
self.version_name = match.group('vname')
elif line.startswith('launchable-activity:'):
match = self.name_regex.search(line)
self.activity = match.group('name')
else:
pass # not interested
def fastboot_command(command, timeout=None):
_check_env()
full_command = "fastboot {}".format(command)
logger.debug(full_command)
output, _ = check_output(full_command, timeout, shell=True)
return output
def fastboot_flash_partition(partition, path_to_image):
command = 'flash {} {}'.format(partition, path_to_image)
fastboot_command(command)
def adb_get_device():
"""
Returns the serial number of a connected android device.
If there are more than one device connected to the machine, or it could not
find any device connected, :class:`wlauto.exceptions.ConfigError` is raised.
"""
_check_env()
# TODO this is a hacky way to issue a adb command to all listed devices
# The output of calling adb devices consists of a heading line then
# a list of the devices sperated by new line
# The last line is a blank new line. in otherwords, if there is a device found
# then the output length is 2 + (1 for each device)
output = adb_command('0', "devices").splitlines() # pylint: disable=E1103
output_length = len(output)
if output_length == 3:
# output[1] is the 2nd line in the output which has the device name
# Splitting the line by '\t' gives a list of two indexes, which has
# device serial in 0 number and device type in 1.
return output[1].split('\t')[0]
elif output_length > 3:
raise ConfigError('Number of discovered devices is {}, it should be 1'.format(output_length - 2))
else:
raise ConfigError('No device is connected and available')
def adb_connect(device, timeout=None):
_check_env()
command = "adb connect " + device
if ":5555" in device:
logger.debug(command)
output, _ = check_output(command, shell=True, timeout=timeout)
logger.debug(output)
#### due to a rare adb bug sometimes an extra :5555 is appended to the IP address
if output.find('5555:5555') != -1:
logger.debug('ADB BUG with extra 5555')
command = "adb connect " + device.replace(':5555', '')
tries = 0
while not poll_for_file(device, "/proc/cpuinfo"):
logger.debug("adb connect failed, retrying now...")
tries += 1
if tries > MAX_TRIES:
raise DeviceError('Cannot connect to adb server on the device.')
logger.debug(command)
output, _ = check_output(command, shell=True, timeout=timeout)
time.sleep(10)
if output.find('connected to') == -1:
raise DeviceError('Could not connect to {}'.format(device))
def adb_disconnect(device):
_check_env()
if ":5555" in device:
command = "adb disconnect " + device
logger.debug(command)
retval = subprocess.call(command, stdout=open(os.devnull, 'wb'), shell=True)
if retval:
raise DeviceError('"{}" returned {}'.format(command, retval))
def poll_for_file(device, dfile):
_check_env()
device_string = '-s {}'.format(device) if device else ''
command = "adb " + device_string + " shell \" if [ -f " + dfile + " ] ; then true ; else false ; fi\" "
logger.debug(command)
result = subprocess.call(command, stderr=subprocess.PIPE, shell=True)
if not result:
return True
else:
return False
am_start_error = re.compile(r"Error: Activity class {[\w|.|/]*} does not exist")
def adb_shell(device, command, timeout=None, check_exit_code=False, as_root=False): # NOQA
_check_env()
if as_root:
command = 'echo "{}" | su'.format(escape_double_quotes(command))
device_string = '-s {}'.format(device) if device else ''
full_command = 'adb {} shell "{}"'.format(device_string, escape_double_quotes(command))
logger.debug(full_command)
if check_exit_code:
actual_command = "adb {} shell '({}); echo $?'".format(device_string, escape_single_quotes(command))
raw_output, error = check_output(actual_command, timeout, shell=True)
if raw_output:
try:
output, exit_code, _ = raw_output.rsplit('\n', 2)
except ValueError:
exit_code, _ = raw_output.rsplit('\n', 1)
output = ''
else: # raw_output is empty
exit_code = '969696' # just because
output = ''
exit_code = exit_code.strip()
if exit_code.isdigit():
if int(exit_code):
message = 'Got exit code {}\nfrom: {}\nSTDOUT: {}\nSTDERR: {}'.format(exit_code, full_command,
output, error)
raise DeviceError(message)
elif am_start_error.findall(output):
message = 'Could not start activity; got the following:'
message += '\n{}'.format(am_start_error.findall(output)[0])
raise DeviceError(message)
else: # not all digits
if am_start_error.findall(output):
message = 'Could not start activity; got the following:'
message += '\n{}'.format(am_start_error.findall(output)[0])
raise DeviceError(message)
else:
raise DeviceError('adb has returned early; did not get an exit code. Was kill-server invoked?')
else: # do not check exit code
output, _ = check_output(full_command, timeout, shell=True)
return output
def adb_background_shell(device, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False):
"""Runs the sepcified command in a subprocess, returning the the Popen object."""
_check_env()
if as_root:
command = 'echo \'{}\' | su'.format(escape_single_quotes(command))
device_string = '-s {}'.format(device) if device else ''
full_command = 'adb {} shell "{}"'.format(device_string, escape_double_quotes(command))
logger.debug(full_command)
return subprocess.Popen(full_command, stdout=stdout, stderr=stderr, shell=True)
class AdbDevice(object):
def __init__(self, name, status):
self.name = name
self.status = status
def __cmp__(self, other):
if isinstance(other, AdbDevice):
return cmp(self.name, other.name)
else:
return cmp(self.name, other)
def adb_list_devices():
_check_env()
output = adb_command(None, 'devices')
devices = []
for line in output.splitlines():
parts = [p.strip() for p in line.split()]
if len(parts) == 2:
devices.append(AdbDevice(*parts))
return devices
def adb_command(device, command, timeout=None):
_check_env()
device_string = '-s {}'.format(device) if device else ''
full_command = "adb {} {}".format(device_string, command)
logger.debug(full_command)
output, _ = check_output(full_command, timeout, shell=True)
return output
# Messy environment initialisation stuff...
def _initialize_with_android_home(env):
logger.debug('Using ANDROID_HOME from the environment.')
env.android_home = android_home
env.platform_tools = os.path.join(android_home, 'platform-tools')
os.environ['PATH'] += os.pathsep + env.platform_tools
_init_common(env)
return env
def _initialize_without_android_home(env):
if os.name == 'nt':
raise HostError('Please set ANDROID_HOME to point to the location of the Android SDK.')
# Assuming Unix in what follows.
if subprocess.call('adb version >{}'.format(get_null()), shell=True):
raise HostError('ANDROID_HOME is not set and adb is not in PATH. Have you installed Android SDK?')
logger.debug('Discovering ANDROID_HOME from adb path.')
env.platform_tools = os.path.dirname(subprocess.check_output('which adb', shell=True))
env.android_home = os.path.dirname(env.platform_tools)
_init_common(env)
return env
def _init_common(env):
logger.debug('ANDROID_HOME: {}'.format(env.android_home))
build_tools_directory = os.path.join(env.android_home, 'build-tools')
if not os.path.isdir(build_tools_directory):
msg = 'ANDROID_HOME ({}) does not appear to have valid Android SKD install (cannot find build-tools)'
raise HostError(msg.format(env.android_home))
versions = os.listdir(build_tools_directory)
for version in reversed(sorted(versions)):
aapt_path = os.path.join(build_tools_directory, version, 'aapt')
if os.path.isfile(aapt_path):
logger.debug('Using aapt for version {}'.format(version))
env.aapt = aapt_path
break
else:
raise HostError('aapt not found. Please make sure at least one Android platform is installed.')
def _check_env():
global android_home, platform_tools, adb, aapt # pylint: disable=W0603
if not android_home:
android_home = os.getenv('ANDROID_HOME')
if android_home:
_env = _initialize_with_android_home(_AndroidEnvironment())
else:
_env = _initialize_without_android_home(_AndroidEnvironment())
android_home = _env.android_home
platform_tools = _env.platform_tools
adb = _env.adb
aapt = _env.aapt

27
wlauto/utils/cli.py Normal file
View File

@@ -0,0 +1,27 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from wlauto.core.version import get_wa_version
def init_argument_parser(parser):
parser.add_argument('-c', '--config', help='specify an additional config.py')
parser.add_argument('-v', '--verbose', action='count',
help='The scripts will produce verbose output.')
parser.add_argument('--debug', action='store_true',
help='Enable debug mode. Note: this implies --verbose.')
parser.add_argument('--version', action='version', version='%(prog)s {}'.format(get_wa_version()))
return parser

44
wlauto/utils/cpuinfo.py Normal file
View File

@@ -0,0 +1,44 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class Cpuinfo(object):
@property
def architecture(self):
for section in self.sections:
if 'CPU architecture' in section:
return section['CPU architecture']
if 'architecture' in section:
return section['architecture']
def __init__(self, text):
self.sections = None
self.text = None
self.parse(text)
def parse(self, text):
self.sections = []
current_section = {}
self.text = text.strip()
for line in self.text.split('\n'):
line = line.strip()
if line:
key, value = line.split(':', 1)
current_section[key.strip()] = value.strip()
else: # not line
self.sections.append(current_section)
current_section = {}
self.sections.append(current_section)

305
wlauto/utils/doc.py Normal file
View File

@@ -0,0 +1,305 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utilities for working with and formatting documentation.
"""
import os
import re
import inspect
from itertools import cycle
USER_HOME = os.path.expanduser('~')
BULLET_CHARS = '-*'
def get_summary(aclass):
"""
Returns the summary description for an extension class. The summary is the
first paragraph (separated by blank line) of the description taken either from
the ``descripton`` attribute of the class, or if that is not present, from the
class' docstring.
"""
return get_description(aclass).split('\n\n')[0]
def get_description(aclass):
"""
Return the description of the specified extension class. The description is taken
either from ``description`` attribute of the class or its docstring.
"""
if hasattr(aclass, 'description') and aclass.description:
return inspect.cleandoc(aclass.description)
if aclass.__doc__:
return inspect.getdoc(aclass)
else:
return 'no documentation found for {}'.format(aclass.__name__)
def get_type_name(obj):
"""Returns the name of the type object or function specified. In case of a lambda,
the definiition is returned with the parameter replaced by "value"."""
match = re.search(r"<(type|class|function) '?(.*?)'?>", str(obj))
if isinstance(obj, tuple):
name = obj[1]
elif match.group(1) == 'function':
text = str(obj)
name = text.split()[1]
if name == '<lambda>':
source = inspect.getsource(obj).strip().replace('\n', ' ')
match = re.search(r'lambda\s+(\w+)\s*:\s*(.*?)\s*[\n,]', source)
if not match:
raise ValueError('could not get name for {}'.format(obj))
name = match.group(2).replace(match.group(1), 'value')
else:
name = match.group(2)
if '.' in name:
name = name.split('.')[-1]
return name
def count_leading_spaces(text):
"""
Counts the number of leading space characters in a string.
TODO: may need to update this to handle whitespace, but shouldn't
be necessary as there should be no tabs in Python source.
"""
nspaces = 0
for c in text:
if c == ' ':
nspaces += 1
else:
break
return nspaces
def format_column(text, width):
"""
Formats text into a column of specified width. If a line is too long,
it will be broken on a word boundary. The new lines will have the same
number of leading spaces as the original line.
Note: this will not attempt to join up lines that are too short.
"""
formatted = []
for line in text.split('\n'):
line_len = len(line)
if line_len <= width:
formatted.append(line)
else:
words = line.split(' ')
new_line = words.pop(0)
while words:
next_word = words.pop(0)
if (len(new_line) + len(next_word) + 1) < width:
new_line += ' ' + next_word
else:
formatted.append(new_line)
new_line = ' ' * count_leading_spaces(new_line) + next_word
formatted.append(new_line)
return '\n'.join(formatted)
def format_bullets(text, width, char='-', shift=3, outchar=None):
"""
Formats text into bulleted list. Assumes each line of input that starts with
``char`` (possibly preceeded with whitespace) is a new bullet point. Note: leading
whitespace in the input will *not* be preserved. Instead, it will be determined by
``shift`` parameter.
:text: the text to be formated
:width: format width (note: must be at least ``shift`` + 4).
:char: character that indicates a new bullet point in the input text.
:shift: How far bulleted entries will be indented. This indicates the indentation
level of the bullet point. Text indentation level will be ``shift`` + 3.
:outchar: character that will be used to mark bullet points in the output. If
left as ``None``, ``char`` will be used.
"""
bullet_lines = []
output = ''
def __process_bullet(bullet_lines):
if bullet_lines:
bullet = format_paragraph(indent(' '.join(bullet_lines), shift + 2), width)
bullet = bullet[:3] + outchar + bullet[4:]
del bullet_lines[:]
return bullet + '\n'
else:
return ''
if outchar is None:
outchar = char
for line in text.split('\n'):
line = line.strip()
if line.startswith(char): # new bullet
output += __process_bullet(bullet_lines)
line = line[1:].strip()
bullet_lines.append(line)
output += __process_bullet(bullet_lines)
return output
def format_simple_table(rows, headers=None, align='>', show_borders=True, borderchar='='): # pylint: disable=R0914
"""Formats a simple table."""
if not rows:
return ''
rows = [map(str, r) for r in rows]
num_cols = len(rows[0])
# cycle specified alignments until we have num_cols of them. This is
# consitent with how such cases are handled in R, pandas, etc.
it = cycle(align)
align = [it.next() for _ in xrange(num_cols)]
cols = zip(*rows)
col_widths = [max(map(len, c)) for c in cols]
row_format = ' '.join(['{:%s%s}' % (align[i], w) for i, w in enumerate(col_widths)])
row_format += '\n'
border = row_format.format(*[borderchar * cw for cw in col_widths])
result = border if show_borders else ''
if headers:
result += row_format.format(*headers)
result += border
for row in rows:
result += row_format.format(*row)
if show_borders:
result += border
return result
def format_paragraph(text, width):
"""
Format the specified text into a column of specified with. The text is
assumed to be a single paragraph and existing line breaks will not be preserved.
Leading spaces (of the initial line), on the other hand, will be preserved.
"""
text = re.sub('\n\n*\\s*', ' ', text.strip('\n'))
return format_column(text, width)
def format_body(text, width):
"""
Format the specified text into a column of specified width. The text is
assumed to be a "body" of one or more paragraphs separated by one or more
blank lines. The initial indentation of the first line of each paragraph
will be presevered, but any other formatting may be clobbered.
"""
text = re.sub('\n\\s*\n', '\n\n', text.strip('\n')) # get rid of all-whitespace lines
paragraphs = re.split('\n\n+', text)
formatted_paragraphs = []
for p in paragraphs:
if p.strip() and p.strip()[0] in BULLET_CHARS:
formatted_paragraphs.append(format_bullets(p, width))
else:
formatted_paragraphs.append(format_paragraph(p, width))
return '\n\n'.join(formatted_paragraphs)
def strip_inlined_text(text):
"""
This function processes multiline inlined text (e.g. form docstrings)
to strip away leading spaces and leading and trailing new lines.
"""
text = text.strip('\n')
lines = [ln.rstrip() for ln in text.split('\n')]
# first line is special as it may not have the indet that follows the
# others, e.g. if it starts on the same as the multiline quote (""").
nspaces = count_leading_spaces(lines[0])
if len([ln for ln in lines if ln]) > 1:
to_strip = min(count_leading_spaces(ln) for ln in lines[1:] if ln)
if nspaces >= to_strip:
stripped = [lines[0][to_strip:]]
else:
stripped = [lines[0][nspaces:]]
stripped += [ln[to_strip:] for ln in lines[1:]]
else:
stripped = [lines[0][nspaces:]]
return '\n'.join(stripped).strip('\n')
def indent(text, spaces=4):
"""Indent the lines i the specified text by ``spaces`` spaces."""
indented = []
for line in text.split('\n'):
if line:
indented.append(' ' * spaces + line)
else: # do not indent emtpy lines
indented.append(line)
return '\n'.join(indented)
def format_literal(lit):
if isinstance(lit, basestring):
return '``\'{}\'``'.format(lit)
elif hasattr(lit, 'pattern'): # regex
return '``r\'{}\'``'.format(lit.pattern)
else:
return '``{}``'.format(lit)
def get_params_rst(ext):
text = ''
for param in ext.parameters:
text += '{} : {} {}\n'.format(param.name, get_type_name(param.kind),
param.mandatory and '(mandatory)' or ' ')
desc = strip_inlined_text(param.description or '')
text += indent('{}\n'.format(desc))
if param.allowed_values:
text += indent('\nallowed values: {}\n'.format(', '.join(map(format_literal, param.allowed_values))))
elif param.constraint:
text += indent('\nconstraint: ``{}``\n'.format(get_type_name(param.constraint)))
if param.default:
value = param.default
if isinstance(value, basestring) and value.startswith(USER_HOME):
value = value.replace(USER_HOME, '~')
text += indent('\ndefault: {}\n'.format(format_literal(value)))
text += '\n'
return text
def underline(text, symbol='='):
return '{}\n{}\n\n'.format(text, symbol * len(text))
def get_rst_from_extension(ext):
text = underline(ext.name, '-')
if hasattr(ext, 'description'):
desc = strip_inlined_text(ext.description or '')
elif ext.__doc__:
desc = strip_inlined_text(ext.__doc__)
else:
desc = ''
text += desc + '\n\n'
params_rst = get_params_rst(ext)
if params_rst:
text += underline('parameters', '~') + params_rst
return text + '\n'

148
wlauto/utils/formatter.py Normal file
View File

@@ -0,0 +1,148 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
INDENTATION_FROM_TITLE = 4
class TextFormatter(object):
"""
This is a base class for text formatting. It mainly ask to implement two
methods which are add_item and format_data. The formar will add new text to
the formatter, whereas the latter will return a formatted text. The name
attribute represents the name of the foramtter.
"""
name = None
data = None
def __init__(self):
pass
def add_item(self, new_data, item_title):
"""
Add new item to the text formatter.
:param new_data: The data to be added
:param item_title: A title for the added data
"""
raise NotImplementedError()
def format_data(self):
"""
It returns a formatted text
"""
raise NotImplementedError()
class DescriptionListFormatter(TextFormatter):
name = 'description_list_formatter'
data = None
def get_text_width(self):
if not self._text_width:
_, width = os.popen('stty size', 'r').read().split()
self._text_width = int(width)
return self._text_width
def set_text_width(self, value):
self._text_width = value
text_width = property(get_text_width, set_text_width)
def __init__(self, title=None, width=None):
super(DescriptionListFormatter, self).__init__()
self.data_title = title
self._text_width = width
self.longest_word_length = 0
self.data = []
def add_item(self, new_data, item_title):
if len(item_title) > self.longest_word_length:
self.longest_word_length = len(item_title)
self.data[len(self.data):] = [(item_title, self._remove_newlines(new_data))]
def format_data(self):
parag_indentation = self.longest_word_length + INDENTATION_FROM_TITLE
string_formatter = '{}:<{}{} {}'.format('{', parag_indentation, '}', '{}')
formatted_data = ''
if self.data_title:
formatted_data += self.data_title
line_width = self.text_width - parag_indentation
for title, paragraph in self.data:
if paragraph:
formatted_data += '\n'
title_len = self.longest_word_length - len(title)
title += ':'
if title_len > 0:
title = (' ' * title_len) + title
parag_lines = self._break_lines(paragraph, line_width).splitlines()
if parag_lines:
formatted_data += string_formatter.format(title, parag_lines[0])
for line in parag_lines[1:]:
formatted_data += '\n' + string_formatter.format('', line)
self.text_width = None
return formatted_data
# Return text's paragraphs sperated in a list, such that each index in the
# list is a single text paragraph with no new lines
def _remove_newlines(self, new_data): # pylint: disable=R0201
parag_list = ['']
parag_num = 0
prv_parag = None
# For each paragraph sperated by a new line
for paragraph in new_data.splitlines():
if paragraph:
parag_list[parag_num] += ' ' + paragraph
# if the previous line is NOT empty, then add new empty index for
# the next paragraph
elif prv_parag:
parag_num = 1
parag_list.append('')
prv_parag = paragraph
# sometimes, we end up with an empty string as the last item so we reomve it
if not parag_list[-1]:
return parag_list[:-1]
return parag_list
def _break_lines(self, parag_list, line_width): # pylint: disable=R0201
formatted_paragraphs = []
for para in parag_list:
words = para.split()
if words:
formatted_text = words.pop(0)
current_width = len(formatted_text)
# for each word in the paragraph, line width is an accumlation of
# word length + 1 (1 is for the space after each word).
for word in words:
word = word.strip()
if current_width + len(word) + 1 >= line_width:
formatted_text += '\n' + word
current_width = len(word)
else:
formatted_text += ' ' + word
current_width += len(word) + 1
formatted_paragraphs.append(formatted_text)
return '\n\n'.join(formatted_paragraphs)

77
wlauto/utils/hwmon.py Normal file
View File

@@ -0,0 +1,77 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from wlauto.exceptions import DeviceError
HWMON_ROOT = '/sys/class/hwmon'
class HwmonSensor(object):
def __init__(self, device, kind, label, filepath):
self.device = device
self.kind = kind
self.label = label
self.filepath = filepath
self.readings = []
def take_reading(self):
reading = self.device.get_sysfile_value(self.filepath, int)
self.readings.append(reading)
def clear_readings(self):
self.readings = []
def discover_sensors(device, sensor_kinds):
"""
Discovers HWMON sensors available on the device.
:device: Device on which to discover HWMON sensors. Must be an instance
of :class:`AndroidDevice`.
:sensor_kinds: A list of names of sensor types to be discovered. The names
must be as they appear prefixed to ``*_input`` files in
sysfs. E.g. ``'energy'``.
:returns: A list of :class:`HwmonSensor` instantces for each found sensor. If
no sensors of the specified types were discovered, an empty list
will be returned.
"""
hwmon_devices = device.listdir(HWMON_ROOT)
path = device.path
sensors = []
for hwmon_device in hwmon_devices:
try:
device_path = path.join(HWMON_ROOT, hwmon_device, 'device')
name = device.get_sysfile_value(path.join(device_path, 'name'))
except DeviceError: # probably a virtual device
device_path = path.join(HWMON_ROOT, hwmon_device)
name = device.get_sysfile_value(path.join(device_path, 'name'))
for sensor_kind in sensor_kinds:
i = 1
input_path = path.join(device_path, '{}{}_input'.format(sensor_kind, i))
while device.file_exists(input_path):
label_path = path.join(device_path, '{}{}_label'.format(sensor_kind, i))
if device.file_exists(label_path):
name += ' ' + device.get_sysfile_value(label_path)
sensors.append(HwmonSensor(device, sensor_kind, name, input_path))
i += 1
input_path = path.join(device_path, '{}{}_input'.format(sensor_kind, i))
return sensors

223
wlauto/utils/log.py Normal file
View File

@@ -0,0 +1,223 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=E1101
import logging
import string
import threading
import colorama
from wlauto.core.bootstrap import settings
import wlauto.core.signal as signal
COLOR_MAP = {
logging.DEBUG: colorama.Fore.BLUE,
logging.INFO: colorama.Fore.GREEN,
logging.WARNING: colorama.Fore.YELLOW,
logging.ERROR: colorama.Fore.RED,
logging.CRITICAL: colorama.Style.BRIGHT + colorama.Fore.RED,
}
RESET_COLOR = colorama.Style.RESET_ALL
def init_logging(verbosity):
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
error_handler = ErrorSignalHandler(logging.DEBUG)
root_logger.addHandler(error_handler)
console_handler = logging.StreamHandler()
if verbosity == 1:
console_handler.setLevel(logging.DEBUG)
if 'colour_enabled' in settings.logging and not settings.logging['colour_enabled']:
console_handler.setFormatter(LineFormatter(settings.logging['verbose_format']))
else:
console_handler.setFormatter(ColorFormatter(settings.logging['verbose_format']))
else:
console_handler.setLevel(logging.INFO)
if 'colour_enabled' in settings.logging and not settings.logging['colour_enabled']:
console_handler.setFormatter(LineFormatter(settings.logging['regular_format']))
else:
console_handler.setFormatter(ColorFormatter(settings.logging['regular_format']))
root_logger.addHandler(console_handler)
logging.basicConfig(level=logging.DEBUG)
def add_log_file(filepath, level=logging.DEBUG):
root_logger = logging.getLogger()
file_handler = logging.FileHandler(filepath)
file_handler.setLevel(level)
file_handler.setFormatter(LineFormatter(settings.logging['file_format']))
root_logger.addHandler(file_handler)
class ErrorSignalHandler(logging.Handler):
"""
Emits signals for ERROR and WARNING level traces.
"""
def emit(self, record):
if record.levelno == logging.ERROR:
signal.send(signal.ERROR_LOGGED, self)
elif record.levelno == logging.WARNING:
signal.send(signal.WARNING_LOGGED, self)
class ColorFormatter(logging.Formatter):
"""
Formats logging records with color and prepends record info
to each line of the message.
BLUE for DEBUG logging level
GREEN for INFO logging level
YELLOW for WARNING logging level
RED for ERROR logging level
BOLD RED for CRITICAL logging level
"""
def __init__(self, fmt=None, datefmt=None):
super(ColorFormatter, self).__init__(fmt, datefmt)
template_text = self._fmt.replace('%(message)s', RESET_COLOR + '%(message)s${color}')
template_text = '${color}' + template_text + RESET_COLOR
self.fmt_template = string.Template(template_text)
def format(self, record):
self._set_color(COLOR_MAP[record.levelno])
record.message = record.getMessage()
if self.usesTime():
record.asctime = self.formatTime(record, self.datefmt)
d = record.__dict__
parts = []
for line in record.message.split('\n'):
d.update({'message': line.strip('\r')})
parts.append(self._fmt % d)
return '\n'.join(parts)
def _set_color(self, color):
self._fmt = self.fmt_template.substitute(color=color)
class LineFormatter(logging.Formatter):
"""
Logs each line of the message separately.
"""
def __init__(self, fmt=None, datefmt=None):
super(LineFormatter, self).__init__(fmt, datefmt)
def format(self, record):
record.message = record.getMessage()
if self.usesTime():
record.asctime = self.formatTime(record, self.datefmt)
d = record.__dict__
parts = []
for line in record.message.split('\n'):
d.update({'message': line.strip('\r')})
parts.append(self._fmt % d)
return '\n'.join(parts)
class BaseLogWriter(object):
def __init__(self, name, level=logging.DEBUG):
"""
File-like object class designed to be used for logging from streams
Each complete line (terminated by new line character) gets logged
at DEBUG level. In complete lines are buffered until the next new line.
:param name: The name of the logger that will be used.
"""
self.logger = logging.getLogger(name)
self.buffer = ''
if level == logging.DEBUG:
self.do_write = self.logger.debug
elif level == logging.INFO:
self.do_write = self.logger.info
elif level == logging.WARNING:
self.do_write = self.logger.warning
elif level == logging.ERROR:
self.do_write = self.logger.error
else:
raise Exception('Unknown logging level: {}'.format(level))
def flush(self):
# Defined to match the interface expected by pexpect.
return self
def close(self):
if self.buffer:
self.logger.debug(self.buffer)
self.buffer = ''
return self
def __del__(self):
# Ensure we don't lose bufferd output
self.close()
class LogWriter(BaseLogWriter):
def write(self, data):
data = data.replace('\r\n', '\n').replace('\r', '\n')
if '\n' in data:
parts = data.split('\n')
parts[0] = self.buffer + parts[0]
for part in parts[:-1]:
self.do_write(part)
self.buffer = parts[-1]
else:
self.buffer += data
return self
class LineLogWriter(BaseLogWriter):
def write(self, data):
self.do_write(data)
class StreamLogger(threading.Thread):
"""
Logs output from a stream in a thread.
"""
def __init__(self, name, stream, level=logging.DEBUG, klass=LogWriter):
super(StreamLogger, self).__init__()
self.writer = klass(name, level)
self.stream = stream
self.daemon = True
def run(self):
line = self.stream.readline()
while line:
self.writer.write(line.rstrip('\n'))
line = self.stream.readline()
self.writer.close()

703
wlauto/utils/misc.py Normal file
View File

@@ -0,0 +1,703 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Miscellaneous functions that don't fit anywhere else.
"""
from __future__ import division
import os
import sys
import re
import math
import imp
import string
import threading
import signal
import subprocess
import pkgutil
import traceback
import logging
import random
from datetime import datetime, timedelta
from operator import mul
from StringIO import StringIO
from itertools import cycle
from functools import partial
from distutils.spawn import find_executable
import yaml
from dateutil import tz
def preexec_function():
# Ignore the SIGINT signal by setting the handler to the standard
# signal handler SIG_IGN.
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Change process group in case we have to kill the subprocess and all of
# its children later.
# TODO: this is Unix-specific; would be good to find an OS-agnostic way
# to do this in case we wanna port WA to Windows.
os.setpgrp()
check_output_logger = logging.getLogger('check_output')
# Defined here rather than in wlauto.exceptions due to module load dependencies
class TimeoutError(Exception):
"""Raised when a subprocess command times out. This is basically a ``WAError``-derived version
of ``subprocess.CalledProcessError``, the thinking being that while a timeout could be due to
programming error (e.g. not setting long enough timers), it is often due to some failure in the
environment, and there fore should be classed as a "user error"."""
def __init__(self, command, output):
super(TimeoutError, self).__init__('Timed out: {}'.format(command))
self.command = command
self.output = output
def __str__(self):
return '\n'.join([self.message, 'OUTPUT:', self.output or ''])
def check_output(command, timeout=None, **kwargs):
"""This is a version of subprocess.check_output that adds a timeout parameter to kill
the subprocess if it does not return within the specified time."""
def callback(pid):
try:
check_output_logger.debug('{} timed out; sending SIGKILL'.format(pid))
os.killpg(pid, signal.SIGKILL)
except OSError:
pass # process may have already terminated.
if 'stdout' in kwargs:
raise ValueError('stdout argument not allowed, it will be overridden.')
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
preexec_fn=preexec_function, **kwargs)
if timeout:
timer = threading.Timer(timeout, callback, [process.pid, ])
timer.start()
try:
output, error = process.communicate()
finally:
if timeout:
timer.cancel()
retcode = process.poll()
if retcode:
if retcode == -9: # killed, assume due to timeout callback
raise TimeoutError(command, output='\n'.join([output, error]))
else:
raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output, error]))
return output, error
def walk_modules(path):
"""
Given package name, return a list of all modules (including submodules, etc)
in that package.
"""
root_mod = __import__(path, {}, {}, [''])
mods = [root_mod]
for _, name, ispkg in pkgutil.iter_modules(root_mod.__path__):
submod_path = '.'.join([path, name])
if ispkg:
mods.extend(walk_modules(submod_path))
else:
submod = __import__(submod_path, {}, {}, [''])
mods.append(submod)
return mods
def ensure_directory_exists(dirpath):
"""A filter for directory paths to ensure they exist."""
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
return dirpath
def ensure_file_directory_exists(filepath):
"""
A filter for file paths to ensure the directory of the
file exists and the file can be created there. The file
itself is *not* going to be created if it doesn't already
exist.
"""
ensure_directory_exists(os.path.dirname(filepath))
return filepath
def diff_tokens(before_token, after_token):
"""
Creates a diff of two tokens.
If the two tokens are the same it just returns returns the token
(whitespace tokens are considered the same irrespective of type/number
of whitespace characters in the token).
If the tokens are numeric, the difference between the two values
is returned.
Otherwise, a string in the form [before -> after] is returned.
"""
if before_token.isspace() and after_token.isspace():
return after_token
elif before_token.isdigit() and after_token.isdigit():
try:
diff = int(after_token) - int(before_token)
return str(diff)
except ValueError:
return "[%s -> %s]" % (before_token, after_token)
elif before_token == after_token:
return after_token
else:
return "[%s -> %s]" % (before_token, after_token)
def prepare_table_rows(rows):
"""Given a list of lists, make sure they are prepared to be formatted into a table
by making sure each row has the same number of columns and stringifying all values."""
rows = [map(str, r) for r in rows]
max_cols = max(map(len, rows))
for row in rows:
pad = max_cols - len(row)
for _ in xrange(pad):
row.append('')
return rows
def write_table(rows, wfh, align='>', headers=None): # pylint: disable=R0914
"""Write a column-aligned table to the specified file object."""
if not rows:
return
rows = prepare_table_rows(rows)
num_cols = len(rows[0])
# cycle specified alignments until we have max_cols of them. This is
# consitent with how such cases are handled in R, pandas, etc.
it = cycle(align)
align = [it.next() for _ in xrange(num_cols)]
cols = zip(*rows)
col_widths = [max(map(len, c)) for c in cols]
row_format = ' '.join(['{:%s%s}' % (align[i], w) for i, w in enumerate(col_widths)])
row_format += '\n'
if headers:
wfh.write(row_format.format(*headers))
underlines = ['-' * len(h) for h in headers]
wfh.write(row_format.format(*underlines))
for row in rows:
wfh.write(row_format.format(*row))
def get_null():
"""Returns the correct null sink based on the OS."""
return 'NUL' if os.name == 'nt' else '/dev/null'
def get_traceback(exc=None):
"""
Returns the string with the traceback for the specifiec exc
object, or for the current exception exc is not specified.
"""
if exc is None:
exc = sys.exc_info()
if not exc:
return None
tb = exc[2]
sio = StringIO()
traceback.print_tb(tb, file=sio)
del tb # needs to be done explicitly see: http://docs.python.org/2/library/sys.html#sys.exc_info
return sio.getvalue()
def merge_dicts(*args, **kwargs):
if not len(args) >= 2:
raise ValueError('Must specify at least two dicts to merge.')
func = partial(_merge_two_dicts, **kwargs)
return reduce(func, args)
def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # pylint: disable=R0912,R0914
dict_type=dict, should_normalize=True, should_merge_lists=True):
"""Merge dicts normalizing their keys."""
merged = dict_type()
base_keys = base.keys()
other_keys = other.keys()
norm = normalize if should_normalize else lambda x, y: x
base_only = []
other_only = []
both = []
union = []
for k in base_keys:
if k in other_keys:
both.append(k)
else:
base_only.append(k)
union.append(k)
for k in other_keys:
if k in base_keys:
union.append(k)
else:
union.append(k)
other_only.append(k)
for k in union:
if k in base_only:
merged[k] = norm(base[k], dict_type)
elif k in other_only:
merged[k] = norm(other[k], dict_type)
elif k in both:
base_value = base[k]
other_value = other[k]
base_type = type(base_value)
other_type = type(other_value)
if (match_types and (base_type != other_type) and
(base_value is not None) and (other_value is not None)):
raise ValueError('Type mismatch for {} got {} ({}) and {} ({})'.format(k, base_value, base_type,
other_value, other_type))
if isinstance(base_value, dict):
merged[k] = _merge_two_dicts(base_value, other_value, list_duplicates, match_types, dict_type)
elif isinstance(base_value, list):
if should_merge_lists:
merged[k] = _merge_two_lists(base_value, other_value, list_duplicates, dict_type)
else:
merged[k] = _merge_two_lists([], other_value, list_duplicates, dict_type)
elif isinstance(base_value, set):
merged[k] = norm(base_value.union(other_value), dict_type)
else:
merged[k] = norm(other_value, dict_type)
else: # Should never get here
raise AssertionError('Unexpected merge key: {}'.format(k))
return merged
def merge_lists(*args, **kwargs):
if not len(args) >= 2:
raise ValueError('Must specify at least two lists to merge.')
func = partial(_merge_two_lists, **kwargs)
return reduce(func, args)
def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: disable=R0912
"""Merge lists, normalizing their entries."""
if duplicates == 'all':
merged_list = []
for v in normalize(base, dict_type) + normalize(other, dict_type):
if not _check_remove_item(merged_list, v):
merged_list.append(v)
return merged_list
elif duplicates == 'first':
merged_list = []
for v in normalize(base + other, dict_type):
if not _check_remove_item(merged_list, v):
if v not in merged_list:
merged_list.append(v)
return merged_list
elif duplicates == 'last':
merged_list = []
for v in normalize(base + other, dict_type):
if not _check_remove_item(merged_list, v):
if v in merged_list:
del merged_list[merged_list.index(v)]
merged_list.append(v)
return merged_list
else:
raise ValueError('Unexpected value for list duplcates argument: {}. '.format(duplicates) +
'Must be in {"all", "first", "last"}.')
def _check_remove_item(the_list, item):
"""Helper function for merge_lists that implements checking wether an items
should be removed from the list and doing so if needed. Returns ``True`` if
the item has been removed and ``False`` otherwise."""
if not isinstance(item, basestring):
return False
if not item.startswith('~'):
return False
actual_item = item[1:]
if actual_item in the_list:
del the_list[the_list.index(actual_item)]
return True
def normalize(value, dict_type=dict):
"""Normalize values. Recursively normalizes dict keys to be lower case,
no surrounding whitespace, underscore-delimited strings."""
if isinstance(value, dict):
normalized = dict_type()
for k, v in value.iteritems():
key = k.strip().lower().replace(' ', '_')
normalized[key] = normalize(v, dict_type)
return normalized
elif isinstance(value, list):
return [normalize(v, dict_type) for v in value]
elif isinstance(value, tuple):
return tuple([normalize(v, dict_type) for v in value])
else:
return value
VALUE_REGEX = re.compile(r'(\d+(?:\.\d+)?)\s*(\w*)')
UNITS_MAP = {
's': 'seconds',
'ms': 'milliseconds',
'us': 'microseconds',
'ns': 'nanoseconds',
'V': 'volts',
'A': 'amps',
'mA': 'milliamps',
'J': 'joules',
}
def parse_value(value_string):
"""parses a string representing a numerical value and returns
a tuple (value, units), where value will be either int or float,
and units will be a string representing the units or None."""
match = VALUE_REGEX.search(value_string)
if match:
vs = match.group(1)
value = float(vs) if '.' in vs else int(vs)
us = match.group(2)
units = UNITS_MAP.get(us, us)
return (value, units)
else:
return (value_string, None)
def get_meansd(values):
"""Returns mean and standard deviation of the specified values."""
if not values:
return float('nan'), float('nan')
mean = sum(values) / len(values)
sd = math.sqrt(sum([v ** 2 for v in values]) / len(values) - mean ** 2)
return mean, sd
def geomean(values):
"""Returns the geometric mean of the values."""
return reduce(mul, values) ** (1.0 / len(values))
def capitalize(text):
"""Capitalises the specified text: first letter upper case,
all subsequent letters lower case."""
if not text:
return ''
return text[0].upper() + text[1:].lower()
def convert_new_lines(text):
""" Convert new lines to a common format. """
return text.replace('\r\n', '\n').replace('\r', '\n')
def escape_quotes(text):
"""Escape quotes, and escaped quotes, in the specified text."""
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\\\'').replace('\"', '\\\"')
def escape_single_quotes(text):
"""Escape single quotes, and escaped single quotes, in the specified text."""
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\'\\\'\'')
def escape_double_quotes(text):
"""Escape double quotes, and escaped double quotes, in the specified text."""
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\"', '\\\"')
def getch(count=1):
"""Read ``count`` characters from standard input."""
if os.name == 'nt':
import msvcrt # pylint: disable=F0401
return ''.join([msvcrt.getch() for _ in xrange(count)])
else: # assume Unix
import tty # NOQA
import termios # NOQA
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(count)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch
def isiterable(obj):
"""Returns ``True`` if the specified object is iterable and
*is not a string type*, ``False`` otherwise."""
return hasattr(obj, '__iter__') and not isinstance(obj, basestring)
def utc_to_local(dt):
"""Convert naive datetime to local time zone, assuming UTC."""
return dt.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal())
def local_to_utc(dt):
"""Convert naive datetime to UTC, assuming local time zone."""
return dt.replace(tzinfo=tz.tzlocal()).astimezone(tz.tzutc())
def as_relative(path):
"""Convert path to relative by stripping away the leading '/' on UNIX or
the equivant on other platforms."""
path = os.path.splitdrive(path)[1]
return path.lstrip(os.sep)
def get_cpu_mask(cores):
"""Return a string with the hex for the cpu mask for the specified core numbers."""
mask = 0
for i in cores:
mask |= 1 << i
return '0x{0:x}'.format(mask)
def load_class(classpath):
"""Loads the specified Python class. ``classpath`` must be a fully-qualified
class name (i.e. namspaced under module/package)."""
modname, clsname = classpath.rsplit('.', 1)
return getattr(__import__(modname), clsname)
def get_pager():
"""Returns the name of the system pager program."""
pager = os.getenv('PAGER')
if not pager:
pager = find_executable('less')
if not pager:
pager = find_executable('more')
return pager
def enum_metaclass(enum_param, return_name=False, start=0):
"""
Returns a ``type`` subclass that may be used as a metaclass for
an enum.
Paremeters:
:enum_param: the name of class attribute that defines enum values.
The metaclass will add a class attribute for each value in
``enum_param``. The value of the attribute depends on the type
of ``enum_param`` and on the values of ``return_name``. If
``return_name`` is ``True``, then the value of the new attribute is
the name of that attribute; otherwise, if ``enum_param`` is a ``list``
or a ``tuple``, the value will be the index of that param in
``enum_param``, optionally offset by ``start``, otherwise, it will
be assumed that ``enum_param`` implementa a dict-like inteface and
the value will be ``enum_param[attr_name]``.
:return_name: If ``True``, the enum values will the names of enum attributes. If
``False``, the default, the values will depend on the type of
``enum_param`` (see above).
:start: If ``enum_param`` is a list or a tuple, and ``return_name`` is ``False``,
this specifies an "offset" that will be added to the index of the attribute
within ``enum_param`` to form the value.
"""
class __EnumMeta(type):
def __new__(mcs, clsname, bases, attrs):
cls = type.__new__(mcs, clsname, bases, attrs)
values = getattr(cls, enum_param, [])
if return_name:
for name in values:
setattr(cls, name, name)
else:
if isinstance(values, list) or isinstance(values, tuple):
for i, name in enumerate(values):
setattr(cls, name, i + start)
else: # assume dict-like
for name in values:
setattr(cls, name, values[name])
return cls
return __EnumMeta
def which(name):
"""Platform-independent version of UNIX which utility."""
if os.name == 'nt':
paths = os.getenv('PATH').split(os.pathsep)
exts = os.getenv('PATHEXT').split(os.pathsep)
for path in paths:
testpath = os.path.join(path, name)
if os.path.isfile(testpath):
return testpath
for ext in exts:
testpathext = testpath + ext
if os.path.isfile(testpathext):
return testpathext
return None
else: # assume UNIX-like
try:
return check_output(['which', name])[0].strip()
except subprocess.CalledProcessError:
return None
_bash_color_regex = re.compile('\x1b\[[0-9;]+m')
def strip_bash_colors(text):
return _bash_color_regex.sub('', text)
def format_duration(seconds, sep=' ', order=['day', 'hour', 'minute', 'second']): # pylint: disable=dangerous-default-value
"""
Formats the specified number of seconds into human-readable duration.
"""
if isinstance(seconds, timedelta):
td = seconds
else:
td = timedelta(seconds=seconds)
dt = datetime(1, 1, 1) + td
result = []
for item in order:
value = getattr(dt, item, None)
if item is 'day':
value -= 1
if not value:
continue
suffix = '' if value == 1 else 's'
result.append('{} {}{}'.format(value, item, suffix))
return sep.join(result)
def get_article(word):
"""
Returns the appropriate indefinite article for the word (ish).
.. note:: Indefinite article assignment in English is based on
sound rather than spelling, so this will not work correctly
in all case; e.g. this will return ``"a hour"``.
"""
return'an' if word[0] in 'aoeiu' else 'a'
def get_random_string(length):
"""Returns a random ASCII string of the specified length)."""
return ''.join(random.choice(string.ascii_letters + string.digits) for _ in xrange(length))
class LoadSyntaxError(Exception):
def __init__(self, message, filepath, lineno):
super(LoadSyntaxError, self).__init__(message)
self.filepath = filepath
self.lineno = lineno
def __str__(self):
message = 'Syntax Error in {}, line {}:\n\t{}'
return message.format(self.filepath, self.lineno, self.message)
RAND_MOD_NAME_LEN = 30
BAD_CHARS = string.punctuation + string.whitespace
TRANS_TABLE = string.maketrans(BAD_CHARS, '_' * len(BAD_CHARS))
def to_identifier(text):
"""Converts text to a valid Python identifier by replacing all
whitespace and punctuation."""
return re.sub('_+', '_', text.translate(TRANS_TABLE))
def load_struct_from_python(filepath=None, text=None):
"""Parses a config structure from a .py file. The structure should be composed
of basic Python types (strings, ints, lists, dicts, etc.)."""
if not (filepath or text) or (filepath and text):
raise ValueError('Exactly one of filepath or text must be specified.')
try:
if filepath:
modname = to_identifier(filepath)
mod = imp.load_source(modname, filepath)
else:
modname = get_random_string(RAND_MOD_NAME_LEN)
while modname in sys.modules: # highly unlikely, but...
modname = get_random_string(RAND_MOD_NAME_LEN)
mod = imp.new_module(modname)
exec text in mod.__dict__ # pylint: disable=exec-used
return dict((k, v)
for k, v in mod.__dict__.iteritems()
if not k.startswith('_'))
except SyntaxError as e:
raise LoadSyntaxError(e.message, e.filepath, e.lineno)
def load_struct_from_yaml(filepath=None, text=None):
"""Parses a config structure from a .yaml file. The structure should be composed
of basic Python types (strings, ints, lists, dicts, etc.)."""
if not (filepath or text) or (filepath and text):
raise ValueError('Exactly one of filepath or text must be specified.')
try:
if filepath:
with open(filepath) as fh:
return yaml.load(fh)
else:
return yaml.load(text)
except yaml.YAMLError as e:
lineno = None
if hasattr(e, 'problem_mark'):
lineno = e.problem_mark.line
raise LoadSyntaxError(e.message, filepath=filepath, lineno=lineno)
def load_struct_from_file(filepath):
"""
Attempts to parse a Python structure consisting of basic types from the specified file.
Raises a ``ValueError`` if the specified file is of unkown format; ``LoadSyntaxError`` if
there is an issue parsing the file.
"""
extn = os.path.splitext(filepath)[1].lower()
if (extn == '.py') or (extn == '.pyc') or (extn == '.pyo'):
return load_struct_from_python(filepath)
elif extn == '.yaml':
return load_struct_from_yaml(filepath)
else:
raise ValueError('Unknown format "{}": {}'.format(extn, filepath))
def unique(alist):
"""
Returns a list containing only unique elements from the input list (but preserves
order, unlike sets).
"""
result = []
for item in alist:
if item not in result:
result.append(item)
return result

98
wlauto/utils/netio.py Normal file
View File

@@ -0,0 +1,98 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains utilities for implemening device hard reset
using Netio 230 series power switches. This utilizes the KSHELL connection.
"""
import telnetlib
import socket
import re
import time
import logging
logger = logging.getLogger('NetIO')
class NetioError(Exception):
pass
class KshellConnection(object):
response_regex = re.compile(r'^(\d+) (.*?)\r\n')
delay = 0.5
def __init__(self, host='ippowerbar', port=1234, timeout=None):
"""Parameters are passed into ``telnetlib.Telnet`` -- see Python docs."""
self.host = host
self.port = port
self.conn = telnetlib.Telnet(host, port, timeout)
time.sleep(self.delay) # give time to respond
output = self.conn.read_very_eager()
if 'HELLO' not in output:
raise NetioError('Could not connect: did not see a HELLO. Got: {}'.format(output))
def login(self, user, password):
code, out = self.send_command('login {} {}\r\n'.format(user, password))
if not code == 250:
raise NetioError('Login failed. Got: {} {}'.format(code, out))
def enable_port(self, port):
"""Enable the power supply at the specified port."""
self.set_port(port, 1)
def disable_port(self, port):
"""Enable the power supply at the specified port."""
self.set_port(port, 0)
def set_port(self, port, value):
code, out = self.send_command('port {} {}'.format(port, value))
if not code == 250:
raise NetioError('Could not set {} on port {}. Got: {} {}'.format(value, port, code, out))
def send_command(self, command):
try:
if command.startswith('login'):
parts = command.split()
parts[2] = '*' * len(parts[2])
logger.debug(' '.join(parts))
else:
logger.debug(command)
self.conn.write('{}\n'.format(command))
time.sleep(self.delay) # give time to respond
out = self.conn.read_very_eager()
match = self.response_regex.search(out)
if not match:
raise NetioError('Invalid response: {}'.format(out.strip()))
logger.debug('response: {} {}'.format(match.group(1), match.group(2)))
return int(match.group(1)), match.group(2)
except socket.error as err:
try:
time.sleep(self.delay) # give time to respond
out = self.conn.read_very_eager()
if out.startswith('130 CONNECTION TIMEOUT'):
raise NetioError('130 Timed out.')
except EOFError:
pass
raise err
def close(self):
self.conn.close()

111
wlauto/utils/serial_port.py Normal file
View File

@@ -0,0 +1,111 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from contextlib import contextmanager
import serial
import fdpexpect
# Adding pexpect exceptions into this module's namespace
from pexpect import EOF, TIMEOUT # NOQA pylint: disable=W0611
from wlauto.exceptions import HostError
from wlauto.utils.log import LogWriter
class PexpectLogger(LogWriter):
def __init__(self, kind):
"""
File-like object class designed to be used for logging with pexpect or
fdpexect. Each complete line (terminated by new line character) gets logged
at DEBUG level. In complete lines are buffered until the next new line.
:param kind: This specified which of pexpect logfile attributes this logger
will be set to. It should be "read" for logfile_read, "send" for
logfile_send, and "" (emtpy string) for logfile.
"""
if kind not in ('read', 'send', ''):
raise ValueError('kind must be "read", "send" or ""; got {}'.format(kind))
self.kind = kind
logger_name = 'serial_{}'.format(kind) if kind else 'serial'
super(PexpectLogger, self).__init__(logger_name)
def pulse_dtr(conn, state=True, duration=0.1):
"""Set the DTR line of the specified serial connection to the specified state
for the specified duration (note: the initial state of the line is *not* checked."""
conn.setDTR(state)
time.sleep(duration)
conn.setDTR(not state)
@contextmanager
def open_serial_connection(timeout, get_conn=False, init_dtr=None, *args, **kwargs):
"""
Opens a serial connection to a device.
:param timeout: timeout for the fdpexpect spawn object.
:param conn: ``bool`` that specfies whether the underlying connection
object should be yielded as well.
:param init_dtr: specifies the initial DTR state stat should be set.
All arguments are passed into the __init__ of serial.Serial. See
pyserial documentation for details:
http://pyserial.sourceforge.net/pyserial_api.html#serial.Serial
:returns: a pexpect spawn object connected to the device.
See: http://pexpect.sourceforge.net/pexpect.html
"""
if init_dtr is not None:
kwargs['dsrdtr'] = True
try:
conn = serial.Serial(*args, **kwargs)
except serial.SerialException as e:
raise HostError(e.message)
if init_dtr is not None:
conn.setDTR(init_dtr)
conn.nonblocking()
conn.flushOutput()
target = fdpexpect.fdspawn(conn.fileno(), timeout=timeout)
target.logfile_read = PexpectLogger('read')
target.logfile_send = PexpectLogger('send')
# Monkey-patching sendline to introduce a short delay after
# chacters are sent to the serial. If two sendline s are issued
# one after another the second one might start putting characters
# into the serial device before the first one has finished, causing
# corruption. The delay prevents that.
tsln = target.sendline
def sendline(x):
tsln(x)
time.sleep(0.1)
target.sendline = sendline
if get_conn:
yield target, conn
else:
yield target
target.close() # Closes the file descriptor used by the conn.
del conn

198
wlauto/utils/ssh.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import subprocess
import re
import pxssh
from pexpect import EOF, TIMEOUT, spawn
from wlauto.exceptions import HostError, DeviceError, TimeoutError, ConfigError
from wlauto.utils.misc import which, strip_bash_colors, escape_single_quotes, check_output
ssh = None
scp = None
sshpass = None
logger = logging.getLogger('ssh')
def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False):
_check_env()
if telnet:
if keyfile:
raise ConfigError('keyfile may not be used with a telnet connection.')
conn = TelnetConnection()
else: # ssh
conn = pxssh.pxssh()
try:
if keyfile:
conn.SSH_OPTS += ' -i {}'.format(keyfile)
conn.login(host, username, port=port, login_timeout=timeout)
else:
conn.login(host, username, password, port=port, login_timeout=timeout)
except EOF:
raise DeviceError('Could not connect to {}; is the host name correct?'.format(host))
return conn
class TelnetConnection(pxssh.pxssh):
# pylint: disable=arguments-differ
def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10,
auto_prompt_reset=True, sync_multiplier=1):
cmd = 'telnet -l {} {}'.format(username, server)
spawn._spawn(self, cmd) # pylint: disable=protected-access
i = self.expect('(?i)(?:password)', timeout=login_timeout)
if i == 0:
self.sendline(password)
i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout)
else:
raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt')
if i:
raise pxssh.ExceptionPxssh('could not log in: password was incorrect')
if not self.sync_original_prompt(sync_multiplier):
self.close()
raise pxssh.ExceptionPxssh('could not synchronize with original prompt')
if auto_prompt_reset:
if not self.set_unique_prompt():
self.close()
message = 'could not set shell prompt (recieved: {}, expected: {}).'
raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT))
return True
class SshShell(object):
def __init__(self, timeout=10):
self.timeout = timeout
self.conn = None
def login(self, host, username, password=None, keyfile=None, port=None, timeout=None, telnet=False):
# pylint: disable=attribute-defined-outside-init
logger.debug('Logging in {}@{}'.format(username, host))
self.host = host
self.username = username
self.password = password
self.keyfile = keyfile
self.port = port
timeout = self.timeout if timeout is None else timeout
self.conn = ssh_get_shell(host, username, password, keyfile, port, timeout, telnet)
def push_file(self, source, dest, timeout=30):
dest = '{}@{}:{}'.format(self.username, self.host, dest)
return self._scp(source, dest, timeout)
def pull_file(self, source, dest, timeout=30):
source = '{}@{}:{}'.format(self.username, self.host, source)
return self._scp(source, dest, timeout)
def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE):
port_string = '-p {}'.format(self.port) if self.port else ''
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command)
logger.debug(command)
if self.password:
command = _give_password(self.password, command)
return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
if check_exit_code:
exit_code = int(self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False))
if exit_code:
message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
raise DeviceError(message.format(exit_code, command, output))
return output
def logout(self):
logger.debug('Logging out {}@{}'.format(self.username, self.host))
self.conn.logout()
def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True):
timeout = self.timeout if timeout is None else timeout
if as_root:
command = "sudo -- sh -c '{}'".format(escape_single_quotes(command))
if log:
logger.debug(command)
self.conn.sendline(command)
index = self.conn.expect_exact(['[sudo] password', TIMEOUT], timeout=0.5)
if index == 0:
self.conn.sendline(self.password)
timed_out = not self.conn.prompt(timeout)
output = re.sub(r'.*?{}'.format(re.escape(command)), '', self.conn.before, 1).strip()
else:
if log:
logger.debug(command)
self.conn.sendline(command)
timed_out = not self.conn.prompt(timeout)
# the regex removes line breaks potentiall introduced when writing
# command to shell.
command_index = re.sub(r' \r([^\n])', r'\1', self.conn.before).find(command)
while not timed_out and command_index == -1:
# In case of a "premature" timeout (i.e. timeout, but no hang,
# so command completes afterwards), there may be a prompt from
# the previous command completion in the serial output. This
# checks for this case by making sure that the original command
# is present in the serial output and waiting for the next
# prompt if it is not.
timed_out = not self.conn.prompt(timeout)
command_index = re.sub(r' \r([^\n])', r'\1', self.conn.before).find(command)
output = self.conn.before[command_index + len(command):].strip()
if timed_out:
raise TimeoutError(command, output)
if strip_colors:
output = strip_bash_colors(output)
return output
def _scp(self, source, dest, timeout=30):
port_string = '-P {}'.format(self.port) if self.port else ''
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest)
pass_string = ''
logger.debug(command)
if self.password:
command = _give_password(self.password, command)
try:
check_output(command, timeout=timeout, shell=True)
except subprocess.CalledProcessError as e:
raise subprocess.CalledProcessError(e.returncode, e.cmd.replace(pass_string, ''), e.output)
except TimeoutError as e:
raise TimeoutError(e.command.replace(pass_string, ''), e.output)
def _give_password(password, command):
if not sshpass:
raise HostError('Must have sshpass installed on the host in order to use password-based auth.')
pass_string = "sshpass -p '{}' ".format(password)
return pass_string + command
def _check_env():
global ssh, scp, sshpass # pylint: disable=global-statement
if not ssh:
ssh = which('ssh')
scp = which('scp')
sshpass = which('sshpass')
if not (ssh and scp):
raise HostError('OpenSSH must be installed on the host.')

176
wlauto/utils/types.py Normal file
View File

@@ -0,0 +1,176 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Routines for doing various type conversions. These usually embody some higher-level
semantics than are present in standard Python types (e.g. ``boolean`` will convert the
string ``"false"`` to ``False``, where as non-empty strings are usually considered to be
``True``).
A lot of these are intened to stpecify type conversions declaratively in place like
``Parameter``'s ``kind`` argument. These are basically "hacks" around the fact that Python
is not the best language to use for configuration.
"""
import re
import math
from collections import defaultdict
from wlauto.utils.misc import isiterable, to_identifier
def identifier(text):
"""Converts text to a valid Python identifier by replacing all
whitespace and punctuation."""
return to_identifier(text)
def boolean(value):
"""
Returns bool represented by the value. This is different from
calling the builtin bool() in that it will interpret string representations.
e.g. boolean('0') and boolean('false') will both yield False.
"""
false_strings = ['', '0', 'n', 'no']
if isinstance(value, basestring):
value = value.lower()
if value in false_strings or 'false'.startswith(value):
return False
return bool(value)
def numeric(value):
"""
Returns the value as number (int if possible, or float otherwise), or
raises ``ValueError`` if the specified ``value`` does not have a straight
forward numeric conversion.
"""
if isinstance(value, int):
return value
try:
fvalue = float(value)
except ValueError:
raise ValueError('Not numeric: {}'.format(value))
if not math.isnan(fvalue) and not math.isinf(fvalue):
ivalue = int(fvalue)
if ivalue == fvalue: # yeah, yeah, I know. Whatever. This is best-effort.
return ivalue
return fvalue
def list_or_string(value):
"""
If the value is a string, at will be kept as a string, otherwise it will be interpreted
as a list. If that is not possible, it will be interpreted as a string.
"""
if isinstance(value, basestring):
return value
else:
try:
return list(value)
except ValueError:
return str(value)
def list_of_strs(value):
"""
Value must be iterable. All elements will be converted to strings.
"""
if not isiterable(value):
raise ValueError(value)
return map(str, value)
list_of_strings = list_of_strs
def list_of_ints(value):
"""
Value must be iterable. All elements will be converted to ``int``\ s.
"""
if not isiterable(value):
raise ValueError(value)
return map(int, value)
list_of_integers = list_of_ints
def list_of_numbers(value):
"""
Value must be iterable. All elements will be converted to numbers (either ``ints`` or
``float``\ s depending on the elements).
"""
if not isiterable(value):
raise ValueError(value)
return map(numeric, value)
def list_of_bools(value, interpret_strings=True):
"""
Value must be iterable. All elements will be converted to ``bool``\ s.
.. note:: By default, ``boolean()`` conversion function will be used, which means that
strings like ``"0"`` or ``"false"`` will be interpreted as ``False``. If this
is undesirable, set ``interpret_strings`` to ``False``.
"""
if not isiterable(value):
raise ValueError(value)
if interpret_strings:
return map(boolean, value)
else:
return map(bool, value)
regex_type = type(re.compile(''))
def regex(value):
"""
Regular expression. If value is a string, it will be complied with no flags. If you
want to specify flags, value must be precompiled.
"""
if isinstance(value, regex_type):
return value
else:
return re.compile(value)
__counters = defaultdict(int)
def reset_counter(name=None):
__counters[name] = 0
def counter(name=None):
"""
An auto incremeting value (kind of like an AUTO INCREMENT field in SQL).
Optionally, the name of the counter to be used is specified (each counter
increments separately).
Counts start at 1, not 0.
"""
__counters[name] += 1
value = __counters[name]
return value

214
wlauto/utils/uefi.py Normal file
View File

@@ -0,0 +1,214 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import time
import logging
from wlauto.utils.serial_port import TIMEOUT
logger = logging.getLogger('UEFI')
class UefiMenu(object):
"""
Allows navigating UEFI menu over serial (it relies on a pexpect connection).
"""
option_regex = re.compile(r'^\[(\d+)\]\s+([^\r]+)\r\n', re.M)
prompt_regex = re.compile(r'^([^\r\n]+):\s*', re.M)
invalid_regex = re.compile(r'Invalid input \(max (\d+)\)', re.M)
load_delay = 1 # seconds
default_timeout = 60 # seconds
def __init__(self, conn, prompt='The default boot selection will start in'):
"""
:param conn: A serial connection as returned by ``pexect.spawn()``.
:param prompt: The starting prompt to wait for during ``open()``.
"""
self.conn = conn
self.start_prompt = prompt
self.options = {}
self.prompt = None
def open(self, timeout=default_timeout):
"""
"Open" the UEFI menu by sending an interrupt on STDIN after seeing the
starting prompt (configurable upon creation of the ``UefiMenu`` object.
"""
self.conn.expect(self.start_prompt, timeout)
self.conn.sendline('')
time.sleep(self.load_delay)
def create_entry(self, name, image, args, fdt_support, initrd=None, fdt_path=None):
"""Create a new UEFI entry using the parameters. The menu is assumed
to be at the top level. Upon return, the menu will be at the top level."""
logger.debug('Creating UEFI entry {}'.format(name))
self.nudge()
self.select('Boot Manager')
self.select('Add Boot Device Entry')
self.select('NOR Flash')
self.enter(image)
self.enter('y' if fdt_support else 'n')
if initrd:
self.enter('y')
self.enter(initrd)
else:
self.enter('n')
self.enter(args)
self.enter(name)
if fdt_path:
self.select('Update FDT path')
self.enter(fdt_path)
self.select('Return to main menu')
def delete_entry(self, name):
"""Delete the specified UEFI entry. The menu is assumed
to be at the top level. Upon return, the menu will be at the top level."""
logger.debug('Removing UEFI entry {}'.format(name))
self.nudge()
self.select('Boot Manager')
self.select('Remove Boot Device Entry')
self.select(name)
self.select('Return to main menu')
def select(self, option, timeout=default_timeout):
"""
Select the specified option from the current menu.
:param option: Could be an ``int`` index of the option, or a string/regex to
match option text against.
:param timeout: If a non-``int`` option is specified, the option list may need
need to be parsed (if it hasn't been already), this may block
and the timeout is used to cap that , resulting in a ``TIMEOUT``
exception.
:param delay: A fixed delay to wait after sending the input to the serial connection.
This should be set if input this action is known to result in a
long-running operation.
"""
if isinstance(option, basestring):
option = self.get_option_index(option, timeout)
self.enter(option)
def enter(self, value, delay=load_delay):
"""Like ``select()`` except no resolution is performed -- the value is sent directly
to the serial connection."""
# Empty the buffer first, so that only response to the input about to
# be sent will be processed by subsequent commands.
value = str(value)
self._reset()
self.write_characters(value)
# TODO: in case the value is long an complicated, things may get
# screwed up (e.g. there may be line breaks injected), additionally,
# special chars might cause regex to fail. To avoid these issues i'm
# only matching against the first 5 chars of the value. This is
# entirely arbitrary and I'll probably have to find a better way of
# doing this at some point.
self.conn.expect(value[:5], timeout=delay)
time.sleep(self.load_delay)
def read_menu(self, timeout=default_timeout):
"""Parse serial output to get the menu options and the following prompt."""
attempting_timeout_retry = False
attempting_invalid_retry = False
while True:
index = self.conn.expect([self.option_regex, self.prompt_regex, self.invalid_regex, TIMEOUT],
timeout=timeout)
match = self.conn.match
if index == 0: # matched menu option
self.options[match.group(1)] = match.group(2)
elif index == 1: # matched prompt
self.prompt = match.group(1)
break
elif index == 2: # matched invalid selection
# We've sent an invalid input (which includes an empty line) at
# the top-level menu. To get back the menu options, it seems we
# need to enter what the error reports as the max + 1, so...
if not attempting_invalid_retry:
attempting_invalid_retry = True
val = int(match.group(1)) + 1
self.empty_buffer()
self.enter(val)
else: # OK, that didn't work; panic!
raise RuntimeError('Could not read menu entries stuck on "{}" prompt'.format(self.prompt))
elif index == 3: # timed out
if not attempting_timeout_retry:
attempting_timeout_retry = True
self.nudge()
else: # Didn't help. Run away!
raise RuntimeError('Did not see a valid UEFI menu.')
else:
raise AssertionError('Unexpected response waiting for UEFI menu') # should never get here
def get_option_index(self, text, timeout=default_timeout):
"""Returns the menu index of the specified option text (uses regex matching). If the option
is not in the current menu, ``LookupError`` will be raised."""
if not self.prompt:
self.read_menu(timeout)
for k, v in self.options.iteritems():
if re.search(text, v):
return k
raise LookupError(text)
def has_option(self, text, timeout=default_timeout):
"""Returns ``True`` if at least one of the options in the current menu has
matched (using regex) the specified text."""
try:
self.get_option_index(text, timeout)
return True
except LookupError:
return False
def nudge(self):
"""Send a little nudge to ensure there is something to read. This is useful when you're not
sure if all out put from the serial has been read already."""
self.enter('')
def empty_buffer(self):
"""Read everything from the serial and clear the internal pexpect buffer. This ensures
that the next ``expect()`` call will time out (unless further input will be sent to the
serial beforehand. This is used to create a "known" state and avoid unexpected matches."""
try:
while True:
time.sleep(0.1)
self.conn.read_nonblocking(size=1024, timeout=0.1)
except TIMEOUT:
pass
self.conn.buffer = ''
def write_characters(self, line):
"""Write a single line out to serial charcter-by-character. This will ensure that nothing will
be dropped for longer lines."""
line = line.rstrip('\r\n')
for c in line:
self.conn.send(c)
time.sleep(0.05)
self.conn.sendline('')
def _reset(self):
self.options = {}
self.prompt = None
self.empty_buffer()