mirror of
https://github.com/ARM-software/workload-automation.git
synced 2025-01-19 12:24:32 +00:00
242df842bc
Standard string representation of a subprocess.CalledProcessError does not include the output of the command, so it was not previsouly included in the resulting DeviceError. This commit ensures that the output is propagated, regardless of whether it came from stdout or stderr of the underlying process.
278 lines
11 KiB
Python
278 lines
11 KiB
Python
# Copyright 2014-2015 ARM Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
|
|
import os
|
|
import stat
|
|
import logging
|
|
import subprocess
|
|
import re
|
|
import threading
|
|
import tempfile
|
|
import shutil
|
|
|
|
from pexpect import EOF, TIMEOUT, spawn, pxssh
|
|
|
|
from wlauto.exceptions import HostError, DeviceError, TimeoutError, ConfigError
|
|
from wlauto.utils.misc import (which, strip_bash_colors, escape_single_quotes, check_output,
|
|
CalledProcessErrorWithStderr)
|
|
|
|
ssh = None
|
|
scp = None
|
|
sshpass = None
|
|
|
|
logger = logging.getLogger('ssh')
|
|
|
|
|
|
def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False):
|
|
_check_env()
|
|
if telnet:
|
|
if keyfile:
|
|
raise ConfigError('keyfile may not be used with a telnet connection.')
|
|
conn = TelnetConnection()
|
|
else: # ssh
|
|
conn = pxssh.pxssh() # pylint: disable=redefined-variable-type
|
|
try:
|
|
if keyfile:
|
|
conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout)
|
|
else:
|
|
conn.login(host, username, password, port=port, login_timeout=timeout)
|
|
except EOF:
|
|
raise DeviceError('Could not connect to {}; is the host name correct?'.format(host))
|
|
return conn
|
|
|
|
|
|
class TelnetConnection(pxssh.pxssh):
|
|
# pylint: disable=arguments-differ
|
|
|
|
def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10,
|
|
auto_prompt_reset=True, sync_multiplier=1, port=23):
|
|
cmd = 'telnet -l {} {} {}'.format(username, server, port)
|
|
|
|
spawn._spawn(self, cmd) # pylint: disable=protected-access
|
|
try:
|
|
i = self.expect('(?i)(?:password)', timeout=login_timeout)
|
|
if i == 0:
|
|
self.sendline(password)
|
|
i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout)
|
|
if i:
|
|
raise pxssh.ExceptionPxssh('could not log in: password was incorrect')
|
|
except TIMEOUT:
|
|
if not password:
|
|
# There was no password prompt before TIMEOUT, and we didn't
|
|
# have a password to enter. Assume everything is OK.
|
|
pass
|
|
else:
|
|
raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt')
|
|
|
|
if not self.sync_original_prompt(sync_multiplier):
|
|
self.close()
|
|
raise pxssh.ExceptionPxssh('could not synchronize with original prompt')
|
|
|
|
if auto_prompt_reset:
|
|
if not self.set_unique_prompt():
|
|
self.close()
|
|
message = 'could not set shell prompt (recieved: {}, expected: {}).'
|
|
raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT))
|
|
return True
|
|
|
|
|
|
def check_keyfile(keyfile):
|
|
"""
|
|
keyfile must have the right access premissions in order to be useable. If the specified
|
|
file doesn't, create a temporary copy and set the right permissions for that.
|
|
|
|
Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a
|
|
temporary copy with the right permissions.
|
|
"""
|
|
desired_mask = stat.S_IWUSR | stat.S_IRUSR
|
|
actual_mask = os.stat(keyfile).st_mode & 0xFF
|
|
if actual_mask != desired_mask:
|
|
tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile))
|
|
shutil.copy(keyfile, tmp_file)
|
|
os.chmod(tmp_file, desired_mask)
|
|
return tmp_file
|
|
else: # permissions on keyfile are OK
|
|
return keyfile
|
|
|
|
|
|
class SshShell(object):
|
|
|
|
default_password_prompt = '[sudo] password'
|
|
max_cancel_attempts = 5
|
|
|
|
def __init__(self, password_prompt=None, timeout=10, telnet=False):
|
|
self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
|
|
self.timeout = timeout
|
|
self.telnet = telnet
|
|
self.conn = None
|
|
self.lock = threading.Lock()
|
|
self.connection_lost = False
|
|
|
|
def login(self, host, username, password=None, keyfile=None, port=None, timeout=None):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
logger.debug('Logging in {}@{}'.format(username, host))
|
|
self.host = host
|
|
self.username = username
|
|
self.password = password
|
|
self.keyfile = check_keyfile(keyfile) if keyfile else keyfile
|
|
self.port = port
|
|
timeout = self.timeout if timeout is None else timeout
|
|
self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, self.telnet)
|
|
|
|
def push_file(self, source, dest, timeout=30):
|
|
dest = '{}@{}:{}'.format(self.username, self.host, dest)
|
|
return self._scp(source, dest, timeout)
|
|
|
|
def pull_file(self, source, dest, timeout=30):
|
|
source = '{}@{}:{}'.format(self.username, self.host, source)
|
|
return self._scp(source, dest, timeout)
|
|
|
|
def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE):
|
|
port_string = '-p {}'.format(self.port) if self.port else ''
|
|
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
|
|
command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command)
|
|
logger.debug(command)
|
|
if self.password:
|
|
command = _give_password(self.password, command)
|
|
return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
|
|
|
|
def reconnect(self):
|
|
self.conn = ssh_get_shell(self.host, self.username, self.password,
|
|
self.keyfile, self.port, self.timeout, self.telnet)
|
|
|
|
def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
|
|
try:
|
|
with self.lock:
|
|
if self.connection_lost:
|
|
logger.debug('Attempting to reconnect...')
|
|
self.reconnect()
|
|
self.connection_lost = False
|
|
output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
|
|
if check_exit_code:
|
|
exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
|
|
try:
|
|
exit_code = int(exit_code_text.split()[0])
|
|
if exit_code:
|
|
message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
|
|
raise DeviceError(message.format(exit_code, command, output))
|
|
except (ValueError, IndexError):
|
|
logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
|
|
return output
|
|
except EOF:
|
|
self.connection_lost = True
|
|
raise DeviceError('Connection dropped.')
|
|
|
|
def logout(self):
|
|
logger.debug('Logging out {}@{}'.format(self.username, self.host))
|
|
self.conn.logout()
|
|
|
|
def cancel_running_command(self):
|
|
# simulate impatiently hitting ^C until command prompt appears
|
|
logger.debug('Sending ^C')
|
|
for _ in xrange(self.max_cancel_attempts):
|
|
self.conn.sendline(chr(3))
|
|
if self.conn.prompt(0.1):
|
|
return True
|
|
return False
|
|
|
|
def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True):
|
|
self.conn.prompt(0.1) # clear an existing prompt if there is one.
|
|
if as_root:
|
|
command = "sudo -- sh -c '{}'".format(escape_single_quotes(command))
|
|
if log:
|
|
logger.debug(command)
|
|
self.conn.sendline(command)
|
|
index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5)
|
|
if index == 0:
|
|
self.conn.sendline(self.password)
|
|
timed_out = self._wait_for_prompt(timeout)
|
|
output = re.sub(r' \r([^\n])', r'\1', self.conn.before)
|
|
output = process_backspaces(output)
|
|
output = re.sub(r'.*?{}'.format(re.escape(command)), '', output, 1).strip()
|
|
else:
|
|
if log:
|
|
logger.debug(command)
|
|
self.conn.sendline(command)
|
|
timed_out = self._wait_for_prompt(timeout)
|
|
# the regex removes line breaks potential introduced when writing
|
|
# command to shell.
|
|
output = re.sub(r' \r([^\n])', r'\1', self.conn.before)
|
|
output = process_backspaces(output)
|
|
command_index = output.find(command)
|
|
output = output[command_index + len(command):].strip()
|
|
if timed_out:
|
|
self.cancel_running_command()
|
|
raise TimeoutError(command, output)
|
|
if strip_colors:
|
|
output = strip_bash_colors(output)
|
|
return output
|
|
|
|
def _wait_for_prompt(self, timeout=None):
|
|
if timeout:
|
|
return not self.conn.prompt(timeout)
|
|
else: # cannot timeout; wait forever
|
|
while not self.conn.prompt(self.timeout):
|
|
pass
|
|
return False
|
|
|
|
def _scp(self, source, dest, timeout=30):
|
|
# NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely)
|
|
# fails to connect to a device if port is explicitly specified using -P
|
|
# option, even if it is the default port, 22. To minimize this problem,
|
|
# only specify -P for scp if the port is *not* the default.
|
|
port_string = '-P {}'.format(self.port) if (self.port and self.port != 22) else ''
|
|
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
|
|
command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest)
|
|
pass_string = ''
|
|
logger.debug(command)
|
|
if self.password:
|
|
command = _give_password(self.password, command)
|
|
try:
|
|
check_output(command, timeout=timeout, shell=True)
|
|
except subprocess.CalledProcessError as e:
|
|
raise CalledProcessErrorWithStderr(e.returncode,
|
|
e.cmd.replace(pass_string, ''),
|
|
e.output, getattr(e, 'error', ''))
|
|
except TimeoutError as e:
|
|
raise TimeoutError(e.command.replace(pass_string, ''), e.output)
|
|
|
|
|
|
def _give_password(password, command):
|
|
if not sshpass:
|
|
raise HostError('Must have sshpass installed on the host in order to use password-based auth.')
|
|
pass_string = "sshpass -p '{}' ".format(password)
|
|
return pass_string + command
|
|
|
|
|
|
def _check_env():
|
|
global ssh, scp, sshpass # pylint: disable=global-statement
|
|
if not ssh:
|
|
ssh = which('ssh')
|
|
scp = which('scp')
|
|
sshpass = which('sshpass')
|
|
if not (ssh and scp):
|
|
raise HostError('OpenSSH must be installed on the host.')
|
|
|
|
|
|
def process_backspaces(text):
|
|
chars = []
|
|
for c in text:
|
|
if c == chr(8) and chars: # backspace
|
|
chars.pop()
|
|
else:
|
|
chars.append(c)
|
|
return ''.join(chars)
|