#!/usr/bin/env python3
#
# Copyright (c) Fortanix, Inc.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/. */

import os
import sys
import datetime
import time
import argparse
import string

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import (
    Cipher, algorithms, modes
)

parser = argparse.ArgumentParser(description='AES-256-GCM Utility')

parser.add_argument('operation', choices=['enc', 'dec'],
                    help="Operation to perform")

parser.add_argument('-K', dest='key_file', type=argparse.FileType('r'),
                    help="Key file for operation, 64 character hex string file.")

parser.add_argument('-in', dest='input_file', type=argparse.FileType('rb'),
                    help="Source file to encrypt/decrypt")

parser.add_argument('-out', dest='output_file', type=argparse.FileType('wb'),
                    help="Target file")

args = parser.parse_args()


if args.key_file is None:
    raise Exception("Missing input file")

if args.input_file is None:
    raise Exception("Missing input file")

if args.output_file is None:
    raise Exception("Missing input file")

NONCE_SIZE=12
TAG_SIZE=16

key=args.key_file.readline().rstrip();
if len(key) != 64:
    raise Exception("Key file must be 64 bytes hex string for AES-256-GCM")

if not all(c in string.hexdigits for c in key) or len(key) != 64:
    raise Exception("Key must be a 64 character hex stream")

key=bytes.fromhex(key)

if args.operation == 'enc':
    nonce=os.urandom(NONCE_SIZE)

    cipher = Cipher(algorithms.AES(key), modes.GCM(nonce), backend=default_backend())
    encryptor = cipher.encryptor()

    args.output_file.write(nonce)
    
    chunksize=8192
    while True:
        chunk = args.input_file.read(chunksize)
        if not chunk:
            break
        
        ciphertext = encryptor.update(chunk)
        if ciphertext:
            args.output_file.write(ciphertext)

    ciphertext = encryptor.finalize()
    
    if ciphertext:
        args.output_file.write(ciphertext)

    args.output_file.write(encryptor.tag)
    args.output_file.close()
    
elif args.operation == 'dec':
    # Read tag from the end of the file
    args.input_file.seek(-1 * TAG_SIZE, 2)
    last_position=args.input_file.tell()
    tag=args.input_file.read(TAG_SIZE)
    
    # Read nonce from the start
    args.input_file.seek(0, 0)
    nonce=args.input_file.read(NONCE_SIZE)

    # Ciphertext is in the middle
    cipher = Cipher(algorithms.AES(key), modes.GCM(nonce, tag), backend=default_backend())
    decryptor = cipher.decryptor()

    chunksize=8192
    remaining=last_position-NONCE_SIZE
    while remaining > 0:
        if remaining-chunksize > 0:
            buf = args.input_file.read(chunksize)
            remaining=remaining-chunksize
        else:
            buf = args.input_file.read(remaining)
            remaining=0
    
        ciphertext = decryptor.update(buf)
        if ciphertext:
            args.output_file.write(ciphertext)

    ciphertext = decryptor.finalize()
    if ciphertext:
        args.output_file.write(ciphertext)

    args.output_file.close()
