1
0
mirror of https://github.com/ARM-software/devlib.git synced 2024-10-05 18:30:50 +01:00

target: Factorize push/pull path resolution

Handle in one place the decision of what is the real destination of each
file in push/pull operations.

The following can now be assumed by the connection:

    * The destination does not exist.
    * The folder containing the destination does exist.

This ensures consistent errors and behaviors across all connection
types, at the cost of:

    * At least an extra execute() per source (up to 2 if the destination
      is a file that needs to be removed to make room).
    * For now, globbing will lead to a separate request for each file,
      rather than a merged one. This is because the destination of each
      source is prepared so that the connection will not have any
      interpretation work to do.
This commit is contained in:
Douglas Raillard 2021-08-12 11:33:36 +01:00 committed by Marc Bonnici
parent 79be8bc5ad
commit 173df18f29
4 changed files with 132 additions and 89 deletions

View File

@ -68,11 +68,6 @@ class LocalConnection(ConnectionBase):
def _copy_path(self, source, dest): def _copy_path(self, source, dest):
self.logger.debug('copying {} to {}'.format(source, dest)) self.logger.debug('copying {} to {}'.format(source, dest))
if os.path.isdir(source): if os.path.isdir(source):
# Behave similarly as cp, scp, adb push, etc. by creating a new
# folder instead of merging hierarchies
if os.path.exists(dest):
dest = os.path.join(dest, os.path.basename(os.path.normpath(source)))
# Use distutils copy_tree since it behaves the same as # Use distutils copy_tree since it behaves the same as
# shutils.copytree except that it won't fail if some folders # shutils.copytree except that it won't fail if some folders
# already exist. # already exist.

View File

