Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
16 kB
1
Indexable
Never
#!/usr/bin/env python3

import argparse, codecs, hmac, socket, sys, time, os, datetime
from hashlib import sha1, sha256
from Crypto.Cipher import ARC4
from pyasn1.codec.der import decoder  # do not use any other imports/libraries
from urllib.parse import urlparse

# took 14 hours (please specify here how much time your solution required)

# had to set up VM with ubuntu to run the tls server, but after I send the client finished message
# the tls server crashes with errno 104 Connection reset by peer
# so I can't even see the debug output as it just auto-closes the gnome terminal tab...
# EDIT: had to enable some property so that it doesn't close and I can finally debug

# also the lecture slides are misleading:
# slide 16 says the RSA encrypted premaster secret should have  a length of 256, while
# slide 26 shows the debugging output "premaster length: 128"
# wasted about 4 hours trying to figure out why my premaster wasn't 256 like slide 16 says it should be
# EDIT: apparently the premaster secret needs 2 byte length prefix which was on the slides but in a really confusing way

# parse arguments
parser = argparse.ArgumentParser(description='TLS v1.2 client')
parser.add_argument('url', type=str, help='URL to request')
parser.add_argument('--certificate', type=str, help='File to write PEM-encoded server certificate')
args = parser.parse_args()

def get_pubkey_certificate(cert):
    # reads the certificate and returns (n, e)

    content = cert
    if content.startswith(b'-----'):
        public_key_der = pem_to_der(content)
    else:
        public_key_der = content

    # DER-decode the DER to get public key DER structure, which is encoded as BITSTRING
    der = decoder.decode(public_key_der)

    # convert BITSTRING to bytestring
    subject_public_key_info = der[0][0][6]
    bytestring = subject_public_key_info[1].asOctets()

    # DER-decode the bytestring (which is actually DER) and return (N, e)
    decoded_pub_key = decoder.decode(bytestring)
    pubkey = decoded_pub_key[0][0], decoded_pub_key[0][1]
    return int(pubkey[0]), int(pubkey[1])

def pem_to_der(content):
    # converts PEM content to DER
    headers = [b'-----BEGIN RSA PRIVATE KEY-----', b'-----BEGIN PRIVATE KEY-----',
               b'-----BEGIN RSA PUBLIC KEY-----', b'-----BEGIN PUBLIC KEY-----']
    footers = [b'-----END RSA PRIVATE KEY-----', b'-----END PRIVATE KEY-----',
               b'-----END RSA PUBLIC KEY-----', b'-----END PUBLIC KEY-----']
    for h in headers:
        content = content.replace(h, b'')
    for f in footers:
        content = content.replace(f, b'')
    content = codecs.decode(content, 'base64')
    return content


def pkcsv15pad_encrypt(plaintext, n):
    # pad plaintext for encryption according to PKCS#1 v1.5

    # calculate byte size of modulus n
    k = (n.bit_length()+7)//8

    # plaintext must be at least 11 bytes smaller than modulus
    if len(plaintext) > (k - 11):
        print("[-] Plaintext larger than modulus - 11 bytes")
        sys.exit(1)

    # generate padding bytes
    padding_len = k - len(plaintext) - 3
    padding = b""
    for i in range(padding_len):
        padbyte = os.urandom(1)
        while padbyte==b"\x00":
            padbyte = os.urandom(1)
        padding += padbyte

    return b"\x00\x02" + padding + b"\x00" + plaintext

