1
0
mirror of https://github.com/ARM-software/workload-automation.git synced 2025-01-19 04:21:17 +00:00

ssh: attempt to deal with dropped connections

This commit is contained in:
Sergei Trofimov 2015-07-10 11:43:21 +01:00
parent 2e4bda71a8
commit b8e25efdd4
2 changed files with 33 additions and 17 deletions

View File

@ -580,8 +580,9 @@ class LinuxDevice(BaseLinuxDevice):
raise DeviceError('Could not connect to {} after reboot'.format(self.host)) raise DeviceError('Could not connect to {} after reboot'.format(self.host))
def connect(self): # NOQA pylint: disable=R0912 def connect(self): # NOQA pylint: disable=R0912
self.shell = SshShell(password_prompt=self.password_prompt, timeout=self.default_timeout) self.shell = SshShell(password_prompt=self.password_prompt,
self.shell.login(self.host, self.username, self.password, self.keyfile, self.port, telnet=self.use_telnet) timeout=self.default_timeout, telnet=self.use_telnet)
self.shell.login(self.host, self.username, self.password, self.keyfile, self.port)
self._is_ready = True self._is_ready = True
def disconnect(self): # NOQA pylint: disable=R0912 def disconnect(self): # NOQA pylint: disable=R0912

View File

@ -109,13 +109,15 @@ class SshShell(object):
default_password_prompt = '[sudo] password' default_password_prompt = '[sudo] password'
max_cancel_attempts = 5 max_cancel_attempts = 5
def __init__(self, password_prompt=None, timeout=10): def __init__(self, password_prompt=None, timeout=10, telnet=False):
self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
self.timeout = timeout self.timeout = timeout
self.telnet = telnet
self.conn = None self.conn = None
self.lock = threading.Lock() self.lock = threading.Lock()
self.connection_lost = False
def login(self, host, username, password=None, keyfile=None, port=None, timeout=None, telnet=False): def login(self, host, username, password=None, keyfile=None, port=None, timeout=None):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
logger.debug('Logging in {}@{}'.format(username, host)) logger.debug('Logging in {}@{}'.format(username, host))
self.host = host self.host = host
@ -124,7 +126,7 @@ class SshShell(object):
self.keyfile = check_keyfile(keyfile) if keyfile else keyfile self.keyfile = check_keyfile(keyfile) if keyfile else keyfile
self.port = port self.port = port
timeout = self.timeout if timeout is None else timeout timeout = self.timeout if timeout is None else timeout
self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet) self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, self.telnet)
def push_file(self, source, dest, timeout=30): def push_file(self, source, dest, timeout=30):
dest = '{}@{}:{}'.format(self.username, self.host, dest) dest = '{}@{}:{}'.format(self.username, self.host, dest)
@ -143,19 +145,32 @@ class SshShell(object):
command = _give_password(self.password, command) command = _give_password(self.password, command)
return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
def reconnect(self):
self.conn = ssh_get_shell(self.host, self.username, self.password,
self.keyfile, self.port, self.timeout, self.telnet)
def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True): def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
with self.lock: try:
output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors) with self.lock:
if check_exit_code: if self.connection_lost:
exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False) logger.debug('Attempting to reconnect...')
try: self.reconnect()
exit_code = int(exit_code_text.split()[0]) self.connection_lost = False
if exit_code: output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}' if check_exit_code:
raise DeviceError(message.format(exit_code, command, output)) exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
except ValueError: try:
logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text)) exit_code = int(exit_code_text.split()[0])
return output if exit_code:
message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
raise DeviceError(message.format(exit_code, command, output))
except ValueError:
logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
return output
except EOF:
logger.error('Dropped connection detected.')
self.connection_lost = True
raise
def logout(self): def logout(self):
logger.debug('Logging out {}@{}'.format(self.username, self.host)) logger.debug('Logging out {}@{}'.format(self.username, self.host))