#!/usr/bin/env python3

import argparse

arg_parser = argparse.ArgumentParser(
    description="BitBot log decrypting utility")

arg_parser.add_argument("key",
    help="Location of private key for decrypting given log file")
arg_parser.add_argument("log", help="Location of the log file to decrypt")

args = arg_parser.parse_args()

import base64, sys
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding as a_padding

def rsa_decrypt(key, data):
    return key.decrypt(base64.b64decode(data), a_padding.OAEP(
        mgf=a_padding.MGF1(algorithm=hashes.SHA256()),
        algorithm=hashes.SHA256(), label=None))

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding

def aes_decrypt(key: bytes, data_str: str):
    data_bytes = base64.b64decode(data_str)
    iv, data_bytes = data_bytes[:16], data_bytes[16:]

    decryptor = Cipher(algorithms.AES(key), modes.CBC(iv),
        backend=default_backend()).decryptor()
    plain = decryptor.update(data_bytes)+decryptor.finalize()

    unpadder = padding.PKCS7(256).unpadder()
    return (unpadder.update(plain)+unpadder.finalize()).decode("utf8")

with open(args.key, "rb") as key_file:
    key_content = key_file.read()

key = serialization.load_pem_private_key(
    key_content, password=None, backend=default_backend())

if args.log == "-":
    lines = sys.stdin.read().split("\n")
else:
    with open(args.log) as log_file:
        lines = log_file.read().split("\n")
lines = filter(None, lines)

symm_key = None
for line in lines:
    printable = None
    if line[0] == "\x02":
        printable = rsa_decrypt(key, line[1:]).decode("utf8")
    elif line[0] == "\x03":
        symm_key = rsa_decrypt(key, line[1:])
    elif line[0] == "\x04":
        printable = aes_decrypt(symm_key, line[1:])
    else:
        printable = line

    if not printable == None:
        try:
            sys.stdout.write("%s\n" % printable)
        except BrokenPipeError:
            break