From b8e25efdd47a1cdf35c60735a52b954fb050ac23 Mon Sep 17 00:00:00 2001
From: Sergei Trofimov <sergei.trofimov@arm.com>
Date: Fri, 10 Jul 2015 11:43:21 +0100
Subject: [PATCH] ssh: attempt to deal with dropped connections

---
 wlauto/common/linux/device.py |  5 ++--
 wlauto/utils/ssh.py           | 45 +++++++++++++++++++++++------------
 2 files changed, 33 insertions(+), 17 deletions(-)

diff --git a/wlauto/common/linux/device.py b/wlauto/common/linux/device.py
index 6e2364ab..1f00faed 100644
--- a/wlauto/common/linux/device.py
+++ b/wlauto/common/linux/device.py
@@ -580,8 +580,9 @@ class LinuxDevice(BaseLinuxDevice):
             raise DeviceError('Could not connect to {} after reboot'.format(self.host))
 
     def connect(self):  # NOQA pylint: disable=R0912
-        self.shell = SshShell(password_prompt=self.password_prompt, timeout=self.default_timeout)
-        self.shell.login(self.host, self.username, self.password, self.keyfile, self.port, telnet=self.use_telnet)
+        self.shell = SshShell(password_prompt=self.password_prompt,
+                              timeout=self.default_timeout, telnet=self.use_telnet)
+        self.shell.login(self.host, self.username, self.password, self.keyfile, self.port)
         self._is_ready = True
 
     def disconnect(self):  # NOQA pylint: disable=R0912
diff --git a/wlauto/utils/ssh.py b/wlauto/utils/ssh.py
index 91e84e1b..d6d3db7a 100644
--- a/wlauto/utils/ssh.py
+++ b/wlauto/utils/ssh.py
@@ -109,13 +109,15 @@ class SshShell(object):
     default_password_prompt = '[sudo] password'
     max_cancel_attempts = 5
 
-    def __init__(self, password_prompt=None, timeout=10):
+    def __init__(self, password_prompt=None, timeout=10, telnet=False):
         self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
         self.timeout = timeout
+        self.telnet = telnet
         self.conn = None
         self.lock = threading.Lock()
+        self.connection_lost = False
 
-    def login(self, host, username, password=None, keyfile=None, port=None, timeout=None, telnet=False):
+    def login(self, host, username, password=None, keyfile=None, port=None, timeout=None):
         # pylint: disable=attribute-defined-outside-init
         logger.debug('Logging in {}@{}'.format(username, host))
         self.host = host
@@ -124,7 +126,7 @@ class SshShell(object):
         self.keyfile = check_keyfile(keyfile) if keyfile else keyfile
         self.port = port
         timeout = self.timeout if timeout is None else timeout
-        self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet)
+        self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, self.telnet)
 
     def push_file(self, source, dest, timeout=30):
         dest = '{}@{}:{}'.format(self.username, self.host, dest)
@@ -143,19 +145,32 @@ class SshShell(object):
             command = _give_password(self.password, command)
         return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
 
+    def reconnect(self):
+        self.conn = ssh_get_shell(self.host, self.username, self.password,
+                                  self.keyfile, self.port, self.timeout, self.telnet)
+
     def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
-        with self.lock:
-            output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
-            if check_exit_code:
-                exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
-                try:
-                    exit_code = int(exit_code_text.split()[0])
-                    if exit_code:
-                        message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
-                        raise DeviceError(message.format(exit_code, command, output))
-                except ValueError:
-                    logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
-            return output
+        try:
+            with self.lock:
+                if self.connection_lost:
+                    logger.debug('Attempting to reconnect...')
+                    self.reconnect()
+                    self.connection_lost = False
+                output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
+                if check_exit_code:
+                    exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
+                    try:
+                        exit_code = int(exit_code_text.split()[0])
+                        if exit_code:
+                            message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
+                            raise DeviceError(message.format(exit_code, command, output))
+                    except ValueError:
+                        logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
+                return output
+        except EOF:
+            logger.error('Dropped connection detected.')
+            self.connection_lost = True
+            raise
 
     def logout(self):
         logger.debug('Logging out {}@{}'.format(self.username, self.host))