From b8e25efdd47a1cdf35c60735a52b954fb050ac23 Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Fri, 10 Jul 2015 11:43:21 +0100 Subject: [PATCH] ssh: attempt to deal with dropped connections --- wlauto/common/linux/device.py | 5 ++-- wlauto/utils/ssh.py | 45 +++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/wlauto/common/linux/device.py b/wlauto/common/linux/device.py index 6e2364ab..1f00faed 100644 --- a/wlauto/common/linux/device.py +++ b/wlauto/common/linux/device.py @@ -580,8 +580,9 @@ class LinuxDevice(BaseLinuxDevice): raise DeviceError('Could not connect to {} after reboot'.format(self.host)) def connect(self): # NOQA pylint: disable=R0912 - self.shell = SshShell(password_prompt=self.password_prompt, timeout=self.default_timeout) - self.shell.login(self.host, self.username, self.password, self.keyfile, self.port, telnet=self.use_telnet) + self.shell = SshShell(password_prompt=self.password_prompt, + timeout=self.default_timeout, telnet=self.use_telnet) + self.shell.login(self.host, self.username, self.password, self.keyfile, self.port) self._is_ready = True def disconnect(self): # NOQA pylint: disable=R0912 diff --git a/wlauto/utils/ssh.py b/wlauto/utils/ssh.py index 91e84e1b..d6d3db7a 100644 --- a/wlauto/utils/ssh.py +++ b/wlauto/utils/ssh.py @@ -109,13 +109,15 @@ class SshShell(object): default_password_prompt = '[sudo] password' max_cancel_attempts = 5 - def __init__(self, password_prompt=None, timeout=10): + def __init__(self, password_prompt=None, timeout=10, telnet=False): self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt self.timeout = timeout + self.telnet = telnet self.conn = None self.lock = threading.Lock() + self.connection_lost = False - def login(self, host, username, password=None, keyfile=None, port=None, timeout=None, telnet=False): + def login(self, host, username, password=None, keyfile=None, port=None, timeout=None): # pylint: disable=attribute-defined-outside-init logger.debug('Logging in {}@{}'.format(username, host)) self.host = host @@ -124,7 +126,7 @@ class SshShell(object): self.keyfile = check_keyfile(keyfile) if keyfile else keyfile self.port = port timeout = self.timeout if timeout is None else timeout - self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet) + self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, self.telnet) def push_file(self, source, dest, timeout=30): dest = '{}@{}:{}'.format(self.username, self.host, dest) @@ -143,19 +145,32 @@ class SshShell(object): command = _give_password(self.password, command) return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) + def reconnect(self): + self.conn = ssh_get_shell(self.host, self.username, self.password, + self.keyfile, self.port, self.timeout, self.telnet) + def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True): - with self.lock: - output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors) - if check_exit_code: - exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False) - try: - exit_code = int(exit_code_text.split()[0]) - if exit_code: - message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}' - raise DeviceError(message.format(exit_code, command, output)) - except ValueError: - logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text)) - return output + try: + with self.lock: + if self.connection_lost: + logger.debug('Attempting to reconnect...') + self.reconnect() + self.connection_lost = False + output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors) + if check_exit_code: + exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False) + try: + exit_code = int(exit_code_text.split()[0]) + if exit_code: + message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}' + raise DeviceError(message.format(exit_code, command, output)) + except ValueError: + logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text)) + return output + except EOF: + logger.error('Dropped connection detected.') + self.connection_lost = True + raise def logout(self): logger.debug('Logging out {}@{}'.format(self.username, self.host))