#!/usr/bin/env python3
#
# Copyright (c) Fortanix, Inc.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import argparse
import os
import string
import sys

from cryptography.exceptions import InvalidTag
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import (
    Cipher, algorithms, modes
)

ENCRYPTION_KEY_SIZE = 64
NONCE_SIZE = 12
TAG_SIZE = 16


def parse_arguments():
    parser = argparse.ArgumentParser(description='AES-256-GCM Utility')
    subparsers = parser.add_subparsers(help="AES-256-GCM sub-commands", dest='command', required=True)
    parser_enc = subparsers.add_parser("encrypt", aliases=["enc"], help="encrypt a file")
    parser_dec = subparsers.add_parser("decrypt", aliases=["dec"], help="decrypt a file")
    parser_gen = subparsers.add_parser("gen_key", aliases=["gen"],
                                       help="generate a 64 character hex string as the encryption key")

    # encrypt sub-command
    parser_enc.add_argument('-i', '-in', '--input-file', required=True, dest='input_file',
                            type=argparse.FileType('rb'), help="the input file to encrypt")
    parser_enc.add_argument('-o', '-out', '--output-file', required=True, dest='output_file',
                            type=argparse.FileType('wb'), help="the encrypted output file")
    parser_enc.add_argument('-k', '-K', '--key-file', required=True, dest='key_file',
                            help="the file containing the encryption key (64 character hex string)")
    parser_enc.add_argument('-n', '--new-key', required=False, dest='new_key', action="store_true", default=False,
                            help="encrypt using a newly generated key, and store the new key to `key-file`")

    # decrypt sub-command
    parser_dec.add_argument('-i', '-in', '--input-file', required=True, dest='input_file',
                            type=argparse.FileType('rb'), help="the input file to decrypt")
    parser_dec.add_argument('-o', '-out', '--output-file', required=True, dest='output_file',
                            type=argparse.FileType('wb'), help="the decrypted output file")
    parser_dec.add_argument('-k', '-K', '--key-file', required=True, dest='key_file',
                            help="the file containing the encryption key (64 character hex string)")

    # generate key sub-command
    parser_gen.add_argument('-k', '-K', '--key-file', required=True, dest='key_file',
                            help="the file where the encryption key will be stored")

    return parser.parse_args()


def get_encryption_key_from_file(filename):
    """
    Extract the encryption key from the given filename.
    :param filename: the file containing the encryption key
    :return: the encryption key
    """
    with open(filename, "r") as f:
        encryption_key = f.readline().strip()
        if len(encryption_key) != 64:
            raise ValueError("Encryption key must be a 64 character hex string")
        if not all(c in string.hexdigits for c in encryption_key):
            raise ValueError("Encryption key must only contain hex characters")
    return encryption_key


def generate_encryption_key():
    """
    Generate a random 64 character long hex string to be used as the encryption key.
    :return: a newly generated random encryption key
    """
    import random
    hex_string = "0123456789abcdef"
    return ''.join([random.choice(hex_string) for _ in range(ENCRYPTION_KEY_SIZE)])


def encrypt_file_command(args):
    if args.new_key:
        encryption_key = generate_encryption_key()
        with open(args.key_file, "w") as f:
            f.write(encryption_key)
    else:
        try:
            encryption_key = get_encryption_key_from_file(args.key_file)
        except ValueError as e:
            print("Could not get encryption key from file: " + e.value)
            sys.exit(1)

    key = bytes.fromhex(encryption_key)
    nonce = os.urandom(NONCE_SIZE)
    cipher = Cipher(algorithms.AES(key), modes.GCM(nonce), backend=default_backend())
    encryptor = cipher.encryptor()

    args.output_file.write(nonce)

    chunk_size = 8192
    while True:
        chunk = args.input_file.read(chunk_size)
        if not chunk:
            break
        ciphertext = encryptor.update(chunk)
        if ciphertext:
            args.output_file.write(ciphertext)

    ciphertext = encryptor.finalize()
    if ciphertext:
        args.output_file.write(ciphertext)
    args.output_file.write(encryptor.tag)
    args.output_file.close()
    print("Encrypted file `%s` to file `%s`" % (args.input_file.name, args.output_file.name))


def decrypt_file_command(args):
    try:
        encryption_key = get_encryption_key_from_file(args.key_file)
    except ValueError as e:
        print("Could not get encryption key from file `%s`: %s" % (args.key_file, str(e)))
        sys.exit(1)

    key = bytes.fromhex(encryption_key)

    # Read tag from the end of the file
    args.input_file.seek(-1 * TAG_SIZE, 2)
    last_position = args.input_file.tell()
    tag = args.input_file.read(TAG_SIZE)

    # Read nonce from the start
    args.input_file.seek(0, 0)
    nonce = args.input_file.read(NONCE_SIZE)

    # Ciphertext is in the middle
    cipher = Cipher(algorithms.AES(key), modes.GCM(nonce, tag), backend=default_backend())
    decryptor = cipher.decryptor()

    chunk_size = 8192
    remaining = last_position - NONCE_SIZE
    while remaining > 0:
        if remaining - chunk_size > 0:
            buf = args.input_file.read(chunk_size)
            remaining = remaining - chunk_size
        else:
            buf = args.input_file.read(remaining)
            remaining = 0

        ciphertext = decryptor.update(buf)
        if ciphertext:
            args.output_file.write(ciphertext)

    try:
        ciphertext = decryptor.finalize()
    except InvalidTag:
        print("Decryption failed. Invalid ciphertext or encryption key (Invalid ciphertext tag).")
        sys.exit(1)

    if ciphertext:
        args.output_file.write(ciphertext)
    args.output_file.close()
    print("Decrypted file `%s` to file `%s`" % (args.input_file.name, args.output_file.name))


def generate_key_command(args):
    encryption_key = generate_encryption_key()
    with open(args.key_file, "w") as f:
        f.write(encryption_key)
    print("Generated encryption key in file `%s`" % args.key_file)


if __name__ == '__main__':
    args = parse_arguments()
    if args.command in ["encrypt", "enc"]:
        encrypt_file_command(args)
    elif args.command in ["decrypt", "dec"]:
        decrypt_file_command(args)
    elif args.command in ["gen_key", "gen"]:
        generate_key_command(args)
