mirror of
https://github.com/ARM-software/workload-automation.git
synced 2025-01-19 04:21:17 +00:00
4ea4bc8631
CalledProcessErrorWithStderr is a subclass of CalledProcessError that also takes stderr output. Both that and normal output must be passed as keyword arguments. They were being passed as keyword arguments inside _scp() of SshConnection, causing cryptic errors to appear. Additionally, "output" was not being properly poped off before invoking super's init.
279 lines
11 KiB
Python
279 lines
11 KiB
Python
# Copyright 2014-2015 ARM Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
|
|
import os
|
|
import stat
|
|
import logging
|
|
import subprocess
|
|
import re
|
|
import threading
|
|
import tempfile
|
|
import shutil
|
|
|
|
from pexpect import EOF, TIMEOUT, spawn, pxssh
|
|
|
|
from wlauto.exceptions import HostError, DeviceError, TimeoutError, ConfigError
|
|
from wlauto.utils.misc import (which, strip_bash_colors, escape_single_quotes, check_output,
|
|
CalledProcessErrorWithStderr)
|
|
|
|
ssh = None
|
|
scp = None
|
|
sshpass = None
|
|
|
|
logger = logging.getLogger('ssh')
|
|
|
|
|
|
def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False):
|
|
_check_env()
|
|
if telnet:
|
|
if keyfile:
|
|
raise ConfigError('keyfile may not be used with a telnet connection.')
|
|
conn = TelnetConnection()
|
|
else: # ssh
|
|
conn = pxssh.pxssh() # pylint: disable=redefined-variable-type
|
|
try:
|
|
if keyfile:
|
|
conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout)
|
|
else:
|
|
conn.login(host, username, password, port=port, login_timeout=timeout)
|
|
except EOF:
|
|
raise DeviceError('Could not connect to {}; is the host name correct?'.format(host))
|
|
return conn
|
|
|
|
|
|
class TelnetConnection(pxssh.pxssh):
|
|
# pylint: disable=arguments-differ
|
|
|
|
def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10,
|
|
auto_prompt_reset=True, sync_multiplier=1, port=23):
|
|
cmd = 'telnet -l {} {} {}'.format(username, server, port)
|
|
|
|
spawn._spawn(self, cmd) # pylint: disable=protected-access
|
|
try:
|
|
i = self.expect('(?i)(?:password)', timeout=login_timeout)
|
|
if i == 0:
|
|
self.sendline(password)
|
|
i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout)
|
|
if i:
|
|
raise pxssh.ExceptionPxssh('could not log in: password was incorrect')
|
|
except TIMEOUT:
|
|
if not password:
|
|
# There was no password prompt before TIMEOUT, and we didn't
|
|
# have a password to enter. Assume everything is OK.
|
|
pass
|
|
else:
|
|
raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt')
|
|
|
|
if not self.sync_original_prompt(sync_multiplier):
|
|
self.close()
|
|
raise pxssh.ExceptionPxssh('could not synchronize with original prompt')
|
|
|
|
if auto_prompt_reset:
|
|
if not self.set_unique_prompt():
|
|
self.close()
|
|
message = 'could not set shell prompt (recieved: {}, expected: {}).'
|
|
raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT))
|
|
return True
|
|
|
|
|
|
def check_keyfile(keyfile):
|
|
"""
|
|
keyfile must have the right access premissions in order to be useable. If the specified
|
|
file doesn't, create a temporary copy and set the right permissions for that.
|
|
|
|
Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a
|
|
temporary copy with the right permissions.
|
|
"""
|
|
desired_mask = stat.S_IWUSR | stat.S_IRUSR
|
|
actual_mask = os.stat(keyfile).st_mode & 0xFF
|
|
if actual_mask != desired_mask:
|
|
tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile))
|
|
shutil.copy(keyfile, tmp_file)
|
|
os.chmod(tmp_file, desired_mask)
|
|
return tmp_file
|
|
else: # permissions on keyfile are OK
|
|
return keyfile
|
|
|
|
|
|
class SshShell(object):
|
|
|
|
default_password_prompt = '[sudo] password'
|
|
max_cancel_attempts = 5
|
|
|
|
def __init__(self, password_prompt=None, timeout=10, telnet=False):
|
|
self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
|
|
self.timeout = timeout
|
|
self.telnet = telnet
|
|
self.conn = None
|
|
self.lock = threading.Lock()
|
|
self.connection_lost = False
|
|
|
|
def login(self, host, username, password=None, keyfile=None, port=None, timeout=None):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
logger.debug('Logging in {}@{}'.format(username, host))
|
|
self.host = host
|
|
self.username = username
|
|
self.password = password
|
|
self.keyfile = check_keyfile(keyfile) if keyfile else keyfile
|
|
self.port = port
|
|
timeout = self.timeout if timeout is None else timeout
|
|
self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, self.telnet)
|
|
|
|
def push_file(self, source, dest, timeout=30):
|
|
dest = '{}@{}:{}'.format(self.username, self.host, dest)
|
|
return self._scp(source, dest, timeout)
|
|
|
|
def pull_file(self, source, dest, timeout=30):
|
|
source = '{}@{}:{}'.format(self.username, self.host, source)
|
|
return self._scp(source, dest, timeout)
|
|
|
|
def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE):
|
|
port_string = '-p {}'.format(self.port) if self.port else ''
|
|
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
|
|
command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command)
|
|
logger.debug(command)
|
|
if self.password:
|
|
command = _give_password(self.password, command)
|
|
return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
|
|
|
|
def reconnect(self):
|
|
self.conn = ssh_get_shell(self.host, self.username, self.password,
|
|
self.keyfile, self.port, self.timeout, self.telnet)
|
|
|
|
def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
|
|
try:
|
|
with self.lock:
|
|
if self.connection_lost:
|
|
logger.debug('Attempting to reconnect...')
|
|
self.reconnect()
|
|
self.connection_lost = False
|
|
output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
|
|
if check_exit_code:
|
|
exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
|
|
try:
|
|
exit_code = int(exit_code_text.split()[0])
|
|
if exit_code:
|
|
message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
|
|
raise DeviceError(message.format(exit_code, command, output))
|
|
except (ValueError, IndexError):
|
|
logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
|
|
return output
|
|
except EOF:
|
|
self.connection_lost = True
|
|
raise DeviceError('Connection dropped.')
|
|
|
|
def logout(self):
|
|
logger.debug('Logging out {}@{}'.format(self.username, self.host))
|
|
self.conn.logout()
|
|
|
|
def cancel_running_command(self):
|
|
# simulate impatiently hitting ^C until command prompt appears
|
|
logger.debug('Sending ^C')
|
|
for _ in xrange(self.max_cancel_attempts):
|
|
self.conn.sendline(chr(3))
|
|
if self.conn.prompt(0.1):
|
|
return True
|
|
return False
|
|
|
|
def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True):
|
|
self.conn.prompt(0.1) # clear an existing prompt if there is one.
|
|
if as_root:
|
|
command = "sudo -- sh -c '{}'".format(escape_single_quotes(command))
|
|
if log:
|
|
logger.debug(command)
|
|
self.conn.sendline(command)
|
|
index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5)
|
|
if index == 0:
|
|
self.conn.sendline(self.password)
|
|
timed_out = self._wait_for_prompt(timeout)
|
|
output = re.sub(r' \r([^\n])', r'\1', self.conn.before)
|
|
output = process_backspaces(output)
|
|
output = re.sub(r'.*?{}'.format(re.escape(command)), '', output, 1).strip()
|
|
else:
|
|
if log:
|
|
logger.debug(command)
|
|
self.conn.sendline(command)
|
|
timed_out = self._wait_for_prompt(timeout)
|
|
# the regex removes line breaks potential introduced when writing
|
|
# command to shell.
|
|
output = re.sub(r' \r([^\n])', r'\1', self.conn.before)
|
|
output = process_backspaces(output)
|
|
command_index = output.find(command)
|
|
output = output[command_index + len(command):].strip()
|
|
if timed_out:
|
|
self.cancel_running_command()
|
|
raise TimeoutError(command, output)
|
|
if strip_colors:
|
|
output = strip_bash_colors(output)
|
|
return output
|
|
|
|
def _wait_for_prompt(self, timeout=None):
|
|
if timeout:
|
|
return not self.conn.prompt(timeout)
|
|
else: # cannot timeout; wait forever
|
|
while not self.conn.prompt(self.timeout):
|
|
pass
|
|
return False
|
|
|
|
def _scp(self, source, dest, timeout=30):
|
|
# NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely)
|
|
# fails to connect to a device if port is explicitly specified using -P
|
|
# option, even if it is the default port, 22. To minimize this problem,
|
|
# only specify -P for scp if the port is *not* the default.
|
|
port_string = '-P {}'.format(self.port) if (self.port and self.port != 22) else ''
|
|
keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
|
|
command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest)
|
|
pass_string = ''
|
|
logger.debug(command)
|
|
if self.password:
|
|
command = _give_password(self.password, command)
|
|
try:
|
|
check_output(command, timeout=timeout, shell=True)
|
|
except subprocess.CalledProcessError as e:
|
|
raise CalledProcessErrorWithStderr(e.returncode,
|
|
e.cmd.replace(pass_string, ''),
|
|
output=e.output,
|
|
error=getattr(e, 'error', ''))
|
|
except TimeoutError as e:
|
|
raise TimeoutError(e.command.replace(pass_string, ''), e.output)
|
|
|
|
|
|
def _give_password(password, command):
|
|
if not sshpass:
|
|
raise HostError('Must have sshpass installed on the host in order to use password-based auth.')
|
|
pass_string = "sshpass -p '{}' ".format(password)
|
|
return pass_string + command
|
|
|
|
|
|
def _check_env():
|
|
global ssh, scp, sshpass # pylint: disable=global-statement
|
|
if not ssh:
|
|
ssh = which('ssh')
|
|
scp = which('scp')
|
|
sshpass = which('sshpass')
|
|
if not (ssh and scp):
|
|
raise HostError('OpenSSH must be installed on the host.')
|
|
|
|
|
|
def process_backspaces(text):
|
|
chars = []
|
|
for c in text:
|
|
if c == chr(8) and chars: # backspace
|
|
chars.pop()
|
|
else:
|
|
chars.append(c)
|
|
return ''.join(chars)
|