diff --git a/wlauto/utils/ssh.py b/wlauto/utils/ssh.py index 2ca98186..fd23405c 100644 --- a/wlauto/utils/ssh.py +++ b/wlauto/utils/ssh.py @@ -14,10 +14,14 @@ # +import os +import stat import logging import subprocess import re import threading +import tempfile +import shutil import pxssh from pexpect import EOF, TIMEOUT, spawn @@ -81,6 +85,25 @@ class TelnetConnection(pxssh.pxssh): return True +def check_keyfile(keyfile): + """ + keyfile must have the right access premissions in order to be useable. If the specified + file doesn't, create a temporary copy and set the right permissions for that. + + Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a + temporary copy with the right permissions. + """ + desired_mask = stat.S_IWUSR | stat.S_IRUSR + actual_mask = os.stat(keyfile).st_mode & 0xFF + if actual_mask != desired_mask: + tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile)) + shutil.copy(keyfile, tmp_file) + os.chmod(tmp_file, desired_mask) + return tmp_file + else: # permissions on keyfile are OK + return keyfile + + class SshShell(object): default_password_prompt = '[sudo] password' @@ -97,10 +120,10 @@ class SshShell(object): self.host = host self.username = username self.password = password - self.keyfile = keyfile + self.keyfile = check_keyfile(keyfile) self.port = port timeout = self.timeout if timeout is None else timeout - self.conn = ssh_get_shell(host, username, password, keyfile, port, timeout, telnet) + self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet) def push_file(self, source, dest, timeout=30): dest = '{}@{}:{}'.format(self.username, self.host, dest)