diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 24c1c2e..f95aade 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -296,8 +296,56 @@ class SshConnectionBase(ConnectionBase): self.sudo_cmd = sanitize_cmd_template(sudo_cmd) self.platform = platform self.strict_host_check = strict_host_check + self.options = {} logger.debug('Logging in {}@{}'.format(username, host)) + def push(self, source, dest, timeout=30): + dest = '{}@{}:{}'.format(self.username, self.host, dest) + return self._scp(source, dest, timeout) + + def pull(self, source, dest, timeout=30): + source = '{}@{}:{}'.format(self.username, self.host, source) + return self._scp(source, dest, timeout) + + def _scp(self, source, dest, timeout=30): + # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely) + # fails to connect to a device if port is explicitly specified using -P + # option, even if it is the default port, 22. To minimize this problem, + # only specify -P for scp if the port is *not* the default. + port_string = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' + keyfile_string = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' + options = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) + command = '{} {} -r {} {} {} {}'.format(scp, + options, + keyfile_string, + port_string, + quote(source), + quote(dest)) + command_redacted = command + logger.debug(command) + if self.password: + command, command_redacted = _give_password(self.password, command) + try: + check_output(command, timeout=timeout, shell=True) + except subprocess.CalledProcessError as e: + raise_from(HostError("Failed to copy file with '{}'. Output:\n{}".format( + command_redacted, e.output)), None) + except TimeoutError as e: + raise TimeoutError(command_redacted, e.output) + + def _get_default_options(self): + if self.strict_host_check: + options = { + 'StrictHostKeyChecking': 'yes', + } + else: + options = { + 'StrictHostKeyChecking': 'no', + 'UserKnownHostsFile': '/dev/null', + } + return options + class SshConnection(SshConnectionBase): # pylint: disable=unused-argument,super-init-not-called @@ -717,16 +765,7 @@ class TelnetConnection(SshConnectionBase): strict_host_check=strict_host_check, ) - if self.strict_host_check: - options = { - 'StrictHostKeyChecking': 'yes', - } - else: - options = { - 'StrictHostKeyChecking': 'no', - 'UserKnownHostsFile': '/dev/null', - } - self.options = options + self.options = self._get_default_options() self.lock = threading.Lock() self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt @@ -736,14 +775,6 @@ class TelnetConnection(SshConnectionBase): self.conn = telnet_get_shell(host, username, password, port, timeout, original_prompt) atexit.register(self.close) - def push(self, source, dest, timeout=30): - dest = '{}@{}:{}'.format(self.username, self.host, dest) - return self._scp(source, dest, timeout) - - def pull(self, source, dest, timeout=30): - source = '{}@{}:{}'.format(self.username, self.host, source) - return self._scp(source, dest, timeout) - def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument if command == '': @@ -863,33 +894,6 @@ class TelnetConnection(SshConnectionBase): pass return False - def _scp(self, source, dest, timeout=30): - # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely) - # fails to connect to a device if port is explicitly specified using -P - # option, even if it is the default port, 22. To minimize this problem, - # only specify -P for scp if the port is *not* the default. - port_string = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' - keyfile_string = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' - options = " ".join(["-o {}={}".format(key, val) - for key, val in self.options.items()]) - command = '{} {} -r {} {} {} {}'.format(scp, - options, - keyfile_string, - port_string, - quote(source), - quote(dest)) - command_redacted = command - logger.debug(command) - if self.password: - command, command_redacted = _give_password(self.password, command) - try: - check_output(command, timeout=timeout, shell=True) - except subprocess.CalledProcessError as e: - raise_from(HostError("Failed to copy file with '{}'. Output:\n{}".format( - command_redacted, e.output)), None) - except TimeoutError as e: - raise TimeoutError(command_redacted, e.output) - def _sendline(self, command): # Workaround for https://github.com/pexpect/pexpect/issues/552 if len(command) == self._get_window_size()[1] - self._get_prompt_length():