# /// script
# dependencies = ["cryptography"]
# ///
import os
from cryptography.x509 import load_pem_x509_certificate
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat

FLAG = os.getenv('FLAG', 'flag{dummy}')

def get_pubkey(cert):
    return tostr(cert.public_key().public_bytes(
        encoding=Encoding.DER,
        format=PublicFormat.SubjectPublicKeyInfo,
    ))

def tostr(sig):
    return ''.join(map(str, sig))

def check_cert(cert):
    if cert.signature_algorithm_oid._name != 'sha256WithRSAEncryption':
        return False
    if get_pubkey(cert) != '48130134481369421347213424713111503130115048130110213011021924970598520811352158981504217353124223501521914517310415115122317014628428419615126174118128816126112149243421791019225391201532439637150328314631395125517122415923120751247193223226106601691281881592437815322988192928219510893914011517925210615213752925918821995117205157140116154361972172512267410210801873014715151955794163153402291181401751268539190442001209621822611417818836172432242176258153671894611314662191867512219119012672115245651184332152222323723071331522081401618121066209183763627195109361701738814115213610218733987642095922382237233595716892525093251792031431783415512855162143392359422826802815493520923123822016676137157823523101':
        return False
    return True

def check_sig(cert, data, signature):
    try:
        cert.public_key().verify(
            signature,
            data,
            padding=PKCS1v15(),
            algorithm=SHA256(),
        )
        return True
    except InvalidSignature:
        return False

def verify(cert_bytes, data, signature):
    cert = load_pem_x509_certificate(cert_bytes, default_backend())
    if not check_cert(cert):
        print('Bad cert!')
        return
    if not check_sig(cert, data, signature):
        print('Bad sig!')
        return
    if data != b'gib flag bitti bitti':
        print('Bad data!')
        return
    print(FLAG)

def read_cert():
    line = input('Cert:\n')
    if line != '-----BEGIN CERTIFICATE-----':
        print('Bad format!')
        return
    cert = line + '\n'
    while line != '-----END CERTIFICATE-----':
        line = input()
        cert += line + '\n'
    return cert.encode()

def main():
    print('KUPONG - Coupon Checker v0.1')
    print('Please enter your coupon to get a discounted flag!')
    cert = read_cert() or exit(1)
    data = input('Data: ').encode()
    signature = bytes.fromhex(input('Sig: '))
    verify(cert, data, signature)

if __name__ == '__main__':
    main()
