1
0
mirror of https://github.com/ARM-software/devlib.git synced 2024-10-05 18:30:50 +01:00

utils/ssh: Allow passing known_hosts path via strict_host_check value

Allow passing a known_hosts file path to strict_host_check.
This commit is contained in:
Douglas Raillard 2024-07-09 17:20:01 +01:00 committed by Marc Bonnici
parent de84a08bf8
commit 38d4796e41
2 changed files with 27 additions and 12 deletions

View File

@ -17,6 +17,7 @@
import os import os
import stat import stat
import logging import logging
from pathlib import Path
import subprocess import subprocess
import re import re
import threading import threading
@ -174,6 +175,18 @@ def _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, in
return (callback_state, exit_code) return (callback_state, exit_code)
def _resolve_known_hosts(strict_host_check):
if strict_host_check:
if isinstance(strict_host_check, (str, os.PathLike)):
path = Path(strict_host_check)
else:
path = Path('~/.ssh/known_hosts').expandvars()
else:
path = Path('/dev/null')
return str(path.resolve())
def telnet_get_shell(host, def telnet_get_shell(host,
username, username,
password=None, password=None,
@ -407,7 +420,9 @@ class SshConnection(SshConnectionBase):
with _handle_paramiko_exceptions(): with _handle_paramiko_exceptions():
client = SSHClient() client = SSHClient()
if self.strict_host_check: if self.strict_host_check:
client.load_system_host_keys() client.load_system_host_keys(_resolve_known_hosts(
self.strict_host_check
))
client.set_missing_host_key_policy(policy) client.set_missing_host_key_policy(policy)
client.connect( client.connect(
hostname=self.host, hostname=self.host,
@ -818,16 +833,12 @@ class TelnetConnection(SshConnectionBase):
return '{}@{}:{}'.format(self.username, self.host, path) return '{}@{}:{}'.format(self.username, self.host, path)
def _get_default_options(self): def _get_default_options(self):
if self.strict_host_check: check = self.strict_host_check
options = { known_hosts = _resolve_known_hosts(check)
'StrictHostKeyChecking': 'yes', return {
} 'StrictHostKeyChecking': 'yes' if check else 'no',
else: 'UserKnownHostsFile': str(known_hosts),
options = { }
'StrictHostKeyChecking': 'no',
'UserKnownHostsFile': '/dev/null',
}
return options
def push(self, sources, dest, timeout=30): def push(self, sources, dest, timeout=30):
# Quote the destination as SCP would apply globbing too # Quote the destination as SCP would apply globbing too

View File

@ -177,7 +177,11 @@ Connection Types
:param platform: Specify the platform to be used. The generic :class:`~devlib.platform.Platform` :param platform: Specify the platform to be used. The generic :class:`~devlib.platform.Platform`
class is used by default. class is used by default.
:param sudo_cmd: Specify the format of the command used to grant sudo access. :param sudo_cmd: Specify the format of the command used to grant sudo access.
:param strict_host_check: Specify the ssh connection parameter ``StrictHostKeyChecking``, :param strict_host_check: Specify the ssh connection parameter
``StrictHostKeyChecking``. If a path is passed
rather than a boolean, it will be taken for a
``known_hosts`` file. Otherwise, the default
``$HOME/.ssh/known_hosts`` will be used.
:param use_scp: Use SCP for file transfers, defaults to SFTP. :param use_scp: Use SCP for file transfers, defaults to SFTP.
:param poll_transfers: Specify whether file transfers should be polled. Polling :param poll_transfers: Specify whether file transfers should be polled. Polling
monitors the progress of file transfers and periodically monitors the progress of file transfers and periodically