1
0
mirror of https://github.com/ARM-software/devlib.git synced 2025-01-31 02:00:45 +00:00

connection: Rework TransferManager

* Split TransferManager and TransferHandle:
    * TransferManager deals with the generic monitoring. To abort a
      transfer, it simply cancels the transfer and raises an exception
      from manage().
    * TransferHandle provides a way for the manager to query the state
      of the transfer and cancel it. It is backend-specific.

* Remove most of the state in TransferManager, along with the associated
  background command leak etc

* Use a daemonic monitor thread to behave as excpected on interpreter
  shutdown.

* Ensure a transfer manager _always_ exists. When no management is
  desired, a noop object is used. This avoids using a None sentinel,
  which is invariably mishandled by some code leading to crashes.

* Try to merge more paths in the code to uncover as many issues as
  possible in testing.

* Fix percentage for SSHTransferHandle (transferred / (remaining +
  transferred) instead of transferred / remaining)

* Rename total_timeout TransferManager parameter and attribute to
  total_transfer_timeout to match the connection name parameter.
This commit is contained in:
Douglas Raillard 2023-02-20 18:18:11 +00:00 committed by Marc Bonnici
parent 1c5412be2f
commit ddaa2f1621
4 changed files with 227 additions and 196 deletions

View File

@ -14,12 +14,11 @@
#
from abc import ABC, abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import datetime
from functools import partial
from weakref import WeakSet
from shlex import quote
from time import monotonic
import os
import signal
import socket
@ -61,13 +60,27 @@ class ConnectionBase(InitCheckpoint):
"""
Base class for all connections.
"""
def __init__(self):
def __init__(
self,
poll_transfers=False,
start_transfer_poll_delay=30,
total_transfer_timeout=3600,
transfer_poll_period=30,
):
self._current_bg_cmds = set()
self._closed = False
self._close_lock = threading.Lock()
self.busybox = None
self.logger = logging.getLogger('Connection')
self.transfer_manager = TransferManager(
self,
start_transfer_poll_delay=start_transfer_poll_delay,
total_transfer_timeout=total_transfer_timeout,
transfer_poll_period=transfer_poll_period,
) if poll_transfers else NoopTransferManager()
def cancel_running_command(self):
bg_cmds = set(self._current_bg_cmds)
for bg_cmd in bg_cmds:
@ -405,13 +418,13 @@ class ParamikoBackgroundCommand(BackgroundCommand):
b''.join(out[stderr])
)
start = monotonic()
start = time.monotonic()
while ret is None:
# Even if ret is not None anymore, we need to drain the streams
ret = self.poll()
if timeout is not None and ret is None and monotonic() - start >= timeout:
if timeout is not None and ret is None and time.monotonic() - start >= timeout:
self.cancel()
_stdout, _stderr = create_out()
raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr)
@ -563,9 +576,89 @@ class AdbBackgroundCommand(BackgroundCommand):
return self
class TransferManagerBase(ABC):
class TransferManager:
def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600):
self.conn = conn
self.transfer_poll_period = transfer_poll_period
self.total_transfer_timeout = total_transfer_timeout
self.start_transfer_poll_delay = start_transfer_poll_delay
def _pull_dest_size(self, dest):
self.logger = logging.getLogger('FileTransfer')
@contextmanager
def manage(self, sources, dest, direction, handle):
excep = None
stop_thread = threading.Event()
def monitor():
nonlocal excep
def cancel(reason):
self.logger.warning(
f'Cancelling file transfer {sources} -> {dest} due to: {reason}'
)
handle.cancel()
start_t = time.monotonic()
stop_thread.wait(self.start_transfer_poll_delay)
while not stop_thread.wait(self.transfer_poll_period):
if not handle.isactive():
cancel(reason='transfer inactive')
elif time.monotonic() - start_t > self.total_transfer_timeout:
cancel(reason='transfer timed out')
excep = TimeoutError(f'{direction}: {sources} -> {dest}')
m_thread = threading.Thread(target=monitor, daemon=True)
try:
m_thread.start()
yield self
finally:
stop_thread.set()
m_thread.join()
if excep is not None:
raise excep
class NoopTransferManager:
def manage(self, *args, **kwargs):
return nullcontext(self)
class TransferHandleBase(ABC):
def __init__(self, manager):
self.manager = manager
@property
def logger(self):
return self.manager.logger
@abstractmethod
def isactive(self):
pass
@abstractmethod
def cancel(self):
pass
class PopenTransferHandle(TransferHandleBase):
def __init__(self, bg_cmd, dest, direction, *args, **kwargs):
super().__init__(*args, **kwargs)
if direction == 'push':
sample_size = self._push_dest_size
elif direction == 'pull':
sample_size = self._pull_dest_size
else:
raise ValueError(f'Unknown direction: {direction}')
self.sample_size = lambda: sample_size(dest)
self.bg_cmd = bg_cmd
self.last_sample = 0
@staticmethod
def _pull_dest_size(dest):
if os.path.isdir(dest):
return sum(
os.stat(os.path.join(dirpath, f)).st_size
@ -576,148 +669,58 @@ class TransferManagerBase(ABC):
return os.stat(dest).st_size
def _push_dest_size(self, dest):
cmd = '{} du -s {}'.format(quote(self.conn.busybox), quote(dest))
out = self.conn.execute(cmd)
try:
conn = self.manager.conn
cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest))
out = conn.execute(cmd)
return int(out.split()[0])
except ValueError:
return 0
def __init__(self, conn, poll_period, start_transfer_poll_delay, total_timeout):
self.conn = conn
self.poll_period = poll_period
self.total_timeout = total_timeout
self.start_transfer_poll_delay = start_transfer_poll_delay
def cancel(self):
self.bg_cmd.cancel()
self.logger = logging.getLogger('FileTransfer')
self.managing = threading.Event()
self.transfer_started = threading.Event()
self.transfer_completed = threading.Event()
self.transfer_aborted = threading.Event()
self.monitor_thread = None
self.sources = None
self.dest = None
self.direction = None
@abstractmethod
def _cancel(self):
pass
def cancel(self, reason=None):
msg = 'Cancelling file transfer {} -> {}'.format(self.sources, self.dest)
if reason is not None:
msg += ' due to \'{}\''.format(reason)
self.logger.warning(msg)
self.transfer_aborted.set()
self._cancel()
@abstractmethod
def isactive(self):
pass
@contextmanager
def manage(self, sources, dest, direction):
try:
self.sources, self.dest, self.direction = sources, dest, direction
m_thread = threading.Thread(target=self._monitor)
self.transfer_completed.clear()
self.transfer_aborted.clear()
self.transfer_started.set()
m_thread.start()
yield self
except BaseException:
self.cancel(reason='exception during transfer')
raise
finally:
self.transfer_completed.set()
self.transfer_started.set()
m_thread.join()
self.transfer_started.clear()
self.transfer_completed.clear()
self.transfer_aborted.clear()
def _monitor(self):
start_t = monotonic()
self.transfer_completed.wait(self.start_transfer_poll_delay)
while not self.transfer_completed.wait(self.poll_period):
if not self.isactive():
self.cancel(reason='transfer inactive')
elif monotonic() - start_t > self.total_timeout:
self.cancel(reason='transfer timed out')
class PopenTransferManager(TransferManagerBase):
def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
self.transfer = None
self.last_sample = None
def _cancel(self):
if self.transfer:
self.transfer.cancel()
self.transfer = None
self.last_sample = None
def isactive(self):
size_fn = self._push_dest_size if self.direction == 'push' else self._pull_dest_size
curr_size = size_fn(self.dest)
self.logger.debug('Polled file transfer, destination size {}'.format(curr_size))
active = True if self.last_sample is None else curr_size > self.last_sample
curr_size = self.sample_size()
except Exception as e:
self.logger.debug(f'File size polling failed: {e}')
return True
else:
self.logger.debug(f'Polled file transfer, destination size: {curr_size}')
if curr_size:
active = curr_size > self.last_sample
self.last_sample = curr_size
return active
def set_transfer_and_wait(self, popen_bg_cmd):
self.transfer = popen_bg_cmd
self.last_sample = None
ret = self.transfer.wait()
if ret and not self.transfer_aborted.is_set():
raise subprocess.CalledProcessError(ret, self.transfer.popen.args)
elif self.transfer_aborted.is_set():
raise TimeoutError(self.transfer.popen.args)
# If the file is empty it will never grow in size, so we assume
# everything is going well.
else:
return True
class SSHTransferManager(TransferManagerBase):
class SSHTransferHandle(TransferHandleBase):
def __init__(self, handle, *args, **kwargs):
super().__init__(*args, **kwargs)
# SFTPClient or SSHClient
self.handle = handle
def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
self.transferer = None
self.progressed = False
self.transferred = None
self.to_transfer = None
self.transferred = 0
self.to_transfer = 0
def _cancel(self):
self.transferer.close()
def cancel(self):
self.handle.close()
def isactive(self):
progressed = self.progressed
if progressed:
self.progressed = False
msg = 'Polled transfer: {}% [{}B/{}B]'
pc = format((self.transferred / self.to_transfer) * 100, '.2f')
self.logger.debug(msg.format(pc, self.transferred, self.to_transfer))
pc = (self.transferred / self.to_transfer) * 100
self.logger.debug(
f'Polled transfer: {pc:.2f}% [{self.transferred}B/{self.to_transfer}B]'
)
return progressed
@contextmanager
def manage(self, sources, dest, direction, transferer):
with super().manage(sources, dest, direction):
try:
self.progressed = False
self.transferer = transferer # SFTPClient or SCPClient
yield self
except socket.error as e:
if self.transfer_aborted.is_set():
self.transfer_aborted.clear()
method = 'SCP' if self.conn.use_scp else 'SFTP'
raise TimeoutError('{} {}: {} -> {}'.format(method, self.direction, sources, self.dest))
else:
raise e
def progress_cb(self, to_transfer, transferred):
if self.transfer_started.is_set():
def progress_cb(self, transferred, to_transfer):
self.progressed = True
self.transferred = transferred
self.to_transfer = to_transfer

