diff --git a/devlib/connection.py b/devlib/connection.py index 596d3fc..aea423b 100644 --- a/devlib/connection.py +++ b/devlib/connection.py @@ -13,13 +13,20 @@ # limitations under the License. # -import os -import time -import subprocess -import signal -import threading -from weakref import WeakSet from abc import ABC, abstractmethod +from contextlib import contextmanager +from datetime import datetime +from functools import partial +from weakref import WeakSet +from shlex import quote +from time import monotonic +import os +import signal +import socket +import subprocess +import threading +import time +import logging from devlib.utils.misc import InitCheckpoint @@ -38,6 +45,7 @@ class ConnectionBase(InitCheckpoint): self._current_bg_cmds = WeakSet() self._closed = False self._close_lock = threading.Lock() + self.busybox = None def cancel_running_command(self): bg_cmds = set(self._current_bg_cmds) @@ -220,6 +228,7 @@ class PopenBackgroundCommand(BackgroundCommand): def __exit__(self, *args, **kwargs): self.popen.__exit__(*args, **kwargs) + class ParamikoBackgroundCommand(BackgroundCommand): """ :mod:`paramiko`-based background command. @@ -349,3 +358,166 @@ class AdbBackgroundCommand(BackgroundCommand): def __exit__(self, *args, **kwargs): self.adb_popen.__exit__(*args, **kwargs) + + +class TransferManagerBase(ABC): + + def _pull_dest_size(self, dest): + if os.path.isdir(dest): + return sum( + os.stat(os.path.join(dirpath, f)).st_size + for dirpath, _, fnames in os.walk(dest) + for f in fnames + ) + else: + return os.stat(dest).st_size + return 0 + + def _push_dest_size(self, dest): + cmd = '{} du -s {}'.format(quote(self.conn.busybox), quote(dest)) + out = self.conn.execute(cmd) + try: + return int(out.split()[0]) + except ValueError: + return 0 + + def __init__(self, conn, poll_period, start_transfer_poll_delay, total_timeout): + self.conn = conn + self.poll_period = poll_period + self.total_timeout = total_timeout + self.start_transfer_poll_delay = start_transfer_poll_delay + + self.logger = logging.getLogger('FileTransfer') + self.managing = threading.Event() + self.transfer_started = threading.Event() + self.transfer_completed = threading.Event() + self.transfer_aborted = threading.Event() + + self.monitor_thread = None + self.sources = None + self.dest = None + self.direction = None + + @abstractmethod + def _cancel(self): + pass + + def cancel(self, reason=None): + msg = 'Cancelling file transfer {} -> {}'.format(self.sources, self.dest) + if reason is not None: + msg += ' due to \'{}\''.format(reason) + self.logger.warning(msg) + self.transfer_aborted.set() + self._cancel() + + @abstractmethod + def isactive(self): + pass + + @contextmanager + def manage(self, sources, dest, direction): + try: + self.sources, self.dest, self.direction = sources, dest, direction + m_thread = threading.Thread(target=self._monitor) + + self.transfer_completed.clear() + self.transfer_aborted.clear() + self.transfer_started.set() + + m_thread.start() + yield self + except BaseException: + self.cancel(reason='exception during transfer') + raise + finally: + self.transfer_completed.set() + self.transfer_started.set() + m_thread.join() + self.transfer_started.clear() + self.transfer_completed.clear() + self.transfer_aborted.clear() + + def _monitor(self): + start_t = monotonic() + self.transfer_completed.wait(self.start_transfer_poll_delay) + while not self.transfer_completed.wait(self.poll_period): + if not self.isactive(): + self.cancel(reason='transfer inactive') + elif monotonic() - start_t > self.total_timeout: + self.cancel(reason='transfer timed out') + + +class PopenTransferManager(TransferManagerBase): + + def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600): + super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout) + self.transfer = None + self.last_sample = None + + def _cancel(self): + if self.transfer: + self.transfer.cancel() + self.transfer = None + + def isactive(self): + size_fn = self._push_dest_size if self.direction == 'push' else self._pull_dest_size + curr_size = size_fn(self.dest) + self.logger.debug('Polled file transfer, destination size {}'.format(curr_size)) + active = True if self.last_sample is None else curr_size > self.last_sample + self.last_sample = curr_size + return active + + def set_transfer_and_wait(self, popen_bg_cmd): + self.transfer = popen_bg_cmd + ret = self.transfer.wait() + + if ret and not self.transfer_aborted.is_set(): + raise subprocess.CalledProcessError(ret, self.transfer.popen.args) + elif self.transfer_aborted.is_set(): + raise TimeoutError(self.transfer.popen.args) + + +class SSHTransferManager(TransferManagerBase): + + def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600): + super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout) + self.transferer = None + self.progressed = False + self.transferred = None + self.to_transfer = None + + def _cancel(self): + self.transferer.close() + + def isactive(self): + progressed = self.progressed + self.progressed = False + msg = 'Polled transfer: {}% [{}B/{}B]' + pc = format((self.transferred / self.to_transfer) * 100, '.2f') + self.logger.debug(msg.format(pc, self.transferred, self.to_transfer)) + return progressed + + @contextmanager + def manage(self, sources, dest, direction, transferer): + with super().manage(sources, dest, direction): + try: + self.progressed = False + self.transferer = transferer # SFTPClient or SCPClient + yield self + except socket.error as e: + if self.transfer_aborted.is_set(): + self.transfer_aborted.clear() + method = 'SCP' if self.conn.use_scp else 'SFTP' + raise TimeoutError('{} {}: {} -> {}'.format(method, self.direction, sources, self.dest)) + else: + raise e + + def progress_cb(self, *args): + if self.transfer_started.is_set(): + self.progressed = True + if len(args) == 3: # For SCPClient callbacks + self.transferred = args[2] + self.to_transfer = args[1] + elif len(args) == 2: # For SFTPClient callbacks + self.transferred = args[0] + self.to_transfer = args[1] \ No newline at end of file diff --git a/devlib/target.py b/devlib/target.py index 843aabc..82b2dc8 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -299,7 +299,8 @@ class Target(object): self._resolve_paths() self.execute('mkdir -p {}'.format(quote(self.working_directory))) self.execute('mkdir -p {}'.format(quote(self.executables_directory))) - self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi, 'busybox')) + self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi, 'busybox'), timeout=30) + self.conn.busybox = self.busybox self.platform.update_from_target(self) self._update_modules('connected') if self.platform.big_core and self.load_default_modules: @@ -2403,7 +2404,9 @@ class ChromeOsTarget(LinuxTarget): # Pull out ssh connection settings ssh_conn_params = ['host', 'username', 'password', 'keyfile', 'port', 'timeout', 'sudo_cmd', - 'strict_host_check', 'use_scp'] + 'strict_host_check', 'use_scp', + 'total_timeout', 'poll_transfers', + 'start_transfer_poll_delay'] self.ssh_connection_settings = {} for setting in ssh_conn_params: if connection_settings.get(setting, None): diff --git a/devlib/utils/android.py b/devlib/utils/android.py index edc488c..2c5c59e 100755 --- a/devlib/utils/android.py +++ b/devlib/utils/android.py @@ -39,8 +39,8 @@ except ImportError: from pipes import quote from devlib.exception import TargetTransientError, TargetStableError, HostError -from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams -from devlib.connection import ConnectionBase, AdbBackgroundCommand +from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess +from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferManager logger = logging.getLogger('android') @@ -263,18 +263,21 @@ class AdbConnection(ConnectionBase): @property def connected_as_root(self): if self._connected_as_root[self.device] is None: - result = self.execute('id') - self._connected_as_root[self.device] = 'uid=0(' in result + result = self.execute('id') + self._connected_as_root[self.device] = 'uid=0(' in result return self._connected_as_root[self.device] @connected_as_root.setter def connected_as_root(self, state): self._connected_as_root[self.device] = state - # pylint: disable=unused-argument def __init__(self, device=None, timeout=None, platform=None, adb_server=None, - adb_as_root=False, connection_attempts=MAX_ATTEMPTS): + adb_as_root=False, connection_attempts=MAX_ATTEMPTS, + poll_transfers=False, + start_transfer_poll_delay=30, + total_transfer_timeout=3600, + transfer_poll_period=30,): super().__init__() self.timeout = timeout if timeout is not None else self.default_timeout if device is None: @@ -282,6 +285,13 @@ class AdbConnection(ConnectionBase): self.device = device self.adb_server = adb_server self.adb_as_root = adb_as_root + self.poll_transfers = poll_transfers + if poll_transfers: + transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay, + 'total_timeout': total_transfer_timeout, + 'poll_period': transfer_poll_period, + } + self.transfer_mgr = PopenTransferManager(self, **transfer_opts) if poll_transfers else None if self.adb_as_root: self.adb_root(enable=True) adb_connect(self.device, adb_server=self.adb_server, attempts=connection_attempts) @@ -289,10 +299,13 @@ class AdbConnection(ConnectionBase): self._setup_ls() self._setup_su() - def _push_pull(self, action, sources, dest, timeout): - if timeout is None: - timeout = self.timeout + def push(self, sources, dest, timeout=None): + return self._push_pull('push', sources, dest, timeout) + def pull(self, sources, dest, timeout=None): + return self._push_pull('pull', sources, dest, timeout) + + def _push_pull(self, action, sources, dest, timeout): paths = sources + [dest] # Quote twice to avoid expansion by host shell, then ADB globbing @@ -300,13 +313,12 @@ class AdbConnection(ConnectionBase): paths = ' '.join(map(do_quote, paths)) command = "{} {}".format(action, paths) - adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server) - - def push(self, sources, dest, timeout=None): - return self._push_pull('push', sources, dest, timeout) - - def pull(self, sources, dest, timeout=None): - return self._push_pull('pull', sources, dest, timeout) + if timeout or not self.poll_transfers: + adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server) + else: + with self.transfer_mgr.manage(sources, dest, action): + bg_cmd = adb_command_background(self.device, command, adb_server=self.adb_server) + self.transfer_mgr.set_transfer_and_wait(bg_cmd) # pylint: disable=unused-argument def execute(self, command, timeout=None, check_exit_code=False, @@ -321,6 +333,11 @@ class AdbConnection(ConnectionBase): raise def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + bg_cmd = self._background(command, stdout, stderr, as_root) + self._current_bg_cmds.add(bg_cmd) + return bg_cmd + + def _background(self, command, stdout, stderr, as_root): adb_shell, pid = adb_background_shell(self, command, stdout, stderr, as_root) bg_cmd = AdbBackgroundCommand( conn=self, @@ -328,7 +345,6 @@ class AdbConnection(ConnectionBase): pid=pid, as_root=as_root ) - self._current_bg_cmds.add(bg_cmd) return bg_cmd def _close(self): @@ -610,12 +626,22 @@ def get_adb_command(device, command, adb_server=None): device_string += ' -s {}'.format(device) if device else '' return "adb{} {}".format(device_string, command) + def adb_command(device, command, timeout=None, adb_server=None): full_command = get_adb_command(device, command, adb_server) logger.debug(full_command) output, _ = check_output(full_command, timeout, shell=True) return output + +def adb_command_background(device, command, adb_server=None): + full_command = get_adb_command(device, command, adb_server) + logger.debug(full_command) + proc = get_subprocess(full_command, shell=True) + cmd = PopenBackgroundCommand(proc) + return cmd + + def grant_app_permissions(target, package): """ Grant an app all the permissions it may ask for diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 1a85208..b8cbc55 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -55,7 +55,8 @@ from devlib.exception import (HostError, TargetStableError, TargetNotRespondingE from devlib.utils.misc import (which, strip_bash_colors, check_output, sanitize_cmd_template, memoized, redirect_streams) from devlib.utils.types import boolean -from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand +from devlib.connection import (ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand, + SSHTransferManager) ssh = None @@ -367,7 +368,11 @@ class SshConnection(SshConnectionBase): platform=None, sudo_cmd="sudo -S -- sh -c {}", strict_host_check=True, - use_scp=False + use_scp=False, + poll_transfers=False, + start_transfer_poll_delay=30, + total_transfer_timeout=3600, + transfer_poll_period=30, ): super().__init__( @@ -384,6 +389,13 @@ class SshConnection(SshConnectionBase): # Allow using scp for file transfer if sftp is not supported self.use_scp = use_scp + self.poll_transfers=poll_transfers + if poll_transfers: + transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay, + 'total_timeout': total_transfer_timeout, + 'poll_period': transfer_poll_period, + } + if self.use_scp: logger.debug('Using SCP for file transfer') _check_env() @@ -391,6 +403,7 @@ class SshConnection(SshConnectionBase): else: logger.debug('Using SFTP for file transfer') + self.transfer_mgr = SSHTransferManager(self, **transfer_opts) if poll_transfers else None self.client = self._make_client() atexit.register(self.close) @@ -442,19 +455,20 @@ class SshConnection(SshConnectionBase): channel = transport.open_session() return channel + def _get_progress_cb(self): + return self.transfer_mgr.progress_cb if self.transfer_mgr is not None else None + def _get_sftp(self, timeout): sftp = self.client.open_sftp() sftp.get_channel().settimeout(timeout) return sftp - def _get_scp(self, timeout): - return SCPClient(self.client.get_transport(), socket_timeout=timeout) + return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=self._get_progress_cb()) - @classmethod - def _push_file(cls, sftp, src, dst): + def _push_file(self, sftp, src, dst): try: - sftp.put(src, dst) + sftp.put(src, dst, callback=self._get_progress_cb()) # Maybe the dst was a folder except OSError as orig_excep: # If dst was an existing folder, we add the src basename to create @@ -465,7 +479,7 @@ class SshConnection(SshConnectionBase): ) logger.debug('Trying: {} -> {}'.format(src, new_dst)) try: - sftp.put(src, new_dst) + sftp.put(src, new_dst, callback=self._get_progress_cb()) # This still failed, which either means: # * There are some missing folders in the dirnames # * Something else SFTP-related is wrong @@ -483,22 +497,20 @@ class SshConnection(SshConnectionBase): else: return True - @classmethod - def _push_folder(cls, sftp, src, dst): + def _push_folder(self, sftp, src, dst): # Behave like the "mv" command or adb push: a new folder is created # inside the destination folder, rather than merging the trees, but # only if the destination already exists. Otherwise, it is use as-is as # the new hierarchy name. - if cls._path_exists(sftp, dst): + if self._path_exists(sftp, dst): dst = os.path.join( dst, os.path.basename(os.path.normpath(src)), ) - return cls._push_folder_internal(sftp, src, dst) + return self._push_folder_internal(sftp, src, dst) - @classmethod - def _push_folder_internal(cls, sftp, src, dst): + def _push_folder_internal(self, sftp, src, dst): # This might fail if the folder already exists with contextlib.suppress(IOError): sftp.mkdir(dst) @@ -508,20 +520,18 @@ class SshConnection(SshConnectionBase): src_path = os.path.join(src, name) dst_path = os.path.join(dst, name) if entry.is_dir(): - push = cls._push_folder_internal + push = self._push_folder_internal else: - push = cls._push_file + push = self._push_file push(sftp, src_path, dst_path) - @classmethod - def _push_path(cls, sftp, src, dst): - logger.debug('Pushing via sftp: {} -> {}'.format(src,dst)) - push = cls._push_folder if os.path.isdir(src) else cls._push_file + def _push_path(self, sftp, src, dst): + logger.debug('Pushing via sftp: {} -> {}'.format(src, dst)) + push = self._push_folder if os.path.isdir(src) else self._push_file push(sftp, src, dst) - @classmethod - def _pull_file(cls, sftp, src, dst): + def _pull_file(self, sftp, src, dst): # Pulling a file into a folder will use the source basename if os.path.isdir(dst): dst = os.path.join( @@ -532,10 +542,9 @@ class SshConnection(SshConnectionBase): with contextlib.suppress(FileNotFoundError): os.remove(dst) - sftp.get(src, dst) - - @classmethod - def _pull_folder(cls, sftp, src, dst): + sftp.get(src, dst, callback=self._get_progress_cb()) + + def _pull_folder(self, sftp, src, dst): with contextlib.suppress(FileNotFoundError): try: shutil.rmtree(dst) @@ -548,42 +557,59 @@ class SshConnection(SshConnectionBase): src_path = os.path.join(src, filename) dst_path = os.path.join(dst, filename) if stat.S_ISDIR(fileattr.st_mode): - pull = cls._pull_folder + pull = self._pull_folder else: - pull = cls._pull_file + pull = self._pull_file pull(sftp, src_path, dst_path) - @classmethod - def _pull_path(cls, sftp, src, dst): - logger.debug('Pulling via sftp: {} -> {}'.format(src,dst)) + def _pull_path(self, sftp, src, dst): + logger.debug('Pulling via sftp: {} -> {}'.format(src, dst)) try: - cls._pull_file(sftp, src, dst) + self._pull_file(sftp, src, dst) except IOError: # Maybe that was a directory, so retry as such - cls._pull_folder(sftp, src, dst) + self._pull_folder(sftp, src, dst) - def push(self, sources, dest, timeout=30): - # If using scp, use implementation from base class - with _handle_paramiko_exceptions(): + def push(self, sources, dest, timeout=None): + self._push_pull('push', sources, dest, timeout) + + def pull(self, sources, dest, timeout=None): + self._push_pull('pull', sources, dest, timeout) + + def _push_pull(self, action, sources, dest, timeout): + if action not in ['push', 'pull']: + raise ValueError("Action must be either `push` or `pull`") + + # If timeout is set, or told not to poll + if timeout is not None or not self.poll_transfers: if self.use_scp: scp = self._get_scp(timeout) - scp.put(sources, dest, recursive=True) + scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') + scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) + logger.debug(scp_msg.capitalize()) + scp_cmd(sources, dest, recursive=True) else: - with self._get_sftp(timeout) as sftp: + sftp = self._get_sftp(timeout) + sftp_cmd = getattr(self, '_' + action + '_path') + with _handle_paramiko_exceptions(): for source in sources: - self._push_path(sftp, source, dest) + sftp_cmd(sftp, source, dest) - def pull(self, sources, dest, timeout=30): - # If using scp, use implementation from base class - with _handle_paramiko_exceptions(): - if self.use_scp: - scp = self._get_scp(timeout) - scp.get(sources, dest, recursive=True) - else: - with self._get_sftp(timeout) as sftp: - for source in sources: - self._pull_path(sftp, source, dest) + # No timeout, and polling is set + elif self.use_scp: + scp = self._get_scp(timeout) + scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') + with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, scp): + scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) + logger.debug(scp_msg.capitalize()) + scp_cmd(sources, dest, recursive=True) + else: + sftp = self._get_sftp(timeout) + sftp_cmd = getattr(self, '_' + action + '_path') + with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, sftp): + for source in sources: + sftp_cmd(sftp, source, dest) def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument