#!/usr/bin/env python
#
# Script to put copyright headers into source files.
#
import argparse
import logging
import os
import re
import string
import subprocess
from datetime import datetime

SOURCE_EXTENSIONS = {
    '.py': ('#', '#', '#'),
    '.sh': ('#', '#', '#'),
    '.java': ('/*', '*/', ' *'),
    '.c': ('/*', '*/', ' *'),
    '.h': ('/*', '*/', ' *'),
    '.cpp': ('/*', '*/', ' *'),
}

OLD_HEADER_TEMPLATE = string.Template(
"""${begin_symbol} $$Copyright:
${symbol} ----------------------------------------------------------------
${symbol} This confidential and proprietary software may be used only as
${symbol} authorised by a licensing agreement from ARM Limited
${symbol}  (C) COPYRIGHT ${year} ARM Limited
${symbol}       ALL RIGHTS RESERVED
${symbol} The entire notice above must be reproduced on all authorised
${symbol} copies and copies may only be made to the extent permitted
${symbol} by a licensing agreement from ARM Limited.
${symbol} ----------------------------------------------------------------
${symbol} File:        ${file}
${symbol} ----------------------------------------------------------------
${symbol} $$
${end_symbol}
"""
)

HEADER_TEMPLATE = string.Template(
"""${begin_symbol}    Copyright ${year} ARM Limited
${symbol}
${symbol} Licensed under the Apache License, Version 2.0 (the "License");
${symbol} you may not use this file except in compliance with the License.
${symbol} You may obtain a copy of the License at
${symbol}
${symbol}     http://www.apache.org/licenses/LICENSE-2.0
${symbol}
${symbol} Unless required by applicable law or agreed to in writing, software
${symbol} distributed under the License is distributed on an "AS IS" BASIS,
${symbol} WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
${symbol} See the License for the specific language governing permissions and
${symbol} limitations under the License.
${end_symbol}
"""
)

# Minimum length, in characters, of a copy right header.
MIN_HEADER_LENGTH = 500

OLD_COPYRIGHT_REGEX = re.compile(r'\(C\) COPYRIGHT\s+(?:(\d+)-)?(\d+)')
COPYRIGHT_REGEX = re.compile(r'Copyright\s+(?:(\d+)\s*[-,]\s*)?(\d+) ARM Limited')

DEFAULT_EXCLUDE_PATHS = [
    os.path.join('wa', 'commands', 'templates'),
]


logging.basicConfig(level=logging.INFO, format='%(levelname)-8s %(message)s')


def remove_old_copyright(filepath):
    begin_symbol, end_symbol, symbol = SOURCE_EXTENSIONS[ext.lower()]
    header = HEADER_TEMPLATE.substitute(begin_symbol=begin_symbol,
                                        end_symbol=end_symbol,
                                        symbol=symbol,
                                        year='0',
                                        file=os.path.basename(filepath))
    header_line_count = len(header.splitlines())
    with open(filepath) as fh:
        lines = fh.readlines()
    for i, line in enumerate(lines):
        if OLD_COPYRIGHT_REGEX.search(line):
            start_line = i -4
            break
    lines = lines[0:start_line] + lines[start_line + header_line_count:]
    return ''.join(lines)


def add_copyright_header(filepath, year):
    _, ext = os.path.splitext(filepath)
    begin_symbol, end_symbol, symbol = SOURCE_EXTENSIONS[ext.lower()]
    with open(filepath) as fh:
        text = fh.read()
    match = OLD_COPYRIGHT_REGEX.search(text)
    if match:
        _, year = update_year(text, year, copyright_regex=OLD_COPYRIGHT_REGEX)
        text = remove_old_copyright(filepath)
    header = HEADER_TEMPLATE.substitute(begin_symbol=begin_symbol,
                                        end_symbol=end_symbol,
                                        symbol=symbol,
                                        year=year)
    if text.strip().startswith('#!') or text.strip().startswith('# -*-'):
        first_line, rest = text.split('\n', 1)
        updated_text = '\n'.join([first_line, header, rest])
    else:
        updated_text = '\n'.join([header, text])
    with open(filepath, 'w') as wfh:
        wfh.write(updated_text)