@ -451,37 +451,136 @@ class Target(object):
finally: finally:
self.execute('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm) self.execute('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm)
def _prepare_xfer(self, action, sources, dest): def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False):
""" """
Check the sanity of sources and destination and prepare the ground for Check the sanity of sources and destination and prepare the ground for
transfering multiple sources. transfering multiple sources.
""" """
once = functools.lru_cache(maxsize=None)
_target_cache = {}
def target_paths_kind(paths):
def process(x):
x = x.strip()
if x == 'notexist':
return None
else:
return x
_paths = [
path
for path in paths
if path not in _target_cache
]
if _paths:
cmd = '; '.join(
'if [ -d {path} ]; then echo dir; elif [ -e {path} ]; then echo file; else echo notexist; fi'.format(
path=quote(path)
)
for path in _paths
)
res = self.execute(cmd)
_target_cache.update(zip(_paths, map(process, res.split())))
return [
_target_cache[path]
for path in paths
]
_host_cache = {}
def host_paths_kind(paths):
def path_kind(path):
if os.path.isdir(path):
return 'dir'
elif os.path.exists(path):
return 'file'
else:
return None
for path in paths:
if path not in _host_cache:
_host_cache[path] = path_kind(path)
return [
_host_cache[path]
for path in paths
]
# TODO: Target.remove() and Target.makedirs() would probably benefit
# from being implemented by connections, with the current
# implementation in ConnectionBase. This would allow SshConnection to
# use SFTP for these operations, which should be cheaper than
# Target.execute()
if action == 'push': if action == 'push':
src_excep = HostError src_excep = HostError
dst_excep = TargetStableError src_path_kind = host_paths_kind
dst_path_exists = self.file_exists
dst_is_dir = self.directory_exists
dst_mkdir = self.makedirs
for source in sources: dst_mkdir = once(self.makedirs)
if not os.path.exists(source): dst_path_join = self.path.join
raise HostError('No such file "{}"'.format(source)) dst_paths_kind = target_paths_kind
else: dst_remove_file = once(functools.partial(self.remove, as_root=as_root))
elif action == 'pull':
src_excep = TargetStableError src_excep = TargetStableError
dst_excep = HostError src_path_kind = target_paths_kind
dst_path_exists = os.path.exists
dst_is_dir = os.path.isdir
dst_mkdir = functools.partial(os.makedirs, exist_ok=True)
if not sources: dst_mkdir = once(functools.partial(os.makedirs, exist_ok=True))
raise src_excep('No file matching: {}'.format(sources)) dst_path_join = os.path.join
elif len(sources) > 1: dst_paths_kind = host_paths_kind
if dst_path_exists(dest): dst_remove_file = once(os.remove)
if not dst_is_dir(dest):
raise dst_excep('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))
else: else:
dst_mkdir(dest) raise ValueError('Unknown action "{}"'.format(action))
def rewrite_dst(src, dst):
new_dst = dst_path_join(dst, os.path.basename(src))
src_kind, = src_path_kind([src])
# Batch both checks to avoid a costly extra execute()
dst_kind, new_dst_kind = dst_paths_kind([dst, new_dst])
if src_kind == 'file':
if dst_kind == 'dir':
if new_dst_kind == 'dir':
raise IsADirectoryError(new_dst)
if new_dst_kind == 'file':
dst_remove_file(new_dst)
return new_dst
else:
return new_dst
elif dst_kind == 'file':
dst_remove_file(dst)
return dst
else:
dst_mkdir(os.path.dirname(dst))
return dst
elif src_kind == 'dir':
if dst_kind == 'dir':
# Do not allow writing over an existing folder
if new_dst_kind == 'dir':
raise FileExistsError(new_dst)
if new_dst_kind == 'file':
raise FileExistsError(new_dst)
else:
return new_dst
elif dst_kind == 'file':
raise FileExistsError(dst_kind)
else:
dst_mkdir(os.path.dirname(dst))
return dst
else:
raise FileNotFoundError(src)
if pattern:
if not sources:
raise src_excep('No file matching source pattern: {}'.format(pattern))
if dst_path_exists(dest) and not dst_is_dir(dest):
raise NotADirectoryError('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))
return {
src: rewrite_dst(src, dest)
for src in sources
}
@call_conn @call_conn
def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ
@ -489,18 +588,19 @@ class Target(object):
dest = str(dest) dest = str(dest)
sources = glob.glob(source) if globbing else [source] sources = glob.glob(source) if globbing else [source]
self._prepare_xfer('push', sources, dest) mapping = self._prepare_xfer('push', sources, dest, pattern=source if globbing else None, as_root=as_root)
def do_push(sources, dest): def do_push(sources, dest):
return self.conn.push(sources, dest, timeout=timeout) return self.conn.push(sources, dest, timeout=timeout)
if as_root: if as_root:
for source in sources: for source, dest in mapping.items():
with self._xfer_cache_path(source) as device_tempfile: with self._xfer_cache_path(source) as device_tempfile:
do_push([source], device_tempfile) do_push([source], device_tempfile)
self.execute("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True) self.execute("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True)
else: else:
do_push(sources, dest) for source, dest in mapping.items():
do_push([source], dest)
def _expand_glob(self, pattern, **kwargs): def _expand_glob(self, pattern, **kwargs):
""" """
@ -552,19 +652,20 @@ class Target(object):
else: else:
sources = [source] sources = [source]
self._prepare_xfer('pull', sources, dest) mapping = self._prepare_xfer('pull', sources, dest, pattern=source if globbing else None, as_root=as_root)
def do_pull(sources, dest): def do_pull(sources, dest):
self.conn.pull(sources, dest, timeout=timeout) self.conn.pull(sources, dest, timeout=timeout)
if as_root: if as_root:
for source in sources: for source, dest in mapping.items():
with self._xfer_cache_path(source) as device_tempfile: with self._xfer_cache_path(source) as device_tempfile:
self.execute("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=True) self.execute("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=True)
self.execute("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=True) self.execute("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=True)
do_pull([device_tempfile], dest) do_pull([device_tempfile], dest)
else: else:
do_pull(sources, dest) for source, dest in mapping.items():
do_pull([source], dest)
def get_directory(self, source_dir, dest, as_root=False): def get_directory(self, source_dir, dest, as_root=False):
""" Pull a directory from the device, after compressing dir """ """ Pull a directory from the device, after compressing dir """

View File

@ -315,6 +315,7 @@ class AdbConnection(ConnectionBase):
return self._push_pull('pull', sources, dest, timeout) return self._push_pull('pull', sources, dest, timeout)
def _push_pull(self, action, sources, dest, timeout): def _push_pull(self, action, sources, dest, timeout):
sources = list(sources)
paths = sources + [dest] paths = sources + [dest]
# Quote twice to avoid expansion by host shell, then ADB globbing # Quote twice to avoid expansion by host shell, then ADB globbing

View File

@ -424,26 +424,7 @@ class SshConnection(SshConnectionBase):
return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=self._get_progress_cb()) return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=self._get_progress_cb())
def _push_file(self, sftp, src, dst): def _push_file(self, sftp, src, dst):
try:
sftp.put(src, dst, callback=self._get_progress_cb()) sftp.put(src, dst, callback=self._get_progress_cb())
# Maybe the dst was a folder
except OSError as orig_excep:
# If dst was an existing folder, we add the src basename to create
# a new destination for the file as cp would do
new_dst = os.path.join(
dst,
os.path.basename(src),
)
logger.debug('Trying: {} -> {}'.format(src, new_dst))
try:
sftp.put(src, new_dst, callback=self._get_progress_cb())
# This still failed, which either means:
# * There are some missing folders in the dirnames
# * Something else SFTP-related is wrong
except OSError as e:
# Raise the original exception, as it is closer to what the
# user asked in the first place
raise orig_excep
@classmethod @classmethod
def _path_exists(cls, sftp, path): def _path_exists(cls, sftp, path):
@ -455,29 +436,13 @@ class SshConnection(SshConnectionBase):
return True return True
def _push_folder(self, sftp, src, dst): def _push_folder(self, sftp, src, dst):
# Behave like the "mv" command or adb push: a new folder is created
# inside the destination folder, rather than merging the trees, but
# only if the destination already exists. Otherwise, it is use as-is as
# the new hierarchy name.
if self._path_exists(sftp, dst):
dst = os.path.join(
dst,
os.path.basename(os.path.normpath(src)),
)
return self._push_folder_internal(sftp, src, dst)
def _push_folder_internal(self, sftp, src, dst):
# This might fail if the folder already exists
with contextlib.suppress(IOError):
sftp.mkdir(dst) sftp.mkdir(dst)
for entry in os.scandir(src): for entry in os.scandir(src):
name = entry.name name = entry.name
src_path = os.path.join(src, name) src_path = os.path.join(src, name)
dst_path = os.path.join(dst, name) dst_path = os.path.join(dst, name)
if entry.is_dir(): if entry.is_dir():
push = self._push_folder_internal push = self._push_folder
else: else:
push = self._push_file push = self._push_file
@ -489,25 +454,9 @@ class SshConnection(SshConnectionBase):
push(sftp, src, dst) push(sftp, src, dst)
def _pull_file(self, sftp, src, dst): def _pull_file(self, sftp, src, dst):
# Pulling a file into a folder will use the source basename
if os.path.isdir(dst):
dst = os.path.join(
dst,
os.path.basename(src),
)
with contextlib.suppress(FileNotFoundError):
os.remove(dst)
sftp.get(src, dst, callback=self._get_progress_cb()) sftp.get(src, dst, callback=self._get_progress_cb())
def _pull_folder(self, sftp, src, dst): def _pull_folder(self, sftp, src, dst):
with contextlib.suppress(FileNotFoundError):
try:
shutil.rmtree(dst)
except OSError:
os.remove(dst)
os.makedirs(dst) os.makedirs(dst)
for fileattr in sftp.listdir_attr(src): for fileattr in sftp.listdir_attr(src):
filename = fileattr.filename filename = fileattr.filename
@ -836,7 +785,7 @@ class TelnetConnection(SshConnectionBase):
def push(self, sources, dest, timeout=30): def push(self, sources, dest, timeout=30):
# Quote the destination as SCP would apply globbing too # Quote the destination as SCP would apply globbing too
dest = self.fmt_remote_path(quote(dest)) dest = self.fmt_remote_path(quote(dest))
paths = sources + [dest] paths = list(sources) + [dest]
return self._scp(paths, timeout) return self._scp(paths, timeout)
def pull(self, sources, dest, timeout=30): def pull(self, sources, dest, timeout=30):
@ -1113,9 +1062,6 @@ class Gem5Connection(TelnetConnection):
# We need to copy the file to copy to the temporary directory # We need to copy the file to copy to the temporary directory
self._move_to_temp_dir(source) self._move_to_temp_dir(source)
# Dest in gem5 world is a file rather than directory
if os.path.basename(dest) != filename:
dest = os.path.join(dest, filename)
# Back to the gem5 world # Back to the gem5 world
filename = quote(self.gem5_input_dir + filename) filename = quote(self.gem5_input_dir + filename)
self._gem5_shell("ls -al {}".format(filename)) self._gem5_shell("ls -al {}".format(filename))