View File

@ -2931,7 +2931,7 @@ class ChromeOsTarget(LinuxTarget):
ssh_conn_params = ['host', 'username', 'password', 'keyfile',
'port', 'timeout', 'sudo_cmd',
'strict_host_check', 'use_scp',
'total_timeout', 'poll_transfers',
'total_transfer_timeout', 'poll_transfers',
'start_transfer_poll_delay']
self.ssh_connection_settings = {}
for setting in ssh_conn_params:

View File

@ -39,7 +39,7 @@ from shlex import quote
from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError
from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess
from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferManager
from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferHandle
logger = logging.getLogger('android')
@ -278,26 +278,32 @@ class AdbConnection(ConnectionBase):
self._connected_as_root[self.device] = state
# pylint: disable=unused-argument
def __init__(self, device=None, timeout=None, platform=None, adb_server=None,
adb_as_root=False, connection_attempts=MAX_ATTEMPTS,
def __init__(
self,
device=None,
timeout=None,
platform=None,
adb_server=None,
adb_as_root=False,
connection_attempts=MAX_ATTEMPTS,
poll_transfers=False,
start_transfer_poll_delay=30,
total_transfer_timeout=3600,
transfer_poll_period=30,):
super().__init__()
transfer_poll_period=30,
):
super().__init__(
poll_transfers=poll_transfers,
start_transfer_poll_delay=start_transfer_poll_delay,
total_transfer_timeout=total_transfer_timeout,
transfer_poll_period=transfer_poll_period,
)
self.timeout = timeout if timeout is not None else self.default_timeout
if device is None:
device = adb_get_device(timeout=timeout, adb_server=adb_server)
self.device = device
self.adb_server = adb_server
self.adb_as_root = adb_as_root
self.poll_transfers = poll_transfers
if poll_transfers:
transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay,
'total_timeout': total_transfer_timeout,
'poll_period': transfer_poll_period,
}
self.transfer_mgr = PopenTransferManager(self, **transfer_opts) if poll_transfers else None
lock, nr_active = AdbConnection.active_connections
with lock:
nr_active[self.device] += 1
@ -330,17 +336,24 @@ class AdbConnection(ConnectionBase):
paths = ' '.join(map(do_quote, paths))
command = "{} {}".format(action, paths)
if timeout or not self.poll_transfers:
if timeout:
adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server)
else:
with self.transfer_mgr.manage(sources, dest, action):
bg_cmd = adb_command_background(
device=self.device,
conn=self,
command=command,
adb_server=self.adb_server
)
self.transfer_mgr.set_transfer_and_wait(bg_cmd)
handle = PopenTransferHandle(
manager=self.transfer_manager,
bg_cmd=bg_cmd,
dest=dest,
direction=action
)
with bg_cmd, self.transfer_manager.manage(sources, dest, action, handle):
bg_cmd.communicate()
# pylint: disable=unused-argument
def execute(self, command, timeout=None, check_exit_code=False,

View File

@ -59,7 +59,7 @@ from devlib.utils.misc import (which, strip_bash_colors, check_output,
sanitize_cmd_template, memoized, redirect_streams)
from devlib.utils.types import boolean
from devlib.connection import (ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand,
SSHTransferManager)
SSHTransferHandle)
DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p ' ' -S -- sh -c {}"
@ -295,8 +295,18 @@ class SshConnectionBase(ConnectionBase):
platform=None,
sudo_cmd=DEFAULT_SSH_SUDO_COMMAND,
strict_host_check=True,
poll_transfers=False,
start_transfer_poll_delay=30,
total_transfer_timeout=3600,
transfer_poll_period=30,
):
super().__init__()
super().__init__(
poll_transfers=poll_transfers,
start_transfer_poll_delay=start_transfer_poll_delay,
total_transfer_timeout=total_transfer_timeout,
transfer_poll_period=transfer_poll_period,
)
self._connected_as_root = None
self.host = host
self.username = username
@ -337,24 +347,21 @@ class SshConnection(SshConnectionBase):
platform=platform,
sudo_cmd=sudo_cmd,
strict_host_check=strict_host_check,
poll_transfers=poll_transfers,
start_transfer_poll_delay=start_transfer_poll_delay,
total_transfer_timeout=total_transfer_timeout,
transfer_poll_period=transfer_poll_period,
)
self.timeout = timeout if timeout is not None else self.default_timeout
# Allow using scp for file transfer if sftp is not supported
self.use_scp = use_scp
self.poll_transfers=poll_transfers
if poll_transfers:
transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay,
'total_timeout': total_transfer_timeout,
'poll_period': transfer_poll_period,
}
if self.use_scp:
logger.debug('Using SCP for file transfer')
else:
logger.debug('Using SFTP for file transfer')
self.transfer_mgr = SSHTransferManager(self, **transfer_opts) if poll_transfers else None
self.client = self._make_client()
atexit.register(self.close)
@ -426,12 +433,12 @@ class SshConnection(SshConnectionBase):
return sftp
@functools.lru_cache()
def _get_scp(self, timeout):
cb = lambda _, to_transfer, transferred: self.transfer_mgr.progress_cb(to_transfer, transferred)
def _get_scp(self, timeout, callback=lambda *_: None):
cb = lambda _, to_transfer, transferred: callback(to_transfer, transferred)
return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb)
def _push_file(self, sftp, src, dst):
sftp.put(src, dst, callback=self.transfer_mgr.progress_cb)
def _push_file(self, sftp, src, dst, callback):
sftp.put(src, dst, callback=callback)
@classmethod
def _path_exists(cls, sftp, path):
@ -442,7 +449,7 @@ class SshConnection(SshConnectionBase):
else:
return True
def _push_folder(self, sftp, src, dst):
def _push_folder(self, sftp, src, dst, callback):
sftp.mkdir(dst)
for entry in os.scandir(src):
name = entry.name
@ -453,17 +460,17 @@ class SshConnection(SshConnectionBase):
else:
push = self._push_file
push(sftp, src_path, dst_path)
push(sftp, src_path, dst_path, callback)
def _push_path(self, sftp, src, dst):
def _push_path(self, sftp, src, dst, callback=None):
logger.debug('Pushing via sftp: {} -> {}'.format(src, dst))
push = self._push_folder if os.path.isdir(src) else self._push_file
push(sftp, src, dst)
push(sftp, src, dst, callback)
def _pull_file(self, sftp, src, dst):
sftp.get(src, dst, callback=self.transfer_mgr.progress_cb)
def _pull_file(self, sftp, src, dst, callback):
sftp.get(src, dst, callback=callback)
def _pull_folder(self, sftp, src, dst):
def _pull_folder(self, sftp, src, dst, callback):
os.makedirs(dst)
for fileattr in sftp.listdir_attr(src):
filename = fileattr.filename
@ -474,15 +481,15 @@ class SshConnection(SshConnectionBase):
else:
pull = self._pull_file
pull(sftp, src_path, dst_path)
pull(sftp, src_path, dst_path, callback)
def _pull_path(self, sftp, src, dst):
def _pull_path(self, sftp, src, dst, callback=None):
logger.debug('Pulling via sftp: {} -> {}'.format(src, dst))
try:
self._pull_file(sftp, src, dst)
self._pull_file(sftp, src, dst, callback)
except IOError:
# Maybe that was a directory, so retry as such
self._pull_folder(sftp, src, dst)
self._pull_folder(sftp, src, dst, callback)
def push(self, sources, dest, timeout=None):
self._push_pull('push', sources, dest, timeout)
@ -494,8 +501,13 @@ class SshConnection(SshConnectionBase):
if action not in ['push', 'pull']:
raise ValueError("Action must be either `push` or `pull`")
# If timeout is set, or told not to poll
if timeout is not None or not self.poll_transfers:
def make_handle(obj):
handle = SSHTransferHandle(obj, manager=self.transfer_manager)
cm = self.transfer_manager.manage(sources, dest, action, handle)
return (handle, cm)
# If timeout is set
if timeout is not None:
if self.use_scp:
scp = self._get_scp(timeout)
scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
@ -509,20 +521,23 @@ class SshConnection(SshConnectionBase):
for source in sources:
sftp_cmd(sftp, source, dest)
# No timeout, and polling is set
# No timeout
elif self.use_scp:
scp = self._get_scp(timeout)
scp = self._get_scp(timeout, callback=handle.progress_cb)
handle, cm = make_handle(scp)
scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, scp):
with _handle_paramiko_exceptions(), cm:
scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest)
logger.debug(scp_msg.capitalize())
scp_cmd(sources, dest, recursive=True)
else:
sftp = self._get_sftp(timeout)
handle, cm = make_handle(sftp)
sftp_cmd = getattr(self, '_' + action + '_path')
with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, sftp):
with _handle_paramiko_exceptions(), cm:
for source in sources:
sftp_cmd(sftp, source, dest)
sftp_cmd(sftp, source, dest, callback=handle.progress_cb)
def execute(self, command, timeout=None, check_exit_code=True,
as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument