1
0
mirror of https://github.com/ARM-software/devlib.git synced 2024-10-05 18:30:50 +01:00

target: Factorize push/pull path resolution

Handle in one place the decision of what is the real destination of each
file in push/pull operations.

The following can now be assumed by the connection:

    * The destination does not exist.
    * The folder containing the destination does exist.

This ensures consistent errors and behaviors across all connection
types, at the cost of:

    * At least an extra execute() per source (up to 2 if the destination
      is a file that needs to be removed to make room).
    * For now, globbing will lead to a separate request for each file,
      rather than a merged one. This is because the destination of each
      source is prepared so that the connection will not have any
      interpretation work to do.
This commit is contained in:
Douglas Raillard 2021-08-12 11:33:36 +01:00 committed by Marc Bonnici
parent 79be8bc5ad
commit 173df18f29
4 changed files with 132 additions and 89 deletions

View File

@ -68,11 +68,6 @@ class LocalConnection(ConnectionBase):
def _copy_path(self, source, dest):
self.logger.debug('copying {} to {}'.format(source, dest))
if os.path.isdir(source):
# Behave similarly as cp, scp, adb push, etc. by creating a new
# folder instead of merging hierarchies
if os.path.exists(dest):
dest = os.path.join(dest, os.path.basename(os.path.normpath(source)))
# Use distutils copy_tree since it behaves the same as
# shutils.copytree except that it won't fail if some folders
# already exist.

View File

@ -451,37 +451,136 @@ class Target(object):
finally:
self.execute('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm)
def _prepare_xfer(self, action, sources, dest):
def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False):
"""
Check the sanity of sources and destination and prepare the ground for
transfering multiple sources.
"""
once = functools.lru_cache(maxsize=None)
_target_cache = {}
def target_paths_kind(paths):
def process(x):
x = x.strip()
if x == 'notexist':
return None
else:
return x
_paths = [
path
for path in paths
if path not in _target_cache
]
if _paths:
cmd = '; '.join(
'if [ -d {path} ]; then echo dir; elif [ -e {path} ]; then echo file; else echo notexist; fi'.format(
path=quote(path)
)
for path in _paths
)
res = self.execute(cmd)
_target_cache.update(zip(_paths, map(process, res.split())))
return [
_target_cache[path]
for path in paths
]
_host_cache = {}
def host_paths_kind(paths):
def path_kind(path):
if os.path.isdir(path):
return 'dir'
elif os.path.exists(path):
return 'file'
else:
return None
for path in paths:
if path not in _host_cache:
_host_cache[path] = path_kind(path)
return [
_host_cache[path]
for path in paths
]
# TODO: Target.remove() and Target.makedirs() would probably benefit
# from being implemented by connections, with the current
# implementation in ConnectionBase. This would allow SshConnection to
# use SFTP for these operations, which should be cheaper than
# Target.execute()
if action == 'push':
src_excep = HostError
dst_excep = TargetStableError
dst_path_exists = self.file_exists
dst_is_dir = self.directory_exists
dst_mkdir = self.makedirs
src_path_kind = host_paths_kind
for source in sources:
if not os.path.exists(source):
raise HostError('No such file "{}"'.format(source))
else:
dst_mkdir = once(self.makedirs)
dst_path_join = self.path.join
dst_paths_kind = target_paths_kind
dst_remove_file = once(functools.partial(self.remove, as_root=as_root))
elif action == 'pull':
src_excep = TargetStableError
dst_excep = HostError
dst_path_exists = os.path.exists
dst_is_dir = os.path.isdir
dst_mkdir = functools.partial(os.makedirs, exist_ok=True)
src_path_kind = target_paths_kind
if not sources:
raise src_excep('No file matching: {}'.format(sources))
elif len(sources) > 1:
if dst_path_exists(dest):
if not dst_is_dir(dest):
raise dst_excep('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))
dst_mkdir = once(functools.partial(os.makedirs, exist_ok=True))
dst_path_join = os.path.join
dst_paths_kind = host_paths_kind
dst_remove_file = once(os.remove)
else:
raise ValueError('Unknown action "{}"'.format(action))
def rewrite_dst(src, dst):
new_dst = dst_path_join(dst, os.path.basename(src))
src_kind, = src_path_kind([src])
# Batch both checks to avoid a costly extra execute()
dst_kind, new_dst_kind = dst_paths_kind([dst, new_dst])
if src_kind == 'file':
if dst_kind == 'dir':
if new_dst_kind == 'dir':
raise IsADirectoryError(new_dst)
if new_dst_kind == 'file':
dst_remove_file(new_dst)
return new_dst
else:
return new_dst
elif dst_kind == 'file':
dst_remove_file(dst)
return dst
else:
dst_mkdir(os.path.dirname(dst))
return dst
elif src_kind == 'dir':
if dst_kind == 'dir':
# Do not allow writing over an existing folder
if new_dst_kind == 'dir':
raise FileExistsError(new_dst)
if new_dst_kind == 'file':
raise FileExistsError(new_dst)
else:
return new_dst
elif dst_kind == 'file':
raise FileExistsError(dst_kind)
else:
dst_mkdir(os.path.dirname(dst))
return dst
else:
dst_mkdir(dest)
raise FileNotFoundError(src)
if pattern:
if not sources:
raise src_excep('No file matching source pattern: {}'.format(pattern))
if dst_path_exists(dest) and not dst_is_dir(dest):
raise NotADirectoryError('A folder dest is required for multiple matches but destination is a file: {}'.format(dest))
return {
src: rewrite_dst(src, dest)
for src in sources
}
@call_conn
def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ
@ -489,18 +588,19 @@ class Target(object):
dest = str(dest)
sources = glob.glob(source) if globbing else [source]
self._prepare_xfer('push', sources, dest)
mapping = self._prepare_xfer('push', sources, dest, pattern=source if globbing else None, as_root=as_root)
def do_push(sources, dest):
return self.conn.push(sources, dest, timeout=timeout)
if as_root:
for source in sources:
for source, dest in mapping.items():
with self._xfer_cache_path(source) as device_tempfile:
do_push([source], device_tempfile)
self.execute("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True)
else:
do_push(sources, dest)
for source, dest in mapping.items():
do_push([source], dest)
def _expand_glob(self, pattern, **kwargs):
"""
@ -552,19 +652,20 @@ class Target(object):
else:
sources = [source]
self._prepare_xfer('pull', sources, dest)
mapping = self._prepare_xfer('pull', sources, dest, pattern=source if globbing else None, as_root=as_root)
def do_pull(sources, dest):
self.conn.pull(sources, dest, timeout=timeout)
if as_root:
for source in sources:
for source, dest in mapping.items():
with self._xfer_cache_path(source) as device_tempfile:
self.execute("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=True)
self.execute("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=True)
do_pull([device_tempfile], dest)
else:
do_pull(sources, dest)
for source, dest in mapping.items():
do_pull([source], dest)
def get_directory(self, source_dir, dest, as_root=False):
""" Pull a directory from the device, after compressing dir """

View File

@ -315,6 +315,7 @@ class AdbConnection(ConnectionBase):
return self._push_pull('pull', sources, dest, timeout)
def _push_pull(self, action, sources, dest, timeout):
sources = list(sources)
paths = sources + [dest]
# Quote twice to avoid expansion by host shell, then ADB globbing

View File

@ -424,26 +424,7 @@ class SshConnection(SshConnectionBase):
return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=self._get_progress_cb())
def _push_file(self, sftp, src, dst):
try:
sftp.put(src, dst, callback=self._get_progress_cb())
# Maybe the dst was a folder
except OSError as orig_excep:
# If dst was an existing folder, we add the src basename to create
# a new destination for the file as cp would do
new_dst = os.path.join(
dst,
os.path.basename(src),
)
logger.debug('Trying: {} -> {}'.format(src, new_dst))
try:
sftp.put(src, new_dst, callback=self._get_progress_cb())
# This still failed, which either means:
# * There are some missing folders in the dirnames
# * Something else SFTP-related is wrong
except OSError as e:
# Raise the original exception, as it is closer to what the
# user asked in the first place
raise orig_excep
sftp.put(src, dst, callback=self._get_progress_cb())
@classmethod
def _path_exists(cls, sftp, path):
@ -455,29 +436,13 @@ class SshConnection(SshConnectionBase):
return True
def _push_folder(self, sftp, src, dst):
# Behave like the "mv" command or adb push: a new folder is created
# inside the destination folder, rather than merging the trees, but
# only if the destination already exists. Otherwise, it is use as-is as
# the new hierarchy name.
if self._path_exists(sftp, dst):
dst = os.path.join(
dst,
os.path.basename(os.path.normpath(src)),
)
return self._push_folder_internal(sftp, src, dst)
def _push_folder_internal(self, sftp, src, dst):
# This might fail if the folder already exists
with contextlib.suppress(IOError):
sftp.mkdir(dst)
sftp.mkdir(dst)
for entry in os.scandir(src):
name = entry.name
src_path = os.path.join(src, name)
dst_path = os.path.join(dst, name)
if entry.is_dir():
push = self._push_folder_internal
push = self._push_folder
else:
push = self._push_file
@ -489,25 +454,9 @@ class SshConnection(SshConnectionBase):
push(sftp, src, dst)
def _pull_file(self, sftp, src, dst):
# Pulling a file into a folder will use the source basename
if os.path.isdir(dst):
dst = os.path.join(
dst,
os.path.basename(src),
)
with contextlib.suppress(FileNotFoundError):
os.remove(dst)
sftp.get(src, dst, callback=self._get_progress_cb())
def _pull_folder(self, sftp, src, dst):
with contextlib.suppress(FileNotFoundError):
try:
shutil.rmtree(dst)
except OSError:
os.remove(dst)
os.makedirs(dst)
for fileattr in sftp.listdir_attr(src):
filename = fileattr.filename
@ -836,7 +785,7 @@ class TelnetConnection(SshConnectionBase):
def push(self, sources, dest, timeout=30):
# Quote the destination as SCP would apply globbing too
dest = self.fmt_remote_path(quote(dest))
paths = sources + [dest]
paths = list(sources) + [dest]
return self._scp(paths, timeout)
def pull(self, sources, dest, timeout=30):
@ -1113,9 +1062,6 @@ class Gem5Connection(TelnetConnection):
# We need to copy the file to copy to the temporary directory
self._move_to_temp_dir(source)
# Dest in gem5 world is a file rather than directory
if os.path.basename(dest) != filename:
dest = os.path.join(dest, filename)
# Back to the gem5 world
filename = quote(self.gem5_input_dir + filename)
self._gem5_shell("ls -al {}".format(filename))