mirror of
				https://github.com/ARM-software/devlib.git
				synced 2025-11-04 07:51:21 +00:00 
			
		
		
		
	target: Factorize push/pull path resolution
Handle in one place the decision of what is the real destination of each
file in push/pull operations.
The following can now be assumed by the connection:
    * The destination does not exist.
    * The folder containing the destination does exist.
This ensures consistent errors and behaviors across all connection
types, at the cost of:
    * At least an extra execute() per source (up to 2 if the destination
      is a file that needs to be removed to make room).
    * For now, globbing will lead to a separate request for each file,
      rather than a merged one. This is because the destination of each
      source is prepared so that the connection will not have any
      interpretation work to do.
			
			
This commit is contained in:
		
				
					committed by
					
						
						Marc Bonnici
					
				
			
			
				
	
			
			
			
						parent
						
							79be8bc5ad
						
					
				
				
					commit
					173df18f29
				
			@@ -68,11 +68,6 @@ class LocalConnection(ConnectionBase):
 | 
			
		||||
    def _copy_path(self, source, dest):
 | 
			
		||||
        self.logger.debug('copying {} to {}'.format(source, dest))
 | 
			
		||||
        if os.path.isdir(source):
 | 
			
		||||
            # Behave similarly as cp, scp, adb push, etc. by creating a new
 | 
			
		||||
            # folder instead of merging hierarchies
 | 
			
		||||
            if os.path.exists(dest):
 | 
			
		||||
                dest = os.path.join(dest, os.path.basename(os.path.normpath(source)))
 | 
			
		||||
 | 
			
		||||
            # Use distutils copy_tree since it behaves the same as
 | 
			
		||||
            # shutils.copytree except that it won't fail if some folders
 | 
			
		||||
            # already exist.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										153
									
								
								devlib/target.py
									
									
									
									
									
								
							
							
						
						
									
										153
									
								
								devlib/target.py
									
									
									
									
									
								
							@@ -451,37 +451,136 @@ class Target(object):
 | 
			
		||||
        finally:
 | 
			
		||||
            self.execute('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm)
 | 
			
		||||
 | 
			
		||||
    def _prepare_xfer(self, action, sources, dest):
 | 
			
		||||
    def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False):
 | 
			
		||||
        """
 | 
			
		||||
        Check the sanity of sources and destination and prepare the ground for
 | 
			
		||||
        transfering multiple sources.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        once = functools.lru_cache(maxsize=None)
 | 
			
		||||
 | 
			
		||||
        _target_cache = {}
 | 
			
		||||
        def target_paths_kind(paths):
 | 
			
		||||
            def process(x):
 | 
			
		||||
                x = x.strip()
 | 
			
		||||
                if x == 'notexist':
 | 
			
		||||
                    return None
 | 
			
		||||
                else:
 | 
			
		||||
                    return x
 | 
			
		||||
 | 
			
		||||
            _paths = [
 | 
			
		||||
                path
 | 
			
		||||
                for path in paths
 | 
			
		||||
                if path not in _target_cache
 | 
			
		||||
            ]
 | 
			
		||||
            if _paths:
 | 
			
		||||
                cmd = '; '.join(
 | 
			
		||||
                    'if [ -d {path} ]; then echo dir; elif [ -e {path} ]; then echo file; else echo notexist; fi'.format(
 | 
			
		||||
                        path=quote(path)
 | 
			
		||||
                    )
 | 
			
		||||
                    for path in _paths
 | 
			
		||||
                )
 | 
			
		||||
                res = self.execute(cmd)
 | 
			
		||||
                _target_cache.update(zip(_paths, map(process, res.split())))
 | 
			
		||||
 | 
			
		||||
            return [
 | 
			
		||||
                _target_cache[path]
 | 
			
		||||
                for path in paths
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        _host_cache = {}
 | 
			
		||||
        def host_paths_kind(paths):
 | 
			
		||||
            def path_kind(path):
 | 
			
		||||
                if os.path.isdir(path):
 | 
			
		||||
                    return 'dir'
 | 
			
		||||
                elif os.path.exists(path):
 | 
			
		||||
                    return 'file'
 | 
			
		||||
                else:
 | 
			
		||||
                    return None
 | 
			
		||||
 | 
			
		||||
            for path in paths:
 | 
			
		||||
                if path not in _host_cache:
 | 
			
		||||
                    _host_cache[path] = path_kind(path)
 | 
			
		||||
 | 
			
		||||
            return [
 | 
			
		||||
                _host_cache[path]
 | 
			
		||||
                for path in paths
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        # TODO: Target.remove() and Target.makedirs() would probably benefit
 | 
			
		||||
        # from being implemented by connections, with the current
 | 
			
		||||
        # implementation in ConnectionBase. This would allow SshConnection to
 | 
			
		||||
        # use SFTP for these operations, which should be cheaper than
 | 
			
		||||
        # Target.execute()
 | 
			
		||||
        if action == 'push':
 | 
			
		||||
            src_excep = HostError
 | 
			
		||||
            dst_excep = TargetStableError
 | 
			
		||||
            dst_path_exists = self.file_exists
 | 
			
		||||
            dst_is_dir = self.directory_exists
 | 
			
		||||
            dst_mkdir = self.makedirs
 | 
			
		||||
            src_path_kind = host_paths_kind
 | 
			
		||||
 | 
			
		||||
            for source in sources:
 | 
			
		||||
                if not os.path.exists(source):
 | 
			
		||||
                    raise HostError('No such file "{}"'.format(source))
 | 
			
		||||
        else:
 | 
			
		||||
            dst_mkdir = once(self.makedirs)
 | 
			
		||||
            dst_path_join = self.path.join
 | 
			
		||||
            dst_paths_kind = target_paths_kind
 | 
			
		||||
            dst_remove_file = once(functools.partial(self.remove, as_root=as_root))
 | 
			
		||||
        elif action == 'pull':
 | 
			
		||||
            src_excep = TargetStableError
 | 
			
		||||
            dst_excep = HostError
 | 
			
		||||
            dst_path_exists = os.path.exists
 | 
			
		||||
            dst_is_dir = os.path.isdir
 | 
			
		||||
            dst_mkdir = functools.partial(os.makedirs, exist_ok=True)
 | 
			
		||||
            src_path_kind = target_paths_kind
 | 
			
		||||
 | 
			
		||||
        if not sources:
 | 
			
		||||
            raise src_excep('No file matching: {}'.format(sources))
 | 
			
		||||
        elif len(sources) > 1:
 | 
			
		||||
            if dst_path_exists(dest):
 | 
			
		||||
                if not dst_is_dir(dest):
 | 
			
		||||
                    raise dst_excep('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))
 | 
			
		||||
            dst_mkdir = once(functools.partial(os.makedirs, exist_ok=True))
 | 
			
		||||
            dst_path_join = os.path.join
 | 
			
		||||
            dst_paths_kind = host_paths_kind
 | 
			
		||||
            dst_remove_file = once(os.remove)
 | 
			
		||||
        else:
 | 
			
		||||
                dst_mkdir(dest)
 | 
			
		||||
            raise ValueError('Unknown action "{}"'.format(action))
 | 
			
		||||
 | 
			
		||||
        def rewrite_dst(src, dst):
 | 
			
		||||
            new_dst = dst_path_join(dst, os.path.basename(src))
 | 
			
		||||
 | 
			
		||||
            src_kind, = src_path_kind([src])
 | 
			
		||||
            # Batch both checks to avoid a costly extra execute()
 | 
			
		||||
            dst_kind, new_dst_kind = dst_paths_kind([dst, new_dst])
 | 
			
		||||
 | 
			
		||||
            if src_kind == 'file':
 | 
			
		||||
                if dst_kind == 'dir':
 | 
			
		||||
                    if new_dst_kind == 'dir':
 | 
			
		||||
                        raise IsADirectoryError(new_dst)
 | 
			
		||||
                    if new_dst_kind == 'file':
 | 
			
		||||
                        dst_remove_file(new_dst)
 | 
			
		||||
                        return new_dst
 | 
			
		||||
                    else:
 | 
			
		||||
                        return new_dst
 | 
			
		||||
                elif dst_kind == 'file':
 | 
			
		||||
                    dst_remove_file(dst)
 | 
			
		||||
                    return dst
 | 
			
		||||
                else:
 | 
			
		||||
                    dst_mkdir(os.path.dirname(dst))
 | 
			
		||||
                    return dst
 | 
			
		||||
            elif src_kind == 'dir':
 | 
			
		||||
                if dst_kind == 'dir':
 | 
			
		||||
                    # Do not allow writing over an existing folder
 | 
			
		||||
                    if new_dst_kind == 'dir':
 | 
			
		||||
                        raise FileExistsError(new_dst)
 | 
			
		||||
                    if new_dst_kind == 'file':
 | 
			
		||||
                        raise FileExistsError(new_dst)
 | 
			
		||||
                    else:
 | 
			
		||||
                        return new_dst
 | 
			
		||||
                elif dst_kind == 'file':
 | 
			
		||||
                    raise FileExistsError(dst_kind)
 | 
			
		||||
                else:
 | 
			
		||||
                    dst_mkdir(os.path.dirname(dst))
 | 
			
		||||
                    return dst
 | 
			
		||||
            else:
 | 
			
		||||
                raise FileNotFoundError(src)
 | 
			
		||||
 | 
			
		||||
        if pattern:
 | 
			
		||||
            if not sources:
 | 
			
		||||
                raise src_excep('No file matching source pattern: {}'.format(pattern))
 | 
			
		||||
 | 
			
		||||
            if dst_path_exists(dest) and not dst_is_dir(dest):
 | 
			
		||||
                raise NotADirectoryError('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            src: rewrite_dst(src, dest)
 | 
			
		||||
            for src in sources
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    @call_conn
 | 
			
		||||
    def push(self, source, dest, as_root=False, timeout=None, globbing=False):  # pylint: disable=arguments-differ
 | 
			
		||||
@@ -489,18 +588,19 @@ class Target(object):
 | 
			
		||||
        dest = str(dest)
 | 
			
		||||
 | 
			
		||||
        sources = glob.glob(source) if globbing else [source]
 | 
			
		||||
        self._prepare_xfer('push', sources, dest)
 | 
			
		||||
        mapping = self._prepare_xfer('push', sources, dest, pattern=source if globbing else None, as_root=as_root)
 | 
			
		||||
 | 
			
		||||
        def do_push(sources, dest):
 | 
			
		||||
            return self.conn.push(sources, dest, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
        if as_root:
 | 
			
		||||
            for source in sources:
 | 
			
		||||
            for source, dest in mapping.items():
 | 
			
		||||
                with self._xfer_cache_path(source) as device_tempfile:
 | 
			
		||||
                    do_push([source], device_tempfile)
 | 
			
		||||
                    self.execute("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True)
 | 
			
		||||
        else:
 | 
			
		||||
            do_push(sources, dest)
 | 
			
		||||
            for source, dest in mapping.items():
 | 
			
		||||
                do_push([source], dest)
 | 
			
		||||
 | 
			
		||||
    def _expand_glob(self, pattern, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
@@ -552,19 +652,20 @@ class Target(object):
 | 
			
		||||
        else:
 | 
			
		||||
            sources = [source]
 | 
			
		||||
 | 
			
		||||
        self._prepare_xfer('pull', sources, dest)
 | 
			
		||||
        mapping = self._prepare_xfer('pull', sources, dest, pattern=source if globbing else None, as_root=as_root)
 | 
			
		||||
 | 
			
		||||
        def do_pull(sources, dest):
 | 
			
		||||
            self.conn.pull(sources, dest, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
        if as_root:
 | 
			
		||||
            for source in sources:
 | 
			
		||||
            for source, dest in mapping.items():
 | 
			
		||||
                with self._xfer_cache_path(source) as device_tempfile:
 | 
			
		||||
                    self.execute("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=True)
 | 
			
		||||
                    self.execute("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=True)
 | 
			
		||||
                    do_pull([device_tempfile], dest)
 | 
			
		||||
        else:
 | 
			
		||||
            do_pull(sources, dest)
 | 
			
		||||
            for source, dest in mapping.items():
 | 
			
		||||
                do_pull([source], dest)
 | 
			
		||||
 | 
			
		||||
    def get_directory(self, source_dir, dest, as_root=False):
 | 
			
		||||
        """ Pull a directory from the device, after compressing dir """
 | 
			
		||||
 
 | 
			
		||||
@@ -315,6 +315,7 @@ class AdbConnection(ConnectionBase):
 | 
			
		||||
        return self._push_pull('pull', sources, dest, timeout)
 | 
			
		||||
 | 
			
		||||
    def _push_pull(self, action, sources, dest, timeout):
 | 
			
		||||
        sources = list(sources)
 | 
			
		||||
        paths = sources + [dest]
 | 
			
		||||
 | 
			
		||||
        # Quote twice to avoid expansion by host shell, then ADB globbing
 | 
			
		||||
 
 | 
			
		||||
@@ -424,26 +424,7 @@ class SshConnection(SshConnectionBase):
 | 
			
		||||
        return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=self._get_progress_cb())
 | 
			
		||||
 | 
			
		||||
    def _push_file(self, sftp, src, dst):
 | 
			
		||||
        try:
 | 
			
		||||
        sftp.put(src, dst, callback=self._get_progress_cb())
 | 
			
		||||
        # Maybe the dst was a folder
 | 
			
		||||
        except OSError as orig_excep:
 | 
			
		||||
            # If dst was an existing folder, we add the src basename to create
 | 
			
		||||
            # a new destination for the file as cp would do
 | 
			
		||||
            new_dst = os.path.join(
 | 
			
		||||
                dst,
 | 
			
		||||
                os.path.basename(src),
 | 
			
		||||
            )
 | 
			
		||||
            logger.debug('Trying: {} -> {}'.format(src, new_dst))
 | 
			
		||||
            try:
 | 
			
		||||
                sftp.put(src, new_dst, callback=self._get_progress_cb())
 | 
			
		||||
            # This still failed, which either means:
 | 
			
		||||
            # * There are some missing folders in the dirnames
 | 
			
		||||
            # * Something else SFTP-related is wrong
 | 
			
		||||
            except OSError as e:
 | 
			
		||||
                # Raise the original exception, as it is closer to what the
 | 
			
		||||
                # user asked in the first place
 | 
			
		||||
                raise orig_excep
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _path_exists(cls, sftp, path):
 | 
			
		||||
@@ -455,29 +436,13 @@ class SshConnection(SshConnectionBase):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    def _push_folder(self, sftp, src, dst):
 | 
			
		||||
        # Behave like the "mv" command or adb push: a new folder is created
 | 
			
		||||
        # inside the destination folder, rather than merging the trees, but
 | 
			
		||||
        # only if the destination already exists. Otherwise, it is use as-is as
 | 
			
		||||
        # the new hierarchy name.
 | 
			
		||||
        if self._path_exists(sftp, dst):
 | 
			
		||||
            dst = os.path.join(
 | 
			
		||||
                dst,
 | 
			
		||||
                os.path.basename(os.path.normpath(src)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return self._push_folder_internal(sftp, src, dst)
 | 
			
		||||
 | 
			
		||||
    def _push_folder_internal(self, sftp, src, dst):
 | 
			
		||||
        # This might fail if the folder already exists
 | 
			
		||||
        with contextlib.suppress(IOError):
 | 
			
		||||
        sftp.mkdir(dst)
 | 
			
		||||
 | 
			
		||||
        for entry in os.scandir(src):
 | 
			
		||||
            name = entry.name
 | 
			
		||||
            src_path = os.path.join(src, name)
 | 
			
		||||
            dst_path = os.path.join(dst, name)
 | 
			
		||||
            if entry.is_dir():
 | 
			
		||||
                push = self._push_folder_internal
 | 
			
		||||
                push = self._push_folder
 | 
			
		||||
            else:
 | 
			
		||||
                push = self._push_file
 | 
			
		||||
 | 
			
		||||
@@ -489,25 +454,9 @@ class SshConnection(SshConnectionBase):
 | 
			
		||||
        push(sftp, src, dst)
 | 
			
		||||
 | 
			
		||||
    def _pull_file(self, sftp, src, dst):
 | 
			
		||||
        # Pulling a file into a folder will use the source basename
 | 
			
		||||
        if os.path.isdir(dst):
 | 
			
		||||
            dst = os.path.join(
 | 
			
		||||
                dst,
 | 
			
		||||
                os.path.basename(src),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with contextlib.suppress(FileNotFoundError):
 | 
			
		||||
            os.remove(dst)
 | 
			
		||||
 | 
			
		||||
        sftp.get(src, dst, callback=self._get_progress_cb())
 | 
			
		||||
 | 
			
		||||
    def _pull_folder(self, sftp, src, dst):
 | 
			
		||||
        with contextlib.suppress(FileNotFoundError):
 | 
			
		||||
            try:
 | 
			
		||||
                shutil.rmtree(dst)
 | 
			
		||||
            except OSError:
 | 
			
		||||
                os.remove(dst)
 | 
			
		||||
 | 
			
		||||
        os.makedirs(dst)
 | 
			
		||||
        for fileattr in sftp.listdir_attr(src):
 | 
			
		||||
            filename = fileattr.filename
 | 
			
		||||
@@ -836,7 +785,7 @@ class TelnetConnection(SshConnectionBase):
 | 
			
		||||
    def push(self, sources, dest, timeout=30):
 | 
			
		||||
        # Quote the destination as SCP would apply globbing too
 | 
			
		||||
        dest = self.fmt_remote_path(quote(dest))
 | 
			
		||||
        paths = sources + [dest]
 | 
			
		||||
        paths = list(sources) + [dest]
 | 
			
		||||
        return self._scp(paths, timeout)
 | 
			
		||||
 | 
			
		||||
    def pull(self, sources, dest, timeout=30):
 | 
			
		||||
@@ -1113,9 +1062,6 @@ class Gem5Connection(TelnetConnection):
 | 
			
		||||
            # We need to copy the file to copy to the temporary directory
 | 
			
		||||
            self._move_to_temp_dir(source)
 | 
			
		||||
 | 
			
		||||
            # Dest in gem5 world is a file rather than directory
 | 
			
		||||
            if os.path.basename(dest) != filename:
 | 
			
		||||
                dest = os.path.join(dest, filename)
 | 
			
		||||
            # Back to the gem5 world
 | 
			
		||||
            filename = quote(self.gem5_input_dir + filename)
 | 
			
		||||
            self._gem5_shell("ls -al {}".format(filename))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user