1
0
mirror of https://github.com/ARM-software/devlib.git synced 2025-09-09 21:41:53 +01:00

devlib initial commit.

This commit is contained in:
Sergei Trofimov
2015-10-09 09:30:04 +01:00
commit 4e6afe960b
64 changed files with 8938 additions and 0 deletions

16
devlib/utils/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

428
devlib/utils/android.py Normal file
View File

@@ -0,0 +1,428 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utility functions for working with Android devices through adb.
"""
# pylint: disable=E1103
import os
import time
import subprocess
import logging
import re
from collections import defaultdict
from devlib.exception import TargetError, HostError
from devlib.utils.misc import check_output, which
from devlib.utils.misc import escape_single_quotes, escape_double_quotes
logger = logging.getLogger('android')
MAX_ATTEMPTS = 5
AM_START_ERROR = re.compile(r"Error: Activity class {[\w|.|/]*} does not exist")
# See:
# http://developer.android.com/guide/topics/manifest/uses-sdk-element.html#ApiLevels
ANDROID_VERSION_MAP = {
22: 'LOLLYPOP_MR1',
21: 'LOLLYPOP',
20: 'KITKAT_WATCH',
19: 'KITKAT',
18: 'JELLY_BEAN_MR2',
17: 'JELLY_BEAN_MR1',
16: 'JELLY_BEAN',
15: 'ICE_CREAM_SANDWICH_MR1',
14: 'ICE_CREAM_SANDWICH',
13: 'HONEYCOMB_MR2',
12: 'HONEYCOMB_MR1',
11: 'HONEYCOMB',
10: 'GINGERBREAD_MR1',
9: 'GINGERBREAD',
8: 'FROYO',
7: 'ECLAIR_MR1',
6: 'ECLAIR_0_1',
5: 'ECLAIR',
4: 'DONUT',
3: 'CUPCAKE',
2: 'BASE_1_1',
1: 'BASE',
}
# Initialized in functions near the botton of the file
android_home = None
platform_tools = None
adb = None
aapt = None
fastboot = None
class AndroidProperties(object):
def __init__(self, text):
self._properties = {}
self.parse(text)
def parse(self, text):
self._properties = dict(re.findall(r'\[(.*?)\]:\s+\[(.*?)\]', text))
def iteritems(self):
return self._properties.iteritems()
def __iter__(self):
return iter(self._properties)
def __getattr__(self, name):
return self._properties.get(name)
__getitem__ = __getattr__
class AdbDevice(object):
def __init__(self, name, status):
self.name = name
self.status = status
def __cmp__(self, other):
if isinstance(other, AdbDevice):
return cmp(self.name, other.name)
else:
return cmp(self.name, other)
def __str__(self):
return 'AdbDevice({}, {})'.format(self.name, self.status)
__repr__ = __str__
class ApkInfo(object):
version_regex = re.compile(r"name='(?P<name>[^']+)' versionCode='(?P<vcode>[^']+)' versionName='(?P<vname>[^']+)'")
name_regex = re.compile(r"name='(?P<name>[^']+)'")
def __init__(self, path=None):
self.path = path
self.package = None
self.activity = None
self.label = None
self.version_name = None
self.version_code = None
self.parse(path)
def parse(self, apk_path):
_check_env()
command = [aapt, 'dump', 'badging', apk_path]
logger.debug(' '.join(command))
output = subprocess.check_output(command)
for line in output.split('\n'):
if line.startswith('application-label:'):
self.label = line.split(':')[1].strip().replace('\'', '')
elif line.startswith('package:'):
match = self.version_regex.search(line)
if match:
self.package = match.group('name')
self.version_code = match.group('vcode')
self.version_name = match.group('vname')
elif line.startswith('launchable-activity:'):
match = self.name_regex.search(line)
self.activity = match.group('name')
else:
pass # not interested
class AdbConnection(object):
# maintains the count of parallel active connections to a device, so that
# adb disconnect is not invoked untill all connections are closed
active_connections = defaultdict(int)
@property
def name(self):
return self.device
def __init__(self, device=None, timeout=10):
self.timeout = timeout
if device is None:
device = adb_get_device(timeout=timeout)
self.device = device
adb_connect(self.device)
AdbConnection.active_connections[self.device] += 1
def push(self, source, dest, timeout=None):
if timeout is None:
timeout = self.timeout
command = 'push {} {}'.format(source, dest)
return adb_command(self.device, command, timeout=timeout)
def pull(self, source, dest, timeout=None):
if timeout is None:
timeout = self.timeout
command = 'pull {} {}'.format(source, dest)
return adb_command(self.device, command, timeout=timeout)
def execute(self, command, timeout=None, check_exit_code=False, as_root=False):
return adb_shell(self.device, command, timeout, check_exit_code, as_root)
def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False):
return adb_background_shell(self.device, command, stdout, stderr, as_root)
def close(self):
AdbConnection.active_connections[self.device] -= 1
if AdbConnection.active_connections[self.device] <= 0:
adb_disconnect(self.device)
del AdbConnection.active_connections[self.device]
def cancel_running_command(self):
# adbd multiplexes commands so that they don't interfer with each
# other, so there is no need to explicitly cancel a running command
# before the next one can be issued.
pass
def fastboot_command(command, timeout=None):
_check_env()
full_command = "fastboot {}".format(command)
logger.debug(full_command)
output, _ = check_output(full_command, timeout, shell=True)
return output
def fastboot_flash_partition(partition, path_to_image):
command = 'flash {} {}'.format(partition, path_to_image)
fastboot_command(command)
def adb_get_device(timeout=None):
"""
Returns the serial number of a connected android device.
If there are more than one device connected to the machine, or it could not
find any device connected, :class:`devlib.exceptions.HostError` is raised.
"""
# TODO this is a hacky way to issue a adb command to all listed devices
# The output of calling adb devices consists of a heading line then
# a list of the devices sperated by new line
# The last line is a blank new line. in otherwords, if there is a device found
# then the output length is 2 + (1 for each device)
start = time.time()
while True:
output = adb_command(None, "devices").splitlines() # pylint: disable=E1103
output_length = len(output)
if output_length == 3:
# output[1] is the 2nd line in the output which has the device name
# Splitting the line by '\t' gives a list of two indexes, which has
# device serial in 0 number and device type in 1.
return output[1].split('\t')[0]
elif output_length > 3:
message = '{} Android devices found; either explicitly specify ' +\
'the device you want, or make sure only one is connected.'
raise HostError(message.format(output_length - 2))
else:
if timeout < time.time() - start:
raise HostError('No device is connected and available')
time.sleep(1)
def adb_connect(device, timeout=None, attempts=MAX_ATTEMPTS):
_check_env()
tries = 0
output = None
while tries <= attempts:
tries += 1
if device:
command = 'adb connect {}'.format(device)
logger.debug(command)
output, _ = check_output(command, shell=True, timeout=timeout)
if _ping(device):
break
time.sleep(10)
else: # did not connect to the device
message = 'Could not connect to {}'.format(device or 'a device')
if output:
message += '; got: "{}"'.format(output)
raise HostError(message)
def adb_disconnect(device):
_check_env()
if not device:
return
if ":" in device:
command = "adb disconnect " + device
logger.debug(command)
retval = subprocess.call(command, stdout=open(os.devnull, 'wb'), shell=True)
if retval:
raise TargetError('"{}" returned {}'.format(command, retval))
def _ping(device):
_check_env()
device_string = ' -s {}'.format(device) if device else ''
command = "adb{} shell \"ls / > /dev/null\"".format(device_string)
logger.debug(command)
result = subprocess.call(command, stderr=subprocess.PIPE, shell=True)
if not result:
return True
else:
return False
def adb_shell(device, command, timeout=None, check_exit_code=False, as_root=False): # NOQA
_check_env()
if as_root:
command = 'echo "{}" | su'.format(escape_double_quotes(command))
device_string = ' -s {}'.format(device) if device else ''
full_command = 'adb{} shell "{}"'.format(device_string,
escape_double_quotes(command))
logger.debug(full_command)
if check_exit_code:
actual_command = "adb{} shell '({}); echo $?'".format(device_string,
escape_single_quotes(command))
raw_output, error = check_output(actual_command, timeout, shell=True)
if raw_output:
try:
output, exit_code, _ = raw_output.rsplit('\n', 2)
except ValueError:
exit_code, _ = raw_output.rsplit('\n', 1)
output = ''
else: # raw_output is empty
exit_code = '969696' # just because
output = ''
exit_code = exit_code.strip()
if exit_code.isdigit():
if int(exit_code):
message = 'Got exit code {}\nfrom: {}\nSTDOUT: {}\nSTDERR: {}'
raise TargetError(message.format(exit_code, full_command, output, error))
elif AM_START_ERROR.findall(output):
message = 'Could not start activity; got the following:'
message += '\n{}'.format(AM_START_ERROR.findall(output)[0])
raise TargetError(message)
else: # not all digits
if AM_START_ERROR.findall(output):
message = 'Could not start activity; got the following:\n{}'
raise TargetError(message.format(AM_START_ERROR.findall(output)[0]))
else:
message = 'adb has returned early; did not get an exit code. '\
'Was kill-server invoked?'
raise TargetError(message)
else: # do not check exit code
output, _ = check_output(full_command, timeout, shell=True)
return output
def adb_background_shell(device, command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
as_root=False):
"""Runs the sepcified command in a subprocess, returning the the Popen object."""
_check_env()
if as_root:
command = 'echo \'{}\' | su'.format(escape_single_quotes(command))
device_string = ' -s {}'.format(device) if device else ''
full_command = 'adb{} shell "{}"'.format(device_string, escape_double_quotes(command))
logger.debug(full_command)
return subprocess.Popen(full_command, stdout=stdout, stderr=stderr, shell=True)
def adb_list_devices():
output = adb_command(None, 'devices')
devices = []
for line in output.splitlines():
parts = [p.strip() for p in line.split()]
if len(parts) == 2:
devices.append(AdbDevice(*parts))
return devices
def adb_command(device, command, timeout=None):
_check_env()
device_string = ' -s {}'.format(device) if device else ''
full_command = "adb{} {}".format(device_string, command)
logger.debug(full_command)
output, _ = check_output(full_command, timeout, shell=True)
return output
# Messy environment initialisation stuff...
class _AndroidEnvironment(object):
def __init__(self):
self.android_home = None
self.platform_tools = None
self.adb = None
self.aapt = None
self.fastboot = None
def _initialize_with_android_home(env):
logger.debug('Using ANDROID_HOME from the environment.')
env.android_home = android_home
env.platform_tools = os.path.join(android_home, 'platform-tools')
os.environ['PATH'] += os.pathsep + env.platform_tools
_init_common(env)
return env
def _initialize_without_android_home(env):
if which('adb'):
env.adb = 'adb'
else:
raise HostError('ANDROID_HOME is not set and adb is not in PATH. '
'Have you installed Android SDK?')
logger.debug('Discovering ANDROID_HOME from adb path.')
env.platform_tools = os.path.dirname(env.adb)
env.android_home = os.path.dirname(env.platform_tools)
_init_common(env)
return env
def _init_common(env):
logger.debug('ANDROID_HOME: {}'.format(env.android_home))
build_tools_directory = os.path.join(env.android_home, 'build-tools')
if not os.path.isdir(build_tools_directory):
msg = '''ANDROID_HOME ({}) does not appear to have valid Android SDK install
(cannot find build-tools)'''
raise HostError(msg.format(env.android_home))
versions = os.listdir(build_tools_directory)
for version in reversed(sorted(versions)):
aapt_path = os.path.join(build_tools_directory, version, 'aapt')
if os.path.isfile(aapt_path):
logger.debug('Using aapt for version {}'.format(version))
env.aapt = aapt_path
break
else:
raise HostError('aapt not found. Please make sure at least one Android '
'platform is installed.')
def _check_env():
global android_home, platform_tools, adb, aapt # pylint: disable=W0603
if not android_home:
android_home = os.getenv('ANDROID_HOME')
if android_home:
_env = _initialize_with_android_home(_AndroidEnvironment())
else:
_env = _initialize_without_android_home(_AndroidEnvironment())
android_home = _env.android_home
platform_tools = _env.platform_tools
adb = _env.adb
aapt = _env.aapt

552
devlib/utils/misc.py Normal file
View File

@@ -0,0 +1,552 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Miscellaneous functions that don't fit anywhere else.
"""
from __future__ import division
import os
import sys
import re
import string
import threading
import signal
import subprocess
import pkgutil
import logging
import random
from operator import itemgetter
from itertools import groupby
from functools import partial
# ABI --> architectures list
ABI_MAP = {
'armeabi': ['armeabi', 'armv7', 'armv7l', 'armv7el', 'armv7lh'],
'arm64': ['arm64', 'armv8', 'arm64-v8a', 'aarch64'],
}
# Vendor ID --> CPU part ID --> CPU variant ID --> Core Name
# None means variant is not used.
CPU_PART_MAP = {
0x41: { # ARM
0x926: {None: 'ARM926'},
0x946: {None: 'ARM946'},
0x966: {None: 'ARM966'},
0xb02: {None: 'ARM11MPCore'},
0xb36: {None: 'ARM1136'},
0xb56: {None: 'ARM1156'},
0xb76: {None: 'ARM1176'},
0xc05: {None: 'A5'},
0xc07: {None: 'A7'},
0xc08: {None: 'A8'},
0xc09: {None: 'A9'},
0xc0f: {None: 'A15'},
0xc14: {None: 'R4'},
0xc15: {None: 'R5'},
0xc20: {None: 'M0'},
0xc21: {None: 'M1'},
0xc23: {None: 'M3'},
0xc24: {None: 'M4'},
0xc27: {None: 'M7'},
0xd03: {None: 'A53'},
0xd07: {None: 'A57'},
0xd08: {None: 'A72'},
},
0x4e: { # Nvidia
0x0: {None: 'Denver'},
},
0x51: { # Qualcomm
0x02d: {None: 'Scorpion'},
0x04d: {None: 'MSM8960'},
0x06f: { # Krait
0x2: 'Krait400',
0x3: 'Krait450',
},
},
0x56: { # Marvell
0x131: {
0x2: 'Feroceon 88F6281',
}
},
}
def get_cpu_name(implementer, part, variant):
part_data = CPU_PART_MAP.get(implementer, {}).get(part, {})
if None in part_data: # variant does not determine core Name for this vendor
name = part_data[None]
else:
name = part_data.get(variant)
return name
def preexec_function():
# Ignore the SIGINT signal by setting the handler to the standard
# signal handler SIG_IGN.
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Change process group in case we have to kill the subprocess and all of
# its children later.
# TODO: this is Unix-specific; would be good to find an OS-agnostic way
# to do this in case we wanna port WA to Windows.
os.setpgrp()
check_output_logger = logging.getLogger('check_output')
# Defined here rather than in devlib.exceptions due to module load dependencies
class TimeoutError(Exception):
"""Raised when a subprocess command times out. This is basically a ``WAError``-derived version
of ``subprocess.CalledProcessError``, the thinking being that while a timeout could be due to
programming error (e.g. not setting long enough timers), it is often due to some failure in the
environment, and there fore should be classed as a "user error"."""
def __init__(self, command, output):
super(TimeoutError, self).__init__('Timed out: {}'.format(command))
self.command = command
self.output = output
def __str__(self):
return '\n'.join([self.message, 'OUTPUT:', self.output or ''])
def check_output(command, timeout=None, ignore=None, inputtext=None, **kwargs):
"""This is a version of subprocess.check_output that adds a timeout parameter to kill
the subprocess if it does not return within the specified time."""
# pylint: disable=too-many-branches
if ignore is None:
ignore = []
elif isinstance(ignore, int):
ignore = [ignore]
elif not isinstance(ignore, list) and ignore != 'all':
message = 'Invalid value for ignore parameter: "{}"; must be an int or a list'
raise ValueError(message.format(ignore))
if 'stdout' in kwargs:
raise ValueError('stdout argument not allowed, it will be overridden.')
def callback(pid):
try:
check_output_logger.debug('{} timed out; sending SIGKILL'.format(pid))
os.killpg(pid, signal.SIGKILL)
except OSError:
pass # process may have already terminated.
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
preexec_fn=preexec_function, **kwargs)
if timeout:
timer = threading.Timer(timeout, callback, [process.pid, ])
timer.start()
try:
output, error = process.communicate(inputtext)
finally:
if timeout:
timer.cancel()
retcode = process.poll()
if retcode:
if retcode == -9: # killed, assume due to timeout callback
raise TimeoutError(command, output='\n'.join([output, error]))
elif ignore != 'all' and retcode not in ignore:
raise subprocess.CalledProcessError(retcode, command, output='\n'.join([output, error]))
return output, error
def walk_modules(path):
"""
Given package name, return a list of all modules (including submodules, etc)
in that package.
"""
root_mod = __import__(path, {}, {}, [''])
mods = [root_mod]
for _, name, ispkg in pkgutil.iter_modules(root_mod.__path__):
submod_path = '.'.join([path, name])
if ispkg:
mods.extend(walk_modules(submod_path))
else:
submod = __import__(submod_path, {}, {}, [''])
mods.append(submod)
return mods
def ensure_directory_exists(dirpath):
"""A filter for directory paths to ensure they exist."""
if not os.path.isdir(dirpath):
os.makedirs(dirpath)
return dirpath
def ensure_file_directory_exists(filepath):
"""
A filter for file paths to ensure the directory of the
file exists and the file can be created there. The file
itself is *not* going to be created if it doesn't already
exist.
"""
ensure_directory_exists(os.path.dirname(filepath))
return filepath
def merge_dicts(*args, **kwargs):
if not len(args) >= 2:
raise ValueError('Must specify at least two dicts to merge.')
func = partial(_merge_two_dicts, **kwargs)
return reduce(func, args)
def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # pylint: disable=R0912,R0914
dict_type=dict, should_normalize=True, should_merge_lists=True):
"""Merge dicts normalizing their keys."""
merged = dict_type()
base_keys = base.keys()
other_keys = other.keys()
norm = normalize if should_normalize else lambda x, y: x
base_only = []
other_only = []
both = []
union = []
for k in base_keys:
if k in other_keys:
both.append(k)
else:
base_only.append(k)
union.append(k)
for k in other_keys:
if k in base_keys:
union.append(k)
else:
union.append(k)
other_only.append(k)
for k in union:
if k in base_only:
merged[k] = norm(base[k], dict_type)
elif k in other_only:
merged[k] = norm(other[k], dict_type)
elif k in both:
base_value = base[k]
other_value = other[k]
base_type = type(base_value)
other_type = type(other_value)
if (match_types and (base_type != other_type) and
(base_value is not None) and (other_value is not None)):
raise ValueError('Type mismatch for {} got {} ({}) and {} ({})'.format(k, base_value, base_type,
other_value, other_type))
if isinstance(base_value, dict):
merged[k] = _merge_two_dicts(base_value, other_value, list_duplicates, match_types, dict_type)
elif isinstance(base_value, list):
if should_merge_lists:
merged[k] = _merge_two_lists(base_value, other_value, list_duplicates, dict_type)
else:
merged[k] = _merge_two_lists([], other_value, list_duplicates, dict_type)
elif isinstance(base_value, set):
merged[k] = norm(base_value.union(other_value), dict_type)
else:
merged[k] = norm(other_value, dict_type)
else: # Should never get here
raise AssertionError('Unexpected merge key: {}'.format(k))
return merged
def merge_lists(*args, **kwargs):
if not len(args) >= 2:
raise ValueError('Must specify at least two lists to merge.')
func = partial(_merge_two_lists, **kwargs)
return reduce(func, args)
def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: disable=R0912
"""
Merge lists, normalizing their entries.
parameters:
:base, other: the two lists to be merged. ``other`` will be merged on
top of base.
:duplicates: Indicates the strategy of handling entries that appear
in both lists. ``all`` will keep occurrences from both
lists; ``first`` will only keep occurrences from
``base``; ``last`` will only keep occurrences from
``other``;
.. note:: duplicate entries that appear in the *same* list
will never be removed.
"""
if not isiterable(base):
base = [base]
if not isiterable(other):
other = [other]
if duplicates == 'all':
merged_list = []
for v in normalize(base, dict_type) + normalize(other, dict_type):
if not _check_remove_item(merged_list, v):
merged_list.append(v)
return merged_list
elif duplicates == 'first':
base_norm = normalize(base, dict_type)
merged_list = normalize(base, dict_type)
for v in base_norm:
_check_remove_item(merged_list, v)
for v in normalize(other, dict_type):
if not _check_remove_item(merged_list, v):
if v not in base_norm:
merged_list.append(v) # pylint: disable=no-member
return merged_list
elif duplicates == 'last':
other_norm = normalize(other, dict_type)
merged_list = []
for v in normalize(base, dict_type):
if not _check_remove_item(merged_list, v):
if v not in other_norm:
merged_list.append(v)
for v in other_norm:
if not _check_remove_item(merged_list, v):
merged_list.append(v)
return merged_list
else:
raise ValueError('Unexpected value for list duplicates argument: {}. '.format(duplicates) +
'Must be in {"all", "first", "last"}.')
def _check_remove_item(the_list, item):
"""Helper function for merge_lists that implements checking wether an items
should be removed from the list and doing so if needed. Returns ``True`` if
the item has been removed and ``False`` otherwise."""
if not isinstance(item, basestring):
return False
if not item.startswith('~'):
return False
actual_item = item[1:]
if actual_item in the_list:
del the_list[the_list.index(actual_item)]
return True
def normalize(value, dict_type=dict):
"""Normalize values. Recursively normalizes dict keys to be lower case,
no surrounding whitespace, underscore-delimited strings."""
if isinstance(value, dict):
normalized = dict_type()
for k, v in value.iteritems():
key = k.strip().lower().replace(' ', '_')
normalized[key] = normalize(v, dict_type)
return normalized
elif isinstance(value, list):
return [normalize(v, dict_type) for v in value]
elif isinstance(value, tuple):
return tuple([normalize(v, dict_type) for v in value])
else:
return value
def convert_new_lines(text):
""" Convert new lines to a common format. """
return text.replace('\r\n', '\n').replace('\r', '\n')
def escape_quotes(text):
"""Escape quotes, and escaped quotes, in the specified text."""
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\\\'').replace('\"', '\\\"')
def escape_single_quotes(text):
"""Escape single quotes, and escaped single quotes, in the specified text."""
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\'\\\'\'')
def escape_double_quotes(text):
"""Escape double quotes, and escaped double quotes, in the specified text."""
return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\"', '\\\"')
def getch(count=1):
"""Read ``count`` characters from standard input."""
if os.name == 'nt':
import msvcrt # pylint: disable=F0401
return ''.join([msvcrt.getch() for _ in xrange(count)])
else: # assume Unix
import tty # NOQA
import termios # NOQA
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(count)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch
def isiterable(obj):
"""Returns ``True`` if the specified object is iterable and
*is not a string type*, ``False`` otherwise."""
return hasattr(obj, '__iter__') and not isinstance(obj, basestring)
def as_relative(path):
"""Convert path to relative by stripping away the leading '/' on UNIX or
the equivant on other platforms."""
path = os.path.splitdrive(path)[1]
return path.lstrip(os.sep)
def get_cpu_mask(cores):
"""Return a string with the hex for the cpu mask for the specified core numbers."""
mask = 0
for i in cores:
mask |= 1 << i
return '0x{0:x}'.format(mask)
def which(name):
"""Platform-independent version of UNIX which utility."""
if os.name == 'nt':
paths = os.getenv('PATH').split(os.pathsep)
exts = os.getenv('PATHEXT').split(os.pathsep)
for path in paths:
testpath = os.path.join(path, name)
if os.path.isfile(testpath):
return testpath
for ext in exts:
testpathext = testpath + ext
if os.path.isfile(testpathext):
return testpathext
return None
else: # assume UNIX-like
try:
return check_output(['which', name])[0].strip() # pylint: disable=E1103
except subprocess.CalledProcessError:
return None
_bash_color_regex = re.compile('\x1b\\[[0-9;]+m')
def strip_bash_colors(text):
return _bash_color_regex.sub('', text)
def get_random_string(length):
"""Returns a random ASCII string of the specified length)."""
return ''.join(random.choice(string.ascii_letters + string.digits) for _ in xrange(length))
class LoadSyntaxError(Exception):
def __init__(self, message, filepath, lineno):
super(LoadSyntaxError, self).__init__(message)
self.filepath = filepath
self.lineno = lineno
def __str__(self):
message = 'Syntax Error in {}, line {}:\n\t{}'
return message.format(self.filepath, self.lineno, self.message)
RAND_MOD_NAME_LEN = 30
BAD_CHARS = string.punctuation + string.whitespace
TRANS_TABLE = string.maketrans(BAD_CHARS, '_' * len(BAD_CHARS))
def to_identifier(text):
"""Converts text to a valid Python identifier by replacing all
whitespace and punctuation."""
return re.sub('_+', '_', text.translate(TRANS_TABLE))
def unique(alist):
"""
Returns a list containing only unique elements from the input list (but preserves
order, unlike sets).
"""
result = []
for item in alist:
if item not in result:
result.append(item)
return result
def ranges_to_list(ranges_string):
"""Converts a sysfs-style ranges string, e.g. ``"0,2-4"``, into a list ,e.g ``[0,2,3,4]``"""
values = []
for rg in ranges_string.split(','):
if '-' in rg:
first, last = map(int, rg.split('-'))
values.extend(xrange(first, last + 1))
else:
values.append(int(rg))
return values
def list_to_ranges(values):
"""Converts a list, e.g ``[0,2,3,4]``, into a sysfs-style ranges string, e.g. ``"0,2-4"``"""
range_groups = []
for _, g in groupby(enumerate(values), lambda (i, x): i - x):
range_groups.append(map(itemgetter(1), g))
range_strings = []
for group in range_groups:
if len(group) == 1:
range_strings.append(str(group[0]))
else:
range_strings.append('{}-{}'.format(group[0], group[-1]))
return ','.join(range_strings)
def list_to_mask(values, base=0x0):
"""Converts the specified list of integer values into
a bit mask for those values. Optinally, the list can be
applied to an existing mask."""
for v in values:
base |= (1 << v)
return base
def mask_to_list(mask):
"""Converts the specfied integer bitmask into a list of
indexes of bits that are set in the mask."""
size = len(bin(mask)) - 2 # because of "0b"
return [size - i - 1 for i in xrange(size)
if mask & (1 << size - i - 1)]
__memo_cache = {}
def memoized(func):
"""A decorator for memoizing functions and methods."""
func_id = repr(func)
def memoize_wrapper(*args, **kwargs):
id_string = func_id + ','.join([str(id(a)) for a in args])
id_string += ','.join('{}={}'.format(k, v)
for k, v in kwargs.iteritems())
if id_string not in __memo_cache:
__memo_cache[id_string] = func(*args, **kwargs)
return __memo_cache[id_string]
return memoize_wrapper

107
devlib/utils/serial_port.py Normal file
View File

@@ -0,0 +1,107 @@
# Copyright 2013-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from contextlib import contextmanager
from logging import Logger
import serial
import fdpexpect
# Adding pexpect exceptions into this module's namespace
from pexpect import EOF, TIMEOUT # NOQA pylint: disable=W0611
from devlib.exception import HostError
def pulse_dtr(conn, state=True, duration=0.1):
"""Set the DTR line of the specified serial connection to the specified state
for the specified duration (note: the initial state of the line is *not* checked."""
conn.setDTR(state)
time.sleep(duration)
conn.setDTR(not state)
def get_connection(timeout, init_dtr=None, logcls=Logger,
*args, **kwargs):
if init_dtr is not None:
kwargs['dsrdtr'] = True
try:
conn = serial.Serial(*args, **kwargs)
except serial.SerialException as e:
raise HostError(e.message)
if init_dtr is not None:
conn.setDTR(init_dtr)
conn.nonblocking()
conn.flushOutput()
target = fdpexpect.fdspawn(conn.fileno(), timeout=timeout)
target.logfile_read = logcls('read')
target.logfile_send = logcls('send')
# Monkey-patching sendline to introduce a short delay after
# chacters are sent to the serial. If two sendline s are issued
# one after another the second one might start putting characters
# into the serial device before the first one has finished, causing
# corruption. The delay prevents that.
tsln = target.sendline
def sendline(x):
tsln(x)
time.sleep(0.1)
target.sendline = sendline
return target, conn
def write_characters(conn, line, delay=0.05):
"""Write a single line out to serial charcter-by-character. This will ensure that nothing will
be dropped for longer lines."""
line = line.rstrip('\r\n')
for c in line:
conn.send(c)
time.sleep(delay)
conn.sendline('')
@contextmanager
def open_serial_connection(timeout, get_conn=False, init_dtr=None,
logcls=Logger, *args, **kwargs):
"""
Opens a serial connection to a device.
:param timeout: timeout for the fdpexpect spawn object.
:param conn: ``bool`` that specfies whether the underlying connection
object should be yielded as well.
:param init_dtr: specifies the initial DTR state stat should be set.
All arguments are passed into the __init__ of serial.Serial. See
pyserial documentation for details:
http://pyserial.sourceforge.net/pyserial_api.html#serial.Serial
:returns: a pexpect spawn object connected to the device.
See: http://pexpect.sourceforge.net/pexpect.html
"""
target, conn = get_connection(timeout, init_dtr=init_dtr,
logcls=logcls, *args, **kwargs)
if get_conn:
yield target, conn
else:
yield target
target.close() # Closes the file descriptor used by the conn.
del conn

261
devlib/utils/ssh.py Normal file
View File

@@ -0,0 +1,261 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import stat
import logging
import subprocess
import re
import threading
import tempfile
import shutil
import pxssh
from pexpect import EOF, TIMEOUT, spawn
from devlib.exception import HostError, TargetError, TimeoutError
from devlib.utils.misc import which, strip_bash_colors, escape_single_quotes, check_output
ssh = None
scp = None
sshpass = None
logger = logging.getLogger('ssh')
def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False):
_check_env()
if telnet:
if keyfile:
raise ValueError('keyfile may not be used with a telnet connection.')
conn = TelnetConnection()
else: # ssh
conn = pxssh.pxssh()
try:
if keyfile:
conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout)
else:
conn.login(host, username, password, port=port, login_timeout=timeout)
except EOF:
raise TargetError('Could not connect to {}; is the host name correct?'.format(host))
return conn
class TelnetConnection(pxssh.pxssh):
# pylint: disable=arguments-differ
def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10,
auto_prompt_reset=True, sync_multiplier=1):
cmd = 'telnet -l {} {}'.format(username, server)
spawn._spawn(self, cmd) # pylint: disable=protected-access
i = self.expect('(?i)(?:password)', timeout=login_timeout)
if i == 0:
self.sendline(password)
i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout)
else:
raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt')
if i:
raise pxssh.ExceptionPxssh('could not log in: password was incorrect')
if not self.sync_original_prompt(sync_multiplier):
self.close()
raise pxssh.ExceptionPxssh('could not synchronize with original prompt')
if auto_prompt_reset:
if not self.set_unique_prompt():
self.close()
message = 'could not set shell prompt (recieved: {}, expected: {}).'
raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT))
return True
def check_keyfile(keyfile):
"""
keyfile must have the right access premissions in order to be useable. If the specified
file doesn't, create a temporary copy and set the right permissions for that.
Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a
temporary copy with the right permissions.
"""
desired_mask = stat.S_IWUSR | stat.S_IRUSR
actual_mask = os.stat(keyfile).st_mode & 0xFF
if actual_mask != desired_mask:
tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile))
shutil.copy(keyfile, tmp_file)
os.chmod(tmp_file, desired_mask)
return tmp_file
else: # permissions on keyfile are OK
return keyfile
class SshConnection(object):
default_password_prompt = '[sudo] password'
max_cancel_attempts = 5
@property
def name(self):
return self.host
def __init__(self,
host,
username,
password=None,
keyfile=None,
port=None,
timeout=10,
telnet=False,
password_prompt=None,
):
self.host = host
self.username = username
self.password = password
self.keyfile = check_keyfile(keyfile) if keyfile else keyfile
self.port = port
self.lock = threading.Lock()
self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
logger.debug('Logging in {}@{}'.format(username, host))
self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet)
def push(self, source, dest, timeout=30):
dest = '{}@{}:{}'.format(self.username, self.host, dest)
return self._scp(source, dest, timeout)
def pull(self, source, dest, timeout=30):
source = '{}@{}:{}'.format(self.username, self.host, source)
return self._scp(source, dest, timeout)
def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
with self.lock:
output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
if check_exit_code:
exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
try:
exit_code = int(exit_code_text.split()[0])
if exit_code:
message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
raise TargetError(message.format(exit_code, command, output))
except (ValueError, IndexError):
logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
return output
def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE):
port_string = '-p {}'.format(self.port) if self.port else ''
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command)
logger.debug(command)
if self.password:
command = _give_password(self.password, command)
return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
def close(self):
logger.debug('Logging out {}@{}'.format(self.username, self.host))
self.conn.logout()
def cancel_running_command(self):
# simulate impatiently hitting ^C until command prompt appears
logger.debug('Sending ^C')
for _ in xrange(self.max_cancel_attempts):
self.conn.sendline(chr(3))
if self.conn.prompt(0.1):
return True
return False
def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True):
self.conn.prompt(0.1) # clear an existing prompt if there is one.
if as_root:
command = "sudo -- sh -c '{}'".format(escape_single_quotes(command))
if log:
logger.debug(command)
self.conn.sendline(command)
if self.password:
index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5)
if index == 0:
self.conn.sendline(self.password)
else: # not as_root
if log:
logger.debug(command)
self.conn.sendline(command)
timed_out = self._wait_for_prompt(timeout)
# the regex removes line breaks potential introduced when writing
# command to shell.
output = process_backspaces(self.conn.before)
output = re.sub(r'\r([^\n])', r'\1', output)
if '\r\n' in output: # strip the echoed command
output = output.split('\r\n', 1)[1]
if timed_out:
self.cancel_running_command()
raise TimeoutError(command, output)
if strip_colors:
output = strip_bash_colors(output)
return output
def _wait_for_prompt(self, timeout=None):
if timeout:
return not self.conn.prompt(timeout)
else: # cannot timeout; wait forever
while not self.conn.prompt(1):
pass
return False
def _scp(self, source, dest, timeout=30):
# NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely)
# fails to connect to a device if port is explicitly specified using -P
# option, even if it is the default port, 22. To minimize this problem,
# only specify -P for scp if the port is *not* the default.
port_string = '-P {}'.format(self.port) if (self.port and self.port != 22) else ''
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest)
pass_string = ''
logger.debug(command)
if self.password:
command = _give_password(self.password, command)
try:
check_output(command, timeout=timeout, shell=True)
except subprocess.CalledProcessError as e:
raise subprocess.CalledProcessError(e.returncode, e.cmd.replace(pass_string, ''), e.output)
except TimeoutError as e:
raise TimeoutError(e.command.replace(pass_string, ''), e.output)
def _give_password(password, command):
if not sshpass:
raise HostError('Must have sshpass installed on the host in order to use password-based auth.')
pass_string = "sshpass -p '{}' ".format(password)
return pass_string + command
def _check_env():
global ssh, scp, sshpass # pylint: disable=global-statement
if not ssh:
ssh = which('ssh')
scp = which('scp')
sshpass = which('sshpass')
if not (ssh and scp):
raise HostError('OpenSSH must be installed on the host.')
def process_backspaces(text):
chars = []
for c in text:
if c == chr(8) and chars: # backspace
chars.pop()
else:
chars.append(c)
return ''.join(chars)

113
devlib/utils/types.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Routines for doing various type conversions. These usually embody some higher-level
semantics than are present in standard Python types (e.g. ``boolean`` will convert the
string ``"false"`` to ``False``, where as non-empty strings are usually considered to be
``True``).
A lot of these are intened to stpecify type conversions declaratively in place like
``Parameter``'s ``kind`` argument. These are basically "hacks" around the fact that Python
is not the best language to use for configuration.
"""
import math
from devlib.utils.misc import isiterable, to_identifier, ranges_to_list, list_to_mask
def identifier(text):
"""Converts text to a valid Python identifier by replacing all
whitespace and punctuation."""
return to_identifier(text)
def boolean(value):
"""
Returns bool represented by the value. This is different from
calling the builtin bool() in that it will interpret string representations.
e.g. boolean('0') and boolean('false') will both yield False.
"""
false_strings = ['', '0', 'n', 'no', 'off']
if isinstance(value, basestring):
value = value.lower()
if value in false_strings or 'false'.startswith(value):
return False
return bool(value)
def integer(value):
"""Handles conversions for string respresentations of binary, octal and hex."""
if isinstance(value, basestring):
return int(value, 0)
else:
return int(value)
def numeric(value):
"""
Returns the value as number (int if possible, or float otherwise), or
raises ``ValueError`` if the specified ``value`` does not have a straight
forward numeric conversion.
"""
if isinstance(value, int):
return value
try:
fvalue = float(value)
except ValueError:
raise ValueError('Not numeric: {}'.format(value))
if not math.isnan(fvalue) and not math.isinf(fvalue):
ivalue = int(fvalue)
if ivalue == fvalue: # yeah, yeah, I know. Whatever. This is best-effort.
return ivalue
return fvalue
class caseless_string(str):
"""
Just like built-in Python string except case-insensitive on comparisons. However, the
case is preserved otherwise.
"""
def __eq__(self, other):
if isinstance(other, basestring):
other = other.lower()
return self.lower() == other
def __ne__(self, other):
return not self.__eq__(other)
def __cmp__(self, other):
if isinstance(basestring, other):
other = other.lower()
return cmp(self.lower(), other)
def format(self, *args, **kwargs):
return caseless_string(super(caseless_string, self).format(*args, **kwargs))
def bitmask(value):
if isinstance(value, basestring):
value = ranges_to_list(value)
if isiterable(value):
value = list_to_mask(value)
if not isinstance(value, int):
raise ValueError(value)
return value

116
devlib/utils/uboot.py Normal file
View File

@@ -0,0 +1,116 @@
#
# Copyright 2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import time
import logging
from devlib.utils.serial_port import TIMEOUT
logger = logging.getLogger('U-Boot')
class UbootMenu(object):
"""
Allows navigating Das U-boot menu over serial (it relies on a pexpect connection).
"""
option_regex = re.compile(r'^\[(\d+)\]\s+([^\r]+)\r\n', re.M)
prompt_regex = re.compile(r'^([^\r\n]+):\s*', re.M)
invalid_regex = re.compile(r'Invalid input \(max (\d+)\)', re.M)
load_delay = 1 # seconds
default_timeout = 60 # seconds
def __init__(self, conn, start_prompt='Hit any key to stop autoboot'):
"""
:param conn: A serial connection as returned by ``pexect.spawn()``.
:param prompt: U-Boot menu prompt
:param start_prompt: The starting prompt to wait for during ``open()``.
"""
self.conn = conn
self.conn.crlf = '\n\r' # TODO: this has *got* to be a bug in U-Boot...
self.start_prompt = start_prompt
self.options = {}
self.prompt = None
def open(self, timeout=default_timeout):
"""
"Open" the UEFI menu by sending an interrupt on STDIN after seeing the
starting prompt (configurable upon creation of the ``UefiMenu`` object.
"""
self.conn.expect(self.start_prompt, timeout)
self.conn.sendline('')
time.sleep(self.load_delay)
self.conn.readline() # garbage
self.conn.sendline('')
self.prompt = self.conn.readline().strip()
def getenv(self):
output = self.enter('printenv')
result = {}
for line in output.split('\n'):
if '=' in line:
variable, value = line.split('=', 1)
result[variable.strip()] = value.strip()
return result
def setenv(self, variable, value, force=False):
force_str = ' -f' if force else ''
if value is not None:
command = 'setenv{} {} {}'.format(force_str, variable, value)
else:
command = 'setenv{} {}'.format(force_str, variable)
return self.enter(command)
def boot(self):
self.write_characters('boot')
def nudge(self):
"""Send a little nudge to ensure there is something to read. This is useful when you're not
sure if all out put from the serial has been read already."""
self.enter('')
def enter(self, value, delay=load_delay):
"""Like ``select()`` except no resolution is performed -- the value is sent directly
to the serial connection."""
# Empty the buffer first, so that only response to the input about to
# be sent will be processed by subsequent commands.
value = str(value)
self.empty_buffer()
self.write_characters(value)
self.conn.expect(self.prompt, timeout=delay)
return self.conn.before
def write_characters(self, line):
line = line.rstrip('\r\n')
for c in line:
self.conn.send(c)
time.sleep(0.05)
self.conn.sendline('')
def empty_buffer(self):
try:
while True:
time.sleep(0.1)
self.conn.read_nonblocking(size=1024, timeout=0.1)
except TIMEOUT:
pass
self.conn.buffer = ''

239
devlib/utils/uefi.py Normal file
View File

@@ -0,0 +1,239 @@
# Copyright 2014-2015 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import time
import logging
from copy import copy
from devlib.utils.serial_port import write_characters, TIMEOUT
from devlib.utils.types import boolean
logger = logging.getLogger('UEFI')
class UefiConfig(object):
def __init__(self, config_dict):
if isinstance(config_dict, UefiConfig):
self.__dict__ = copy(config_dict.__dict__)
else:
try:
self.image_name = config_dict['image_name']
self.image_args = config_dict['image_args']
self.fdt_support = boolean(config_dict['fdt_support'])
except KeyError as e:
raise ValueError('Missing mandatory parameter for UEFI entry config: "{}"'.format(e))
self.initrd = config_dict.get('initrd')
self.fdt_path = config_dict.get('fdt_path')
if self.fdt_path and not self.fdt_support:
raise ValueError('FDT path has been specfied for UEFI entry, when FDT support is "False"')
class UefiMenu(object):
"""
Allows navigating UEFI menu over serial (it relies on a pexpect connection).
"""
option_regex = re.compile(r'^\[(\d+)\]\s+([^\r]+)\r\n', re.M)
prompt_regex = re.compile(r'^(\S[^\r\n]+):\s*', re.M)
invalid_regex = re.compile(r'Invalid input \(max (\d+)\)', re.M)
load_delay = 1 # seconds
default_timeout = 60 # seconds
def __init__(self, conn, prompt='The default boot selection will start in'):
"""
:param conn: A serial connection as returned by ``pexect.spawn()``.
:param prompt: The starting prompt to wait for during ``open()``.
"""
self.conn = conn
self.start_prompt = prompt
self.options = {}
self.prompt = None
self.attempting_invalid_retry = False
def wait(self, timeout=default_timeout):
"""
"Open" the UEFI menu by sending an interrupt on STDIN after seeing the
starting prompt (configurable upon creation of the ``UefiMenu`` object.
"""
self.conn.expect(self.start_prompt, timeout)
self.connect()
def connect(self, timeout=default_timeout):
self.nudge()
time.sleep(self.load_delay)
self.read_menu(timeout=timeout)
def create_entry(self, name, config):
"""Create a new UEFI entry using the parameters. The menu is assumed
to be at the top level. Upon return, the menu will be at the top level."""
logger.debug('Creating UEFI entry {}'.format(name))
self.nudge()
self.select('Boot Manager')
self.select('Add Boot Device Entry')
self.select('NOR Flash')
self.enter(config.image_name)
self.enter('y' if config.fdt_support else 'n')
if config.initrd:
self.enter('y')
self.enter(config.initrd)
else:
self.enter('n')
self.enter(config.image_args)
self.enter(name)
if config.fdt_path:
self.select('Update FDT path')
self.enter(config.fdt_path)
self.select('Return to main menu')
def delete_entry(self, name):
"""Delete the specified UEFI entry. The menu is assumed
to be at the top level. Upon return, the menu will be at the top level."""
logger.debug('Removing UEFI entry {}'.format(name))
self.nudge()
self.select('Boot Manager')
self.select('Remove Boot Device Entry')
self.select(name)
self.select('Return to main menu')
def select(self, option, timeout=default_timeout):
"""
Select the specified option from the current menu.
:param option: Could be an ``int`` index of the option, or a string/regex to
match option text against.
:param timeout: If a non-``int`` option is specified, the option list may need
need to be parsed (if it hasn't been already), this may block
and the timeout is used to cap that , resulting in a ``TIMEOUT``
exception.
:param delay: A fixed delay to wait after sending the input to the serial connection.
This should be set if input this action is known to result in a
long-running operation.
"""
if isinstance(option, basestring):
option = self.get_option_index(option, timeout)
self.enter(option)
def enter(self, value, delay=load_delay):
"""Like ``select()`` except no resolution is performed -- the value is sent directly
to the serial connection."""
# Empty the buffer first, so that only response to the input about to
# be sent will be processed by subsequent commands.
value = str(value)
self._reset()
write_characters(self.conn, value)
# TODO: in case the value is long an complicated, things may get
# screwed up (e.g. there may be line breaks injected), additionally,
# special chars might cause regex to fail. To avoid these issues i'm
# only matching against the first 5 chars of the value. This is
# entirely arbitrary and I'll probably have to find a better way of
# doing this at some point.
self.conn.expect(value[:5], timeout=delay)
time.sleep(self.load_delay)
def read_menu(self, timeout=default_timeout):
"""Parse serial output to get the menu options and the following prompt."""
attempting_timeout_retry = False
self.attempting_invalid_retry = False
while True:
index = self.conn.expect([self.option_regex, self.prompt_regex, self.invalid_regex, TIMEOUT],
timeout=timeout)
match = self.conn.match
if index == 0: # matched menu option
self.options[match.group(1)] = match.group(2)
elif index == 1: # matched prompt
self.prompt = match.group(1)
break
elif index == 2: # matched invalid selection
# We've sent an invalid input (which includes an empty line) at
# the top-level menu. To get back the menu options, it seems we
# need to enter what the error reports as the max + 1, so...
if not self.attempting_invalid_retry:
self.attempting_invalid_retry = True
val = int(match.group(1))
self.empty_buffer()
self.enter(val)
self.select('Return to main menu')
self.attempting_invalid_retry = False
else: # OK, that didn't work; panic!
raise RuntimeError('Could not read menu entries stuck on "{}" prompt'.format(self.prompt))
elif index == 3: # timed out
if not attempting_timeout_retry:
attempting_timeout_retry = True
self.nudge()
else: # Didn't help. Run away!
raise RuntimeError('Did not see a valid UEFI menu.')
else:
raise AssertionError('Unexpected response waiting for UEFI menu') # should never get here
def list_options(self, timeout=default_timeout):
"""Returns the menu index of the specified option text (uses regex matching). If the option
is not in the current menu, ``LookupError`` will be raised."""
if not self.prompt:
self.read_menu(timeout)
return self.options.items()
def get_option_index(self, text, timeout=default_timeout):
"""Returns the menu index of the specified option text (uses regex matching). If the option
is not in the current menu, ``LookupError`` will be raised."""
if not self.prompt:
self.read_menu(timeout)
for k, v in self.options.iteritems():
if re.search(text, v):
return k
raise LookupError(text)
def has_option(self, text, timeout=default_timeout):
"""Returns ``True`` if at least one of the options in the current menu has
matched (using regex) the specified text."""
try:
self.get_option_index(text, timeout)
return True
except LookupError:
return False
def nudge(self):
"""Send a little nudge to ensure there is something to read. This is useful when you're not
sure if all out put from the serial has been read already."""
self.enter('')
def empty_buffer(self):
"""Read everything from the serial and clear the internal pexpect buffer. This ensures
that the next ``expect()`` call will time out (unless further input will be sent to the
serial beforehand. This is used to create a "known" state and avoid unexpected matches."""
try:
while True:
time.sleep(0.1)
self.conn.read_nonblocking(size=1024, timeout=0.1)
except TIMEOUT:
pass
self.conn.buffer = ''
def _reset(self):
self.options = {}
self.prompt = None
self.empty_buffer()