# Copyright 2014-2015 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import subprocess import re import threading import pxssh from pexpect import EOF, TIMEOUT, spawn from wlauto.exceptions import HostError, DeviceError, TimeoutError, ConfigError from wlauto.utils.misc import which, strip_bash_colors, escape_single_quotes, check_output ssh = None scp = None sshpass = None logger = logging.getLogger('ssh') def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False): _check_env() if telnet: if keyfile: raise ConfigError('keyfile may not be used with a telnet connection.') conn = TelnetConnection() else: # ssh conn = pxssh.pxssh() try: if keyfile: conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout) else: conn.login(host, username, password, port=port, login_timeout=timeout) except EOF: raise DeviceError('Could not connect to {}; is the host name correct?'.format(host)) return conn class TelnetConnection(pxssh.pxssh): # pylint: disable=arguments-differ def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10, auto_prompt_reset=True, sync_multiplier=1): cmd = 'telnet -l {} {}'.format(username, server) spawn._spawn(self, cmd) # pylint: disable=protected-access i = self.expect('(?i)(?:password)', timeout=login_timeout) if i == 0: self.sendline(password) i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout) else: raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt') if i: raise pxssh.ExceptionPxssh('could not log in: password was incorrect') if not self.sync_original_prompt(sync_multiplier): self.close() raise pxssh.ExceptionPxssh('could not synchronize with original prompt') if auto_prompt_reset: if not self.set_unique_prompt(): self.close() message = 'could not set shell prompt (recieved: {}, expected: {}).' raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT)) return True class SshShell(object): default_password_prompt = '[sudo] password' def __init__(self, password_prompt=None, timeout=10): self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt self.timeout = timeout self.conn = None self.lock = threading.Lock() def login(self, host, username, password=None, keyfile=None, port=None, timeout=None, telnet=False): # pylint: disable=attribute-defined-outside-init logger.debug('Logging in {}@{}'.format(username, host)) self.host = host self.username = username self.password = password self.keyfile = keyfile self.port = port timeout = self.timeout if timeout is None else timeout self.conn = ssh_get_shell(host, username, password, keyfile, port, timeout, telnet) def push_file(self, source, dest, timeout=30): dest = '{}@{}:{}'.format(self.username, self.host, dest) return self._scp(source, dest, timeout) def pull_file(self, source, dest, timeout=30): source = '{}@{}:{}'.format(self.username, self.host, source) return self._scp(source, dest, timeout) def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE): port_string = '-p {}'.format(self.port) if self.port else '' keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else '' command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command) logger.debug(command) if self.password: command = _give_password(self.password, command) return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True): with self.lock: output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors) if check_exit_code: exit_code = int(self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)) if exit_code: message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}' raise DeviceError(message.format(exit_code, command, output)) return output def logout(self): logger.debug('Logging out {}@{}'.format(self.username, self.host)) self.conn.logout() def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True): timeout = self.timeout if timeout is None else timeout if as_root: command = "sudo -- sh -c '{}'".format(escape_single_quotes(command)) if log: logger.debug(command) self.conn.sendline(command) index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5) if index == 0: self.conn.sendline(self.password) timed_out = not self.conn.prompt(timeout) output = re.sub(r'.*?{}'.format(re.escape(command)), '', self.conn.before, 1).strip() else: if log: logger.debug(command) self.conn.sendline(command) timed_out = not self.conn.prompt(timeout) # the regex removes line breaks potentiall introduced when writing # command to shell. output = re.sub(r' \r([^\n])', r'\1', self.conn.before) command_index = output.find(command) while not timed_out and command_index == -1: # In case of a "premature" timeout (i.e. timeout, but no hang, # so command completes afterwards), there may be a prompt from # the previous command completion in the serial output. This # checks for this case by making sure that the original command # is present in the serial output and waiting for the next # prompt if it is not. output = re.sub(r' \r([^\n])', r'\1', self.conn.before) command_index = output.find(command) output = output[command_index + len(command):].strip() if timed_out: raise TimeoutError(command, output) if strip_colors: output = strip_bash_colors(output) return output def _scp(self, source, dest, timeout=30): # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely) # fails to connect to a device if port is explicitly specified using -P # option, even if it is the default port, 22. To minimize this problem, # only specify -P for scp if the port is *not* the default. port_string = '-P {}'.format(self.port) if (self.port and self.port != 22) else '' keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else '' command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest) pass_string = '' logger.debug(command) if self.password: command = _give_password(self.password, command) try: check_output(command, timeout=timeout, shell=True) except subprocess.CalledProcessError as e: raise subprocess.CalledProcessError(e.returncode, e.cmd.replace(pass_string, ''), e.output) except TimeoutError as e: raise TimeoutError(e.command.replace(pass_string, ''), e.output) def _give_password(password, command): if not sshpass: raise HostError('Must have sshpass installed on the host in order to use password-based auth.') pass_string = "sshpass -p '{}' ".format(password) return pass_string + command def _check_env(): global ssh, scp, sshpass # pylint: disable=global-statement if not ssh: ssh = which('ssh') scp = which('scp') sshpass = which('sshpass') if not (ssh and scp): raise HostError('OpenSSH must be installed on the host.')