1
0
mirror of https://github.com/ARM-software/workload-automation.git synced 2025-01-19 04:21:17 +00:00

ssh: ensure keyfile has the right permissions

The key file must only be readable by the owner. If the specified key
file has different access permissions, create a temporary copy with the
right permissions and use that.
This commit is contained in:
Sergei Trofimov 2015-05-12 11:06:21 +01:00
parent 98b259be33
commit b2981a57bc

View File

@ -14,10 +14,14 @@
# #
import os
import stat
import logging import logging
import subprocess import subprocess
import re import re
import threading import threading
import tempfile
import shutil
import pxssh import pxssh
from pexpect import EOF, TIMEOUT, spawn from pexpect import EOF, TIMEOUT, spawn
@ -81,6 +85,25 @@ class TelnetConnection(pxssh.pxssh):
return True return True
def check_keyfile(keyfile):
"""
keyfile must have the right access premissions in order to be useable. If the specified
file doesn't, create a temporary copy and set the right permissions for that.
Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a
temporary copy with the right permissions.
"""
desired_mask = stat.S_IWUSR | stat.S_IRUSR
actual_mask = os.stat(keyfile).st_mode & 0xFF
if actual_mask != desired_mask:
tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile))
shutil.copy(keyfile, tmp_file)
os.chmod(tmp_file, desired_mask)
return tmp_file
else: # permissions on keyfile are OK
return keyfile
class SshShell(object): class SshShell(object):
default_password_prompt = '[sudo] password' default_password_prompt = '[sudo] password'
@ -97,10 +120,10 @@ class SshShell(object):
self.host = host self.host = host
self.username = username self.username = username
self.password = password self.password = password
self.keyfile = keyfile self.keyfile = check_keyfile(keyfile)
self.port = port self.port = port
timeout = self.timeout if timeout is None else timeout timeout = self.timeout if timeout is None else timeout
self.conn = ssh_get_shell(host, username, password, keyfile, port, timeout, telnet) self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet)
def push_file(self, source, dest, timeout=30): def push_file(self, source, dest, timeout=30):
dest = '{}@{}:{}'.format(self.username, self.host, dest) dest = '{}@{}:{}'.format(self.username, self.host, dest)