#    Copyright 2014-2015 ARM Limited
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import subprocess
import re
import threading

import pxssh
from pexpect import EOF, TIMEOUT, spawn

from wlauto.exceptions import HostError, DeviceError, TimeoutError, ConfigError
from wlauto.utils.misc import which, strip_bash_colors, escape_single_quotes, check_output

ssh = None
scp = None
sshpass = None

logger = logging.getLogger('ssh')

def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False):
    if telnet:
        if keyfile:
            raise ConfigError('keyfile may not be used with a telnet connection.')
        conn = TelnetConnection()
    else:  # ssh
        conn = pxssh.pxssh()
        if keyfile:
            conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout)
            conn.login(host, username, password, port=port, login_timeout=timeout)
    except EOF:
        raise DeviceError('Could not connect to {}; is the host name correct?'.format(host))
    return conn

class TelnetConnection(pxssh.pxssh):
    # pylint: disable=arguments-differ

    def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10,
              auto_prompt_reset=True, sync_multiplier=1):
        cmd = 'telnet -l {} {}'.format(username, server)

        spawn._spawn(self, cmd)  # pylint: disable=protected-access
        i = self.expect('(?i)(?:password)', timeout=login_timeout)
        if i == 0:
            i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout)
            raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt')

        if i:
            raise pxssh.ExceptionPxssh('could not log in: password was incorrect')

        if not self.sync_original_prompt(sync_multiplier):
            raise pxssh.ExceptionPxssh('could not synchronize with original prompt')

        if auto_prompt_reset:
            if not self.set_unique_prompt():
                message = 'could not set shell prompt (recieved: {}, expected: {}).'
                raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT))
        return True

class SshShell(object):

    default_password_prompt = '[sudo] password'

    def __init__(self, password_prompt=None, timeout=10):
        self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
        self.timeout = timeout
        self.conn = None
        self.lock = threading.Lock()

    def login(self, host, username, password=None, keyfile=None, port=None, timeout=None, telnet=False):
        # pylint: disable=attribute-defined-outside-init
        logger.debug('Logging in {}@{}'.format(username, host))
        self.host = host
        self.username = username
        self.password = password
        self.keyfile = keyfile
        self.port = port
        timeout = self.timeout if timeout is None else timeout
        self.conn = ssh_get_shell(host, username, password, keyfile, port, timeout, telnet)

    def push_file(self, source, dest, timeout=30):
        dest = '{}@{}:{}'.format(self.username, self.host, dest)
        return self._scp(source, dest, timeout)

    def pull_file(self, source, dest, timeout=30):
        source = '{}@{}:{}'.format(self.username, self.host, source)
        return self._scp(source, dest, timeout)

    def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE):
        port_string = '-p {}'.format(self.port) if self.port else ''
        keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
        command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command)
        if self.password:
            command = _give_password(self.password, command)
        return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)

    def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
        with self.lock:
            output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
            if check_exit_code:
                exit_code = int(self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False))
                if exit_code:
                    message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
                    raise DeviceError(message.format(exit_code, command, output))
            return output

    def logout(self):
        logger.debug('Logging out {}@{}'.format(self.username, self.host))

    def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True):
        timeout = self.timeout if timeout is None else timeout
        if as_root:
            command = "sudo -- sh -c '{}'".format(escape_single_quotes(command))
            if log:
            index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5)
            if index == 0:
            timed_out = not self.conn.prompt(timeout)
            output = re.sub(r'.*?{}'.format(re.escape(command)), '', self.conn.before, 1).strip()
            if log:
            timed_out = not self.conn.prompt(timeout)
            # the regex removes line breaks potentiall introduced when writing
            # command to shell.
            output = re.sub(r' \r([^\n])', r'\1', self.conn.before)
            command_index = output.find(command)
            while not timed_out and command_index == -1:
                # In case of a "premature" timeout (i.e. timeout, but no hang,
                # so command completes afterwards), there may be a prompt from
                # the previous command completion in the serial output. This
                # checks for this case by making sure that the original command
                # is present in the serial output and waiting for the next
                # prompt if it is not.
                output = re.sub(r' \r([^\n])', r'\1', self.conn.before)
                command_index = output.find(command)
            output = output[command_index + len(command):].strip()
        if timed_out:
            raise TimeoutError(command, output)
        if strip_colors:
            output = strip_bash_colors(output)
        return output

    def _scp(self, source, dest, timeout=30):
        # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely)
        # fails to connect to a device if port is explicitly specified using -P
        # option, even if it is the default port, 22. To minimize this problem,
        # only specify -P for scp if the port is *not* the default.
        port_string = '-P {}'.format(self.port) if (self.port and self.port != 22) else ''
        keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
        command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest)
        pass_string = ''
        if self.password:
            command = _give_password(self.password, command)
            check_output(command, timeout=timeout, shell=True)
        except subprocess.CalledProcessError as e:
            raise subprocess.CalledProcessError(e.returncode, e.cmd.replace(pass_string, ''), e.output)
        except TimeoutError as e:
            raise TimeoutError(e.command.replace(pass_string, ''), e.output)

def _give_password(password, command):
    if not sshpass:
        raise HostError('Must have sshpass installed on the host in order to use password-based auth.')
    pass_string = "sshpass -p '{}' ".format(password)
    return pass_string + command

def _check_env():
    global ssh, scp, sshpass  # pylint: disable=global-statement
    if not ssh:
        ssh = which('ssh')
        scp = which('scp')
        sshpass = which('sshpass')
    if not (ssh and scp):
        raise HostError('OpenSSH must be installed on the host.')