def update_year(text, year, copyright_regex=COPYRIGHT_REGEX, match=None):
    if match is None:
        match = copyright_regex.search(text)
    old_year = match.group(1) or match.group(2)
    updated_year_text = 'Copyright {}-{} ARM Limited'.format(old_year, year)
    if old_year == year:
        ret_year = '{}'.format(year)
    else:
        ret_year = '{}-{}'.format(old_year, year)
    return (text.replace(match.group(0), updated_year_text), ret_year)


def get_git_year(path):
    info = subprocess.check_output('git log -n 1 {}'.format(os.path.basename(path)),
                                   shell=True, cwd=os.path.dirname(path))
    if not info.strip():
        return None

    i = 1
    while 'copyright' in info.lower():
        info = subprocess.check_output('git log -n 1 --skip {} {}'.format(i, os.path.basename(path)),
                                    shell=True, cwd=os.path.dirname(path))
        if not info.strip():
            return None

    info_split_lines = info.split('\n')
    info_split_words = info_split_lines[2].split()
    return int(info_split_words[5])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('path', help='Location to add copyrights to source files in.')
    parser.add_argument('-n', '--update-no-ext', action='store_true',
                        help='Will update files with on textension using # as the comment symbol.')
    parser.add_argument('-x', '--exclude', action='append',
                        help='Exclude this directory form the scan. May be used multiple times.')
    args = parser.parse_args()

    if args.update_no_ext:
        SOURCE_EXTENSIONS[''] = ('#', '#', '#')

    exclude_paths = DEFAULT_EXCLUDE_PATHS + (args.exclude or [])

    current_year = datetime.now().year
    for root, dirs, files in os.walk(args.path):
        should_skip = False
        for exclude_path in exclude_paths:
            if exclude_path in os.path.realpath(root):
                should_skip = True
                break
        if should_skip:
            logging.info('Skipping {}'.format(root))
            continue

        logging.info('Checking {}'.format(root))
        for entry in files:
            _, ext = os.path.splitext(entry)
            if ext.lower() in SOURCE_EXTENSIONS:
                filepath = os.path.join(root, entry)
                should_skip = False
                for exclude_path in exclude_paths:
                    if exclude_path in os.path.realpath(filepath):
                        should_skip = True
                        break
                if should_skip:
                    logging.info('\tSkipping {}'.format(entry))
                    continue
                with open(filepath) as fh:
                    text = fh.read()
                if not text.strip():
                    logging.info('\tSkipping empty  {}'.format(entry))
                    continue

                year_modified = get_git_year(filepath) or current_year
                if len(text) < MIN_HEADER_LENGTH:
                    logging.info('\tAdding header to {}'.format(entry))
                    add_copyright_header(filepath, year_modified)
                else:
                    first_chunk = text[:MIN_HEADER_LENGTH]
                    match = COPYRIGHT_REGEX.search(first_chunk)
                    if not match:
                        if OLD_COPYRIGHT_REGEX.search(first_chunk):
                            logging.warn('\tOld copyright message detected and replaced in {}'.format(entry))
                            add_copyright_header(filepath, year_modified)
                        elif '(c)' in first_chunk or '(C)' in first_chunk:
                            logging.warn('\tAnother copyright header appears to be in {}'.format(entry))
                        else:
                            logging.info('\tAdding header to {}'.format(entry))
                            add_copyright_header(filepath, current_year)
                    else:
                        # Found an existing copyright header. Update the
                        # year if needed, otherwise, leave it alone.
                        last_year = int(match.group(2))
                        if year_modified > last_year:
                            logging.info('\tUpdating year in {}'.format(entry))
                            text, _ = update_year(text, year_modified, COPYRIGHT_REGEX, match)
                            with open(filepath, 'w') as wfh:
                                wfh.write(text)
                        else:
                            logging.info('\t{}: OK'.format(entry))