mirror of
https://github.com/ARM-software/devlib.git
synced 2025-01-31 02:00:45 +00:00
connection: Rework TransferManager
* Split TransferManager and TransferHandle: * TransferManager deals with the generic monitoring. To abort a transfer, it simply cancels the transfer and raises an exception from manage(). * TransferHandle provides a way for the manager to query the state of the transfer and cancel it. It is backend-specific. * Remove most of the state in TransferManager, along with the associated background command leak etc * Use a daemonic monitor thread to behave as excpected on interpreter shutdown. * Ensure a transfer manager _always_ exists. When no management is desired, a noop object is used. This avoids using a None sentinel, which is invariably mishandled by some code leading to crashes. * Try to merge more paths in the code to uncover as many issues as possible in testing. * Fix percentage for SSHTransferHandle (transferred / (remaining + transferred) instead of transferred / remaining) * Rename total_timeout TransferManager parameter and attribute to total_transfer_timeout to match the connection name parameter.
This commit is contained in:
parent
1c5412be2f
commit
ddaa2f1621
@ -14,12 +14,11 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from weakref import WeakSet
|
from weakref import WeakSet
|
||||||
from shlex import quote
|
from shlex import quote
|
||||||
from time import monotonic
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
@ -61,13 +60,27 @@ class ConnectionBase(InitCheckpoint):
|
|||||||
"""
|
"""
|
||||||
Base class for all connections.
|
Base class for all connections.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self,
|
||||||
|
poll_transfers=False,
|
||||||
|
start_transfer_poll_delay=30,
|
||||||
|
total_transfer_timeout=3600,
|
||||||
|
transfer_poll_period=30,
|
||||||
|
):
|
||||||
self._current_bg_cmds = set()
|
self._current_bg_cmds = set()
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._close_lock = threading.Lock()
|
self._close_lock = threading.Lock()
|
||||||
self.busybox = None
|
self.busybox = None
|
||||||
self.logger = logging.getLogger('Connection')
|
self.logger = logging.getLogger('Connection')
|
||||||
|
|
||||||
|
self.transfer_manager = TransferManager(
|
||||||
|
self,
|
||||||
|
start_transfer_poll_delay=start_transfer_poll_delay,
|
||||||
|
total_transfer_timeout=total_transfer_timeout,
|
||||||
|
transfer_poll_period=transfer_poll_period,
|
||||||
|
) if poll_transfers else NoopTransferManager()
|
||||||
|
|
||||||
|
|
||||||
def cancel_running_command(self):
|
def cancel_running_command(self):
|
||||||
bg_cmds = set(self._current_bg_cmds)
|
bg_cmds = set(self._current_bg_cmds)
|
||||||
for bg_cmd in bg_cmds:
|
for bg_cmd in bg_cmds:
|
||||||
@ -405,13 +418,13 @@ class ParamikoBackgroundCommand(BackgroundCommand):
|
|||||||
b''.join(out[stderr])
|
b''.join(out[stderr])
|
||||||
)
|
)
|
||||||
|
|
||||||
start = monotonic()
|
start = time.monotonic()
|
||||||
|
|
||||||
while ret is None:
|
while ret is None:
|
||||||
# Even if ret is not None anymore, we need to drain the streams
|
# Even if ret is not None anymore, we need to drain the streams
|
||||||
ret = self.poll()
|
ret = self.poll()
|
||||||
|
|
||||||
if timeout is not None and ret is None and monotonic() - start >= timeout:
|
if timeout is not None and ret is None and time.monotonic() - start >= timeout:
|
||||||
self.cancel()
|
self.cancel()
|
||||||
_stdout, _stderr = create_out()
|
_stdout, _stderr = create_out()
|
||||||
raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr)
|
raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr)
|
||||||
@ -563,9 +576,89 @@ class AdbBackgroundCommand(BackgroundCommand):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class TransferManagerBase(ABC):
|
class TransferManager:
|
||||||
|
def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600):
|
||||||
|
self.conn = conn
|
||||||
|
self.transfer_poll_period = transfer_poll_period
|
||||||
|
self.total_transfer_timeout = total_transfer_timeout
|
||||||
|
self.start_transfer_poll_delay = start_transfer_poll_delay
|
||||||
|
|
||||||
def _pull_dest_size(self, dest):
|
self.logger = logging.getLogger('FileTransfer')
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def manage(self, sources, dest, direction, handle):
|
||||||
|
excep = None
|
||||||
|
stop_thread = threading.Event()
|
||||||
|
|
||||||
|
def monitor():
|
||||||
|
nonlocal excep
|
||||||
|
|
||||||
|
def cancel(reason):
|
||||||
|
self.logger.warning(
|
||||||
|
f'Cancelling file transfer {sources} -> {dest} due to: {reason}'
|
||||||
|
)
|
||||||
|
handle.cancel()
|
||||||
|
|
||||||
|
start_t = time.monotonic()
|
||||||
|
stop_thread.wait(self.start_transfer_poll_delay)
|
||||||
|
while not stop_thread.wait(self.transfer_poll_period):
|
||||||
|
if not handle.isactive():
|
||||||
|
cancel(reason='transfer inactive')
|
||||||
|
elif time.monotonic() - start_t > self.total_transfer_timeout:
|
||||||
|
cancel(reason='transfer timed out')
|
||||||
|
excep = TimeoutError(f'{direction}: {sources} -> {dest}')
|
||||||
|
|
||||||
|
m_thread = threading.Thread(target=monitor, daemon=True)
|
||||||
|
try:
|
||||||
|
m_thread.start()
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
stop_thread.set()
|
||||||
|
m_thread.join()
|
||||||
|
if excep is not None:
|
||||||
|
raise excep
|
||||||
|
|
||||||
|
|
||||||
|
class NoopTransferManager:
|
||||||
|
def manage(self, *args, **kwargs):
|
||||||
|
return nullcontext(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TransferHandleBase(ABC):
|
||||||
|
def __init__(self, manager):
|
||||||
|
self.manager = manager
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self):
|
||||||
|
return self.manager.logger
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def isactive(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PopenTransferHandle(TransferHandleBase):
|
||||||
|
def __init__(self, bg_cmd, dest, direction, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if direction == 'push':
|
||||||
|
sample_size = self._push_dest_size
|
||||||
|
elif direction == 'pull':
|
||||||
|
sample_size = self._pull_dest_size
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown direction: {direction}')
|
||||||
|
|
||||||
|
self.sample_size = lambda: sample_size(dest)
|
||||||
|
|
||||||
|
self.bg_cmd = bg_cmd
|
||||||
|
self.last_sample = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pull_dest_size(dest):
|
||||||
if os.path.isdir(dest):
|
if os.path.isdir(dest):
|
||||||
return sum(
|
return sum(
|
||||||
os.stat(os.path.join(dirpath, f)).st_size
|
os.stat(os.path.join(dirpath, f)).st_size
|
||||||
@ -576,148 +669,58 @@ class TransferManagerBase(ABC):
|
|||||||
return os.stat(dest).st_size
|
return os.stat(dest).st_size
|
||||||
|
|
||||||
def _push_dest_size(self, dest):
|
def _push_dest_size(self, dest):
|
||||||
cmd = '{} du -s {}'.format(quote(self.conn.busybox), quote(dest))
|
conn = self.manager.conn
|
||||||
out = self.conn.execute(cmd)
|
cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest))
|
||||||
try:
|
out = conn.execute(cmd)
|
||||||
return int(out.split()[0])
|
return int(out.split()[0])
|
||||||
except ValueError:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def __init__(self, conn, poll_period, start_transfer_poll_delay, total_timeout):
|
def cancel(self):
|
||||||
self.conn = conn
|
self.bg_cmd.cancel()
|
||||||
self.poll_period = poll_period
|
|
||||||
self.total_timeout = total_timeout
|
|
||||||
self.start_transfer_poll_delay = start_transfer_poll_delay
|
|
||||||
|
|
||||||
self.logger = logging.getLogger('FileTransfer')
|
|
||||||
self.managing = threading.Event()
|
|
||||||
self.transfer_started = threading.Event()
|
|
||||||
self.transfer_completed = threading.Event()
|
|
||||||
self.transfer_aborted = threading.Event()
|
|
||||||
|
|
||||||
self.monitor_thread = None
|
|
||||||
self.sources = None
|
|
||||||
self.dest = None
|
|
||||||
self.direction = None
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _cancel(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def cancel(self, reason=None):
|
|
||||||
msg = 'Cancelling file transfer {} -> {}'.format(self.sources, self.dest)
|
|
||||||
if reason is not None:
|
|
||||||
msg += ' due to \'{}\''.format(reason)
|
|
||||||
self.logger.warning(msg)
|
|
||||||
self.transfer_aborted.set()
|
|
||||||
self._cancel()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def isactive(self):
|
def isactive(self):
|
||||||
pass
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def manage(self, sources, dest, direction):
|
|
||||||
try:
|
try:
|
||||||
self.sources, self.dest, self.direction = sources, dest, direction
|
curr_size = self.sample_size()
|
||||||
m_thread = threading.Thread(target=self._monitor)
|
except Exception as e:
|
||||||
|
self.logger.debug(f'File size polling failed: {e}')
|
||||||
self.transfer_completed.clear()
|
return True
|
||||||
self.transfer_aborted.clear()
|
else:
|
||||||
self.transfer_started.set()
|
self.logger.debug(f'Polled file transfer, destination size: {curr_size}')
|
||||||
|
if curr_size:
|
||||||
m_thread.start()
|
active = curr_size > self.last_sample
|
||||||
yield self
|
|
||||||
except BaseException:
|
|
||||||
self.cancel(reason='exception during transfer')
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.transfer_completed.set()
|
|
||||||
self.transfer_started.set()
|
|
||||||
m_thread.join()
|
|
||||||
self.transfer_started.clear()
|
|
||||||
self.transfer_completed.clear()
|
|
||||||
self.transfer_aborted.clear()
|
|
||||||
|
|
||||||
def _monitor(self):
|
|
||||||
start_t = monotonic()
|
|
||||||
self.transfer_completed.wait(self.start_transfer_poll_delay)
|
|
||||||
while not self.transfer_completed.wait(self.poll_period):
|
|
||||||
if not self.isactive():
|
|
||||||
self.cancel(reason='transfer inactive')
|
|
||||||
elif monotonic() - start_t > self.total_timeout:
|
|
||||||
self.cancel(reason='transfer timed out')
|
|
||||||
|
|
||||||
|
|
||||||
class PopenTransferManager(TransferManagerBase):
|
|
||||||
|
|
||||||
def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
|
|
||||||
super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
|
|
||||||
self.transfer = None
|
|
||||||
self.last_sample = None
|
|
||||||
|
|
||||||
def _cancel(self):
|
|
||||||
if self.transfer:
|
|
||||||
self.transfer.cancel()
|
|
||||||
self.transfer = None
|
|
||||||
self.last_sample = None
|
|
||||||
|
|
||||||
def isactive(self):
|
|
||||||
size_fn = self._push_dest_size if self.direction == 'push' else self._pull_dest_size
|
|
||||||
curr_size = size_fn(self.dest)
|
|
||||||
self.logger.debug('Polled file transfer, destination size {}'.format(curr_size))
|
|
||||||
active = True if self.last_sample is None else curr_size > self.last_sample
|
|
||||||
self.last_sample = curr_size
|
self.last_sample = curr_size
|
||||||
return active
|
return active
|
||||||
|
# If the file is empty it will never grow in size, so we assume
|
||||||
def set_transfer_and_wait(self, popen_bg_cmd):
|
# everything is going well.
|
||||||
self.transfer = popen_bg_cmd
|
else:
|
||||||
self.last_sample = None
|
return True
|
||||||
ret = self.transfer.wait()
|
|
||||||
|
|
||||||
if ret and not self.transfer_aborted.is_set():
|
|
||||||
raise subprocess.CalledProcessError(ret, self.transfer.popen.args)
|
|
||||||
elif self.transfer_aborted.is_set():
|
|
||||||
raise TimeoutError(self.transfer.popen.args)
|
|
||||||
|
|
||||||
|
|
||||||
class SSHTransferManager(TransferManagerBase):
|
class SSHTransferHandle(TransferHandleBase):
|
||||||
|
|
||||||
|
def __init__(self, handle, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# SFTPClient or SSHClient
|
||||||
|
self.handle = handle
|
||||||
|
|
||||||
def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
|
|
||||||
super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
|
|
||||||
self.transferer = None
|
|
||||||
self.progressed = False
|
self.progressed = False
|
||||||
self.transferred = None
|
self.transferred = 0
|
||||||
self.to_transfer = None
|
self.to_transfer = 0
|
||||||
|
|
||||||
def _cancel(self):
|
def cancel(self):
|
||||||
self.transferer.close()
|
self.handle.close()
|
||||||
|
|
||||||
def isactive(self):
|
def isactive(self):
|
||||||
progressed = self.progressed
|
progressed = self.progressed
|
||||||
|
if progressed:
|
||||||
self.progressed = False
|
self.progressed = False
|
||||||
msg = 'Polled transfer: {}% [{}B/{}B]'
|
pc = (self.transferred / self.to_transfer) * 100
|
||||||
pc = format((self.transferred / self.to_transfer) * 100, '.2f')
|
self.logger.debug(
|
||||||
self.logger.debug(msg.format(pc, self.transferred, self.to_transfer))
|
f'Polled transfer: {pc:.2f}% [{self.transferred}B/{self.to_transfer}B]'
|
||||||
|
)
|
||||||
return progressed
|
return progressed
|
||||||
|
|
||||||
@contextmanager
|
def progress_cb(self, transferred, to_transfer):
|
||||||
def manage(self, sources, dest, direction, transferer):
|
|
||||||
with super().manage(sources, dest, direction):
|
|
||||||
try:
|
|
||||||
self.progressed = False
|
|
||||||
self.transferer = transferer # SFTPClient or SCPClient
|
|
||||||
yield self
|
|
||||||
except socket.error as e:
|
|
||||||
if self.transfer_aborted.is_set():
|
|
||||||
self.transfer_aborted.clear()
|
|
||||||
method = 'SCP' if self.conn.use_scp else 'SFTP'
|
|
||||||
raise TimeoutError('{} {}: {} -> {}'.format(method, self.direction, sources, self.dest))
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def progress_cb(self, to_transfer, transferred):
|
|
||||||
if self.transfer_started.is_set():
|
|
||||||
self.progressed = True
|
self.progressed = True
|
||||||
self.transferred = transferred
|
self.transferred = transferred
|
||||||
self.to_transfer = to_transfer
|
self.to_transfer = to_transfer
|
||||||
|
@ -2931,7 +2931,7 @@ class ChromeOsTarget(LinuxTarget):
|
|||||||
ssh_conn_params = ['host', 'username', 'password', 'keyfile',
|
ssh_conn_params = ['host', 'username', 'password', 'keyfile',
|
||||||
'port', 'timeout', 'sudo_cmd',
|
'port', 'timeout', 'sudo_cmd',
|
||||||
'strict_host_check', 'use_scp',
|
'strict_host_check', 'use_scp',
|
||||||
'total_timeout', 'poll_transfers',
|
'total_transfer_timeout', 'poll_transfers',
|
||||||
'start_transfer_poll_delay']
|
'start_transfer_poll_delay']
|
||||||
self.ssh_connection_settings = {}
|
self.ssh_connection_settings = {}
|
||||||
for setting in ssh_conn_params:
|
for setting in ssh_conn_params:
|
||||||
|
@ -39,7 +39,7 @@ from shlex import quote
|
|||||||
|
|
||||||
from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError
|
from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError
|
||||||
from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess
|
from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess
|
||||||
from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferManager
|
from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferHandle
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('android')
|
logger = logging.getLogger('android')
|
||||||
@ -278,26 +278,32 @@ class AdbConnection(ConnectionBase):
|
|||||||
self._connected_as_root[self.device] = state
|
self._connected_as_root[self.device] = state
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def __init__(self, device=None, timeout=None, platform=None, adb_server=None,
|
def __init__(
|
||||||
adb_as_root=False, connection_attempts=MAX_ATTEMPTS,
|
self,
|
||||||
|
device=None,
|
||||||
|
timeout=None,
|
||||||
|
platform=None,
|
||||||
|
adb_server=None,
|
||||||
|
adb_as_root=False,
|
||||||
|
connection_attempts=MAX_ATTEMPTS,
|
||||||
|
|
||||||
poll_transfers=False,
|
poll_transfers=False,
|
||||||
start_transfer_poll_delay=30,
|
start_transfer_poll_delay=30,
|
||||||
total_transfer_timeout=3600,
|
total_transfer_timeout=3600,
|
||||||
transfer_poll_period=30,):
|
transfer_poll_period=30,
|
||||||
super().__init__()
|
):
|
||||||
|
super().__init__(
|
||||||
|
poll_transfers=poll_transfers,
|
||||||
|
start_transfer_poll_delay=start_transfer_poll_delay,
|
||||||
|
total_transfer_timeout=total_transfer_timeout,
|
||||||
|
transfer_poll_period=transfer_poll_period,
|
||||||
|
)
|
||||||
self.timeout = timeout if timeout is not None else self.default_timeout
|
self.timeout = timeout if timeout is not None else self.default_timeout
|
||||||
if device is None:
|
if device is None:
|
||||||
device = adb_get_device(timeout=timeout, adb_server=adb_server)
|
device = adb_get_device(timeout=timeout, adb_server=adb_server)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.adb_server = adb_server
|
self.adb_server = adb_server
|
||||||
self.adb_as_root = adb_as_root
|
self.adb_as_root = adb_as_root
|
||||||
self.poll_transfers = poll_transfers
|
|
||||||
if poll_transfers:
|
|
||||||
transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay,
|
|
||||||
'total_timeout': total_transfer_timeout,
|
|
||||||
'poll_period': transfer_poll_period,
|
|
||||||
}
|
|
||||||
self.transfer_mgr = PopenTransferManager(self, **transfer_opts) if poll_transfers else None
|
|
||||||
lock, nr_active = AdbConnection.active_connections
|
lock, nr_active = AdbConnection.active_connections
|
||||||
with lock:
|
with lock:
|
||||||
nr_active[self.device] += 1
|
nr_active[self.device] += 1
|
||||||
@ -330,17 +336,24 @@ class AdbConnection(ConnectionBase):
|
|||||||
paths = ' '.join(map(do_quote, paths))
|
paths = ' '.join(map(do_quote, paths))
|
||||||
|
|
||||||
command = "{} {}".format(action, paths)
|
command = "{} {}".format(action, paths)
|
||||||
if timeout or not self.poll_transfers:
|
if timeout:
|
||||||
adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server)
|
adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server)
|
||||||
else:
|
else:
|
||||||
with self.transfer_mgr.manage(sources, dest, action):
|
|
||||||
bg_cmd = adb_command_background(
|
bg_cmd = adb_command_background(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
conn=self,
|
conn=self,
|
||||||
command=command,
|
command=command,
|
||||||
adb_server=self.adb_server
|
adb_server=self.adb_server
|
||||||
)
|
)
|
||||||
self.transfer_mgr.set_transfer_and_wait(bg_cmd)
|
|
||||||
|
handle = PopenTransferHandle(
|
||||||
|
manager=self.transfer_manager,
|
||||||
|
bg_cmd=bg_cmd,
|
||||||
|
dest=dest,
|
||||||
|
direction=action
|
||||||
|
)
|
||||||
|
with bg_cmd, self.transfer_manager.manage(sources, dest, action, handle):
|
||||||
|
bg_cmd.communicate()
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def execute(self, command, timeout=None, check_exit_code=False,
|
def execute(self, command, timeout=None, check_exit_code=False,
|
||||||
|
@ -59,7 +59,7 @@ from devlib.utils.misc import (which, strip_bash_colors, check_output,
|
|||||||
sanitize_cmd_template, memoized, redirect_streams)
|
sanitize_cmd_template, memoized, redirect_streams)
|
||||||
from devlib.utils.types import boolean
|
from devlib.utils.types import boolean
|
||||||
from devlib.connection import (ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand,
|
from devlib.connection import (ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand,
|
||||||
SSHTransferManager)
|
SSHTransferHandle)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p ' ' -S -- sh -c {}"
|
DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p ' ' -S -- sh -c {}"
|
||||||
@ -295,8 +295,18 @@ class SshConnectionBase(ConnectionBase):
|
|||||||
platform=None,
|
platform=None,
|
||||||
sudo_cmd=DEFAULT_SSH_SUDO_COMMAND,
|
sudo_cmd=DEFAULT_SSH_SUDO_COMMAND,
|
||||||
strict_host_check=True,
|
strict_host_check=True,
|
||||||
|
|
||||||
|
poll_transfers=False,
|
||||||
|
start_transfer_poll_delay=30,
|
||||||
|
total_transfer_timeout=3600,
|
||||||
|
transfer_poll_period=30,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
poll_transfers=poll_transfers,
|
||||||
|
start_transfer_poll_delay=start_transfer_poll_delay,
|
||||||
|
total_transfer_timeout=total_transfer_timeout,
|
||||||
|
transfer_poll_period=transfer_poll_period,
|
||||||
|
)
|
||||||
self._connected_as_root = None
|
self._connected_as_root = None
|
||||||
self.host = host
|
self.host = host
|
||||||
self.username = username
|
self.username = username
|
||||||
@ -337,24 +347,21 @@ class SshConnection(SshConnectionBase):
|
|||||||
platform=platform,
|
platform=platform,
|
||||||
sudo_cmd=sudo_cmd,
|
sudo_cmd=sudo_cmd,
|
||||||
strict_host_check=strict_host_check,
|
strict_host_check=strict_host_check,
|
||||||
|
|
||||||
|
poll_transfers=poll_transfers,
|
||||||
|
start_transfer_poll_delay=start_transfer_poll_delay,
|
||||||
|
total_transfer_timeout=total_transfer_timeout,
|
||||||
|
transfer_poll_period=transfer_poll_period,
|
||||||
)
|
)
|
||||||
self.timeout = timeout if timeout is not None else self.default_timeout
|
self.timeout = timeout if timeout is not None else self.default_timeout
|
||||||
|
|
||||||
# Allow using scp for file transfer if sftp is not supported
|
# Allow using scp for file transfer if sftp is not supported
|
||||||
self.use_scp = use_scp
|
self.use_scp = use_scp
|
||||||
self.poll_transfers=poll_transfers
|
|
||||||
if poll_transfers:
|
|
||||||
transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay,
|
|
||||||
'total_timeout': total_transfer_timeout,
|
|
||||||
'poll_period': transfer_poll_period,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.use_scp:
|
if self.use_scp:
|
||||||
logger.debug('Using SCP for file transfer')
|
logger.debug('Using SCP for file transfer')
|
||||||
else:
|
else:
|
||||||
logger.debug('Using SFTP for file transfer')
|
logger.debug('Using SFTP for file transfer')
|
||||||
|
|
||||||
self.transfer_mgr = SSHTransferManager(self, **transfer_opts) if poll_transfers else None
|
|
||||||
self.client = self._make_client()
|
self.client = self._make_client()
|
||||||
atexit.register(self.close)
|
atexit.register(self.close)
|
||||||
|
|
||||||
@ -426,12 +433,12 @@ class SshConnection(SshConnectionBase):
|
|||||||
return sftp
|
return sftp
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def _get_scp(self, timeout):
|
def _get_scp(self, timeout, callback=lambda *_: None):
|
||||||
cb = lambda _, to_transfer, transferred: self.transfer_mgr.progress_cb(to_transfer, transferred)
|
cb = lambda _, to_transfer, transferred: callback(to_transfer, transferred)
|
||||||
return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb)
|
return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb)
|
||||||
|
|
||||||
def _push_file(self, sftp, src, dst):
|
def _push_file(self, sftp, src, dst, callback):
|
||||||
sftp.put(src, dst, callback=self.transfer_mgr.progress_cb)
|
sftp.put(src, dst, callback=callback)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _path_exists(cls, sftp, path):
|
def _path_exists(cls, sftp, path):
|
||||||
@ -442,7 +449,7 @@ class SshConnection(SshConnectionBase):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _push_folder(self, sftp, src, dst):
|
def _push_folder(self, sftp, src, dst, callback):
|
||||||
sftp.mkdir(dst)
|
sftp.mkdir(dst)
|
||||||
for entry in os.scandir(src):
|
for entry in os.scandir(src):
|
||||||
name = entry.name
|
name = entry.name
|
||||||
@ -453,17 +460,17 @@ class SshConnection(SshConnectionBase):
|
|||||||
else:
|
else:
|
||||||
push = self._push_file
|
push = self._push_file
|
||||||
|
|
||||||
push(sftp, src_path, dst_path)
|
push(sftp, src_path, dst_path, callback)
|
||||||
|
|
||||||
def _push_path(self, sftp, src, dst):
|
def _push_path(self, sftp, src, dst, callback=None):
|
||||||
logger.debug('Pushing via sftp: {} -> {}'.format(src, dst))
|
logger.debug('Pushing via sftp: {} -> {}'.format(src, dst))
|
||||||
push = self._push_folder if os.path.isdir(src) else self._push_file
|
push = self._push_folder if os.path.isdir(src) else self._push_file
|
||||||
push(sftp, src, dst)
|
push(sftp, src, dst, callback)
|
||||||
|
|
||||||
def _pull_file(self, sftp, src, dst):
|
def _pull_file(self, sftp, src, dst, callback):
|
||||||
sftp.get(src, dst, callback=self.transfer_mgr.progress_cb)
|
sftp.get(src, dst, callback=callback)
|
||||||
|
|
||||||
def _pull_folder(self, sftp, src, dst):
|
def _pull_folder(self, sftp, src, dst, callback):
|
||||||
os.makedirs(dst)
|
os.makedirs(dst)
|
||||||
for fileattr in sftp.listdir_attr(src):
|
for fileattr in sftp.listdir_attr(src):
|
||||||
filename = fileattr.filename
|
filename = fileattr.filename
|
||||||
@ -474,15 +481,15 @@ class SshConnection(SshConnectionBase):
|
|||||||
else:
|
else:
|
||||||
pull = self._pull_file
|
pull = self._pull_file
|
||||||
|
|
||||||
pull(sftp, src_path, dst_path)
|
pull(sftp, src_path, dst_path, callback)
|
||||||
|
|
||||||
def _pull_path(self, sftp, src, dst):
|
def _pull_path(self, sftp, src, dst, callback=None):
|
||||||
logger.debug('Pulling via sftp: {} -> {}'.format(src, dst))
|
logger.debug('Pulling via sftp: {} -> {}'.format(src, dst))
|
||||||
try:
|
try:
|
||||||
self._pull_file(sftp, src, dst)
|
self._pull_file(sftp, src, dst, callback)
|
||||||
except IOError:
|
except IOError:
|
||||||
# Maybe that was a directory, so retry as such
|
# Maybe that was a directory, so retry as such
|
||||||
self._pull_folder(sftp, src, dst)
|
self._pull_folder(sftp, src, dst, callback)
|
||||||
|
|
||||||
def push(self, sources, dest, timeout=None):
|
def push(self, sources, dest, timeout=None):
|
||||||
self._push_pull('push', sources, dest, timeout)
|
self._push_pull('push', sources, dest, timeout)
|
||||||
@ -494,8 +501,13 @@ class SshConnection(SshConnectionBase):
|
|||||||
if action not in ['push', 'pull']:
|
if action not in ['push', 'pull']:
|
||||||
raise ValueError("Action must be either `push` or `pull`")
|
raise ValueError("Action must be either `push` or `pull`")
|
||||||
|
|
||||||
# If timeout is set, or told not to poll
|
def make_handle(obj):
|
||||||
if timeout is not None or not self.poll_transfers:
|
handle = SSHTransferHandle(obj, manager=self.transfer_manager)
|
||||||
|
cm = self.transfer_manager.manage(sources, dest, action, handle)
|
||||||
|
return (handle, cm)
|
||||||
|
|
||||||
|
# If timeout is set
|
||||||
|
if timeout is not None:
|
||||||
if self.use_scp:
|
if self.use_scp:
|
||||||
scp = self._get_scp(timeout)
|
scp = self._get_scp(timeout)
|
||||||
scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
|
scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
|
||||||
@ -509,20 +521,23 @@ class SshConnection(SshConnectionBase):
|
|||||||
for source in sources:
|
for source in sources:
|
||||||
sftp_cmd(sftp, source, dest)
|
sftp_cmd(sftp, source, dest)
|
||||||
|
|
||||||
# No timeout, and polling is set
|
# No timeout
|
||||||
elif self.use_scp:
|
elif self.use_scp:
|
||||||
scp = self._get_scp(timeout)
|
scp = self._get_scp(timeout, callback=handle.progress_cb)
|
||||||
|
handle, cm = make_handle(scp)
|
||||||
|
|
||||||
scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
|
scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
|
||||||
with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, scp):
|
with _handle_paramiko_exceptions(), cm:
|
||||||
scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest)
|
scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest)
|
||||||
logger.debug(scp_msg.capitalize())
|
logger.debug(scp_msg.capitalize())
|
||||||
scp_cmd(sources, dest, recursive=True)
|
scp_cmd(sources, dest, recursive=True)
|
||||||
else:
|
else:
|
||||||
sftp = self._get_sftp(timeout)
|
sftp = self._get_sftp(timeout)
|
||||||
|
handle, cm = make_handle(sftp)
|
||||||
sftp_cmd = getattr(self, '_' + action + '_path')
|
sftp_cmd = getattr(self, '_' + action + '_path')
|
||||||
with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, sftp):
|
with _handle_paramiko_exceptions(), cm:
|
||||||
for source in sources:
|
for source in sources:
|
||||||
sftp_cmd(sftp, source, dest)
|
sftp_cmd(sftp, source, dest, callback=handle.progress_cb)
|
||||||
|
|
||||||
def execute(self, command, timeout=None, check_exit_code=True,
|
def execute(self, command, timeout=None, check_exit_code=True,
|
||||||
as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument
|
as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument
|
||||||
|
Loading…
x
Reference in New Issue
Block a user