def rsa_encrypt(cert, m):
    # encrypts message m using public key from certificate cert

    # reusing rsa homework code
    plaintext = m
    n, e = get_pubkey_certificate(cert)

    # 1. Pad plaintext
    padded_plaintext = pkcsv15pad_encrypt(plaintext, n)

    # 2. Convert padded byte string to integer
    m = bn(padded_plaintext)

    # 3. Calculate ciphertext
    c = pow(m, e, n)

    # 4. Convert ciphertext integer to byte string
    ciphertext = nb(c, (n.bit_length()+7)//8)

    return ciphertext

def nb(i, length=False):
    # converts integer to bytes
    b = b''
    if length==False:
        length = (i.bit_length()+7)//8
    for _ in range(length):
        b = bytes([i & 0xff]) + b
        i >>= 8
    return b

def bn(b):
    # converts bytes to integer
    i = 0
    for byte in b:
        i <<= 8
        i |= byte
    return i

# returns TLS record that contains ClientHello handshake message
def client_hello():
    global client_random, handshake_messages

    print("--> ClientHello()")

    # list of cipher suites the client supports
    csuite = b"\x00\x05" # TLS_RSA_WITH_RC4_128_SHA

    # add Handshake body
    highest_tls_supported = b"\x03\x03"
    timestamp = nb(int(time.time()), 4)
    client_random = timestamp + os.urandom(32 - len(timestamp))
    session_id_len = b"\x00"
    csuite_len = nb(len(csuite), 2)
    compression_methods = b"\x00"
    compression_methods_len = nb(len(compression_methods))
    handshake_message_body = highest_tls_supported + \
                             client_random + \
                             session_id_len + \
                             csuite_len + \
                             csuite + \
                             compression_methods_len + \
                             compression_methods

    # add Handshake message header
    handshake_message_type = b"\x01"  # client_hello(1)
    handshake_message_len = nb(len(handshake_message_body), 3)
    handshake_message_header = handshake_message_type + handshake_message_len
    handshake = handshake_message_header + handshake_message_body

    handshake_messages += handshake

    # add record layer header
    record_message_type = b"\x16"  # handshake
    record_tls_version = b"\x03\x03"
    record_data_length = nb(len(handshake), 2)
    record_layer_header = record_message_type + record_tls_version + record_data_length

    record = record_layer_header + handshake

    return record

# returns TLS record that contains ClientKeyExchange message containing encrypted pre-master secret
def client_key_exchange():
    global server_cert, premaster, handshake_messages

    print("--> ClientKeyExchange()")
    premaster = b"\x03\x03" + os.urandom(46)
    rsa_encrypted_premaster = rsa_encrypt(server_cert, premaster)
    handshake_message_body = nb(len(rsa_encrypted_premaster), 2) + rsa_encrypted_premaster
    handshake_message_type = b"\x10"  # client_key_exchange(16)
    handshake_message_len = nb(len(handshake_message_body), 3)
    handshake_message_header = handshake_message_type + handshake_message_len
    handshake = handshake_message_header + handshake_message_body

    handshake_messages += handshake

    # add record layer header
    record_message_type = b"\x16"  # handshake
    record_tls_version = b"\x03\x03"
    record_data_length = nb(len(handshake), 2)
    record_layer_header = record_message_type + record_tls_version + record_data_length

    record = record_layer_header + handshake

    return record

# returns TLS record that contains ChangeCipherSpec message
def change_cipher_spec():
    print("--> ChangeCipherSpec()")

    record_message_type = b"\x14"  # change_cipher_spec
    record_tls_version = b"\x03\x03"
    record_data_length = nb(1, 2)
    record_message = b"\x01"
    record = record_message_type + record_tls_version + record_data_length + record_message

    return record

# returns TLS record that contains encrypted Finished handshake message
def finished():
    global handshake_messages, master_secret

    print("--> Finished()")
    client_verify = PRF(master_secret, b"client finished" + sha256(handshake_messages).digest(), 12)
    handshake_message_body = client_verify
    handshake_message_type = b"\x14"  # finished(20)
    handshake_message_len = nb(len(handshake_message_body), 3)
    handshake_message_header = handshake_message_type + handshake_message_len
    handshake = handshake_message_header + handshake_message_body

    handshake_messages += handshake  # include unencrypted handshake

    record_message_type = b"\x16"  # handshake
    record_tls_version = b"\x03\x03"
    encrypted_handshake = encrypt(handshake, record_message_type, record_tls_version)
    record_data_length = nb(len(encrypted_handshake), 2)
    record_layer_header = record_message_type + record_tls_version + record_data_length

    record = record_layer_header + encrypted_handshake

    return record

# returns TLS record that contains encrypted Application data
def application_data(data):
    print("--> Application_data()")
    print(data.decode().strip())

    record_message_type = b"\x17"  # application data (23)
    record_tls_version = b"\x03\x03"
    encrypted_application_data = encrypt(data, record_message_type, record_tls_version)
    record_data_length = nb(len(encrypted_application_data), 2)
    record_layer_header = record_message_type + record_tls_version + record_data_length

    record = record_layer_header + encrypted_application_data

    return record

# parse TLS Handshake messages
def parsehandshake(r):
    global server_hello_done_received, server_random, server_cert, handshake_messages, server_change_cipher_spec_received, server_finished_received

    # decrypt if encryption enabled
    if server_change_cipher_spec_received:
        r = decrypt(r, b"\x16", b"\x03\x03")

    # read Handshake message type and length from message header
    htype, hlength = r[0:1], bn(r[1:4])

    body = r[4:4+hlength]
    handshake = r[:4+hlength]
    handshake_messages+= handshake

    if htype == b"\x02":
        print("	<--- ServerHello()")
        server_random = r[6:38]
        timestamp = server_random[:4]
        gmt = datetime.datetime.fromtimestamp(bn(timestamp)).strftime('%Y-%m-%d %H:%M:%S')
        sessid_len = r[38]
        sessid = r[39: 39 + sessid_len]
        bookmark = 39 + sessid_len
        cipher = r[bookmark:bookmark+2]
        compression = r[bookmark + 2: bookmark + 3]

        print("	[+] server randomness:", server_random.hex().upper())
        print("	[+] server timestamp:", gmt)
        print("	[+] TLS session ID:", sessid.hex().upper())

        if cipher==b"\x00\x05":
            print("	[+] Cipher suite: TLS_RSA_WITH_RC4_128_SHA")
        else:
            print("[-] Unsupported cipher suite selected:", cipher.hex())
            sys.exit(1)

        if compression!=b"\x00":
            print("[-] Wrong compression:", compression.hex())
            sys.exit(1)
    elif htype == b"\x0b":
        print("	<--- Certificate()")
        certlen = bn(r[7:10])
        print("	[+] Server certificate length:", certlen)
        certificate = r[10: certlen + 10]
        server_cert = certificate

        #debugging
        #pem = b"-----BEGIN CERTIFICATE-----\n"
        #pem += codecs.encode(certificate, "base64")
        #pem += b"-----END CERTIFICATE-----\n"
        #open('server.pem', 'wb').write(pem)
    elif htype == b"\x0e":
        print("	<--- ServerHelloDone()")
        server_hello_done_received = True
    elif htype == b"\x14":
        print("	<--- Finished()")
        # hashmac of all Handshake messages except the current Finished message (obviously)
        server_verify = body
        verify_data_calc = PRF(master_secret, b"server finished" + sha256(handshake_messages[:-4-hlength]).digest(), 12)
        if server_verify!=verify_data_calc:
            print("[-] Server finished verification failed!")
            sys.exit(1)
        else:
            server_finished_received = True
    else:
        print("[-] Unknown Handshake Type:", htype.hex())
        sys.exit(1)

    # handle the case of several Handshake messages in one record
    leftover = r[4+len(body):]
    if len(leftover):
        parsehandshake(leftover)

# parses TLS record
def parserecord(r):
    global server_change_cipher_spec_received

    # parse TLS record header and pass the record body to the corresponding parsing method
    ctype = r[0:1]
    c = r[5:]

    # handle known types
    if ctype == b"\x16":
        print("<--- Handshake()")
        parsehandshake(c)
    elif ctype == b"\x14":
        print("<--- ChangeCipherSpec()")
        server_change_cipher_spec_received = True
    elif ctype == b"\x15":
        print("<--- Alert()")
        level, desc = c[0], c[1]
        if level == 1:
            print("	[-] warning:", desc)
        elif level == 2:
            print("	[-] fatal:", desc)
            sys.exit(1)
        else:
            sys.exit(1)
    elif ctype == b"\x17":
        print("<--- Application_data()")
        data = decrypt(c, b"\x17", b"\x03\x03")
        print(data.decode().strip())
    else:
        print("[-] Unknown TLS Record type:", ctype.hex())
        sys.exit(1)

# PRF defined in TLS v1.2
def PRF(secret, seed, l):

    out = b""
    A = hmac.new(secret, seed, sha256).digest()
    while len(out) < l:
        out += hmac.new(secret, A + seed, sha256).digest()
        A = hmac.new(secret, A, sha256).digest()
    return out[:l]

# derives master_secret
def derive_master_secret():
    global premaster, master_secret, client_random, server_random
    master_secret = PRF(premaster, b"master secret" + client_random + server_random, 48)

# derives keys for encryption and MAC
def derive_keys():
    global premaster, master_secret, client_random, server_random
    global client_mac_key, server_mac_key, client_enc_key, server_enc_key, rc4c, rc4s

    key_block = PRF(master_secret, b"key expansion" + server_random + client_random, 136)
    mac_size = 20
    key_size = 16
    iv_size = 16

    client_mac_key = key_block[:mac_size]
    server_mac_key = key_block[mac_size:mac_size*2]
    client_enc_key = key_block[mac_size*2:mac_size*2+key_size]
    server_enc_key = key_block[mac_size*2+key_size:mac_size*2+key_size*2]

    rc4c = ARC4.new(client_enc_key)
    rc4s = ARC4.new(server_enc_key)

# HMAC SHA1 wrapper
def HMAC_sha1(key, data):
    return hmac.new(key, data, sha1).digest()

# calculates MAC and encrypts plaintext
def encrypt(plain, type, version):
    global client_mac_key, client_enc_key, client_seq, rc4c

    mac = HMAC_sha1(client_mac_key, nb(client_seq, 8) + type + version + nb(len(plain), 2) + plain)
    ciphertext = rc4c.encrypt(plain + mac)
    client_seq+= 1
    return ciphertext

# decrypts ciphertext and verifies MAC
def decrypt(ciphertext, type, version):
    global server_mac_key, server_enc_key, server_seq, rc4s

    d = rc4s.decrypt(ciphertext)
    mac = d[-20:]
    plain = d[:-20]

    # verify MAC
    mac_calc = HMAC_sha1(server_mac_key, nb(server_seq, 8) + type + version + nb(len(plain), 2) + plain)
    if mac!=mac_calc:
        print("[-] MAC verification failed!")
        sys.exit(1)
    server_seq+= 1
    return plain

# read from the socket full TLS record
def readrecord():
    record = b""

    # read TLS record header (5 bytes)
    for _ in range(5):
        record += s.recv(1)

    # find data length
    datalen = bn(record[3:5])

    # read TLS record body
    for _ in range(datalen):
        record+= s.recv(1)

    return record

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
url = urlparse(args.url)
host = url.netloc.split(':')
if len(host) > 1:
    port = int(host[1])
else:
    port = 443
host = host[0]


path = url.path

client_random = b""	# will hold client randomness
server_random = b""	# will hold server randomness
server_cert = b""	# will hold DER encoded server certificate
premaster = b""		# will hold 48 byte pre-master secret
master_secret = b""	# will hold master secret
handshake_messages = b"" # will hold concatenation of handshake messages

# client/server keys and sequence numbers
client_mac_key = b""
server_mac_key = b""
client_enc_key = b""
server_enc_key = b""
client_seq = 0
server_seq = 0

# client/server RC4 instances
rc4c = b""
rc4s = b""

s.connect((host, port))
s.send(client_hello())

server_hello_done_received = False
server_change_cipher_spec_received = False
server_finished_received = False

while not server_hello_done_received:
    parserecord(readrecord())

s.send(client_key_exchange())
s.send(change_cipher_spec())
derive_master_secret()
derive_keys()
s.send(finished())

while not server_finished_received:
    parserecord(readrecord())

s.send(application_data(b"GET / HTTP/1.0\r\n\r\n"))
parserecord(readrecord())

print("[+] Closing TCP connection!")
s.close()