#!/usr/bin/env python3
import argparse, codecs, hmac, socket, sys, time, os, datetime
from hashlib import sha1, sha256
from Crypto.Cipher import ARC4
from pyasn1.codec.der import decoder # do not use any other imports/libraries
from urllib.parse import urlparse
# took x.y hours (please specify here how much time your solution required)
# parse arguments
parser = argparse.ArgumentParser(description='TLS v1.2 client')
parser.add_argument('url', type=str, help='URL to request')
parser.add_argument('--certificate', type=str, help='File to write PEM-encoded server certificate')
args = parser.parse_args()
def get_pubkey_certificate(cert):
# reads the certificate and returns (n, e)
content = cert
if content.startswith(b'-----'):
public_key_der = pem_to_der(content)
else:
public_key_der = content
# DER-decode the DER to get public key DER structure, which is encoded as BITSTRING
der = decoder.decode(public_key_der)
# convert BITSTRING to bytestring
subject_public_key_info = der[0][0][6]
bytestring = subject_public_key_info[1].asOctets()
# DER-decode the bytestring (which is actually DER) and return (N, e)
decoded_pub_key = decoder.decode(bytestring)
pubkey = decoded_pub_key[0][0], decoded_pub_key[0][1]
return int(pubkey[0]), int(pubkey[1])
def pem_to_der(content):
# converts PEM content to DER
headers = [b'-----BEGIN RSA PRIVATE KEY-----', b'-----BEGIN PRIVATE KEY-----',
b'-----BEGIN RSA PUBLIC KEY-----', b'-----BEGIN PUBLIC KEY-----']
footers = [b'-----END RSA PRIVATE KEY-----', b'-----END PRIVATE KEY-----',
b'-----END RSA PUBLIC KEY-----', b'-----END PUBLIC KEY-----']
for h in headers:
content = content.replace(h, b'')
for f in footers:
content = content.replace(f, b'')
content = codecs.decode(content, 'base64')
return content
def pkcsv15pad_encrypt(plaintext, n):
# pad plaintext for encryption according to PKCS#1 v1.5
# calculate byte size of modulus n
k = (n.bit_length()+7)//8
# plaintext must be at least 11 bytes smaller than modulus
if len(plaintext) > (k - 11):
print("[-] Plaintext larger than modulus - 11 bytes")
sys.exit(1)
# generate padding bytes
padding_len = k - len(plaintext) - 3
padding = b""
for i in range(padding_len):
padbyte = os.urandom(1)
while padbyte==b"\x00":
padbyte = os.urandom(1)
padding += padbyte
return b"\x00\x02" + padding + b"\x00" + plaintext
def rsa_encrypt(cert, m):
# encrypts message m using public key from certificate cert
# reusing rsa homework code
plaintext = m
n, e = get_pubkey_certificate(cert)
# 1. Pad plaintext
padded_plaintext = pkcsv15pad_encrypt(plaintext, n)
# 2. Convert padded byte string to integer
m = bn(padded_plaintext)
# 3. Calculate ciphertext
c = pow(m, e, n)
# 4. Convert ciphertext integer to byte string
ciphertext = nb(c)
return ciphertext
def nb(i, length=False):
# converts integer to bytes
b = b''
if length==False:
length = (i.bit_length()+7)//8
for _ in range(length):
b = bytes([i & 0xff]) + b
i >>= 8
return b
def bn(b):
# converts bytes to integer
i = 0
for byte in b:
i <<= 8
i |= byte
return i
# returns TLS record that contains ClientHello handshake message
def client_hello():
global client_random, handshake_messages
print("--> ClientHello()")
# list of cipher suites the client supports
csuite = b"\x00\x05" # TLS_RSA_WITH_RC4_128_SHA
# add Handshake body
highest_tls_supported = b"\x03\x03"
timestamp = nb(int(time.time()), 4)
client_random = timestamp + os.urandom(32 - len(timestamp))
session_id_len = b"\x00"
csuite_len = nb(len(csuite), 2)
compression_methods = b"\x00"
compression_methods_len = nb(len(compression_methods))
handshake_message_body = highest_tls_supported + \
client_random + \
session_id_len + \
csuite_len + \
csuite + \
compression_methods_len + \
compression_methods
# add Handshake message header
handshake_message_type = b"\x01" # client_hello(1)
handshake_message_len = nb(len(handshake_message_body), 3)
handshake_message_header = handshake_message_type + handshake_message_len
handshake = handshake_message_header + handshake_message_body
handshake_messages += handshake
# add record layer header
record_message_type = b"\x16" # handshake
record_tls_version = b"\x03\x03"
record_data_length = nb(len(handshake), 2)
record_layer_header = record_message_type + record_tls_version + record_data_length
record = record_layer_header + handshake
return record
# returns TLS record that contains ClientKeyExchange message containing encrypted pre-master secret
def client_key_exchange():
global server_cert, premaster, handshake_messages
print("--> ClientKeyExchange()")
premaster = b"\x03\x03" + os.urandom(46)
message = premaster
handshake_message_body = rsa_encrypt(server_cert, message)
handshake_message_type = b"\x10" # client_key_exchange(16)
handshake_message_len = nb(len(handshake_message_body), 3)
handshake_message_header = handshake_message_type + handshake_message_len
handshake = handshake_message_header + handshake_message_body
handshake_messages += handshake
# add record layer header
record_message_type = b"\x16" # handshake
record_tls_version = b"\x03\x03"
record_data_length = nb(len(handshake), 2)
record_layer_header = record_message_type + record_tls_version + record_data_length
record = record_layer_header + handshake
return record
# returns TLS record that contains ChangeCipherSpec message
def change_cipher_spec():
print("--> ChangeCipherSpec()")
record_message_type = b"\x14" # change_cipher_spec
record_tls_version = b"\x03\x03"
record_data_length = nb(1, 2)
record_message = b"\x01"
record = record_message_type + record_tls_version + record_data_length + record_message
return record
# returns TLS record that contains encrypted Finished handshake message
def finished():
global handshake_messages, master_secret
print("--> Finished()")
client_verify = PRF(master_secret, b"client finished" + sha256(handshake_messages).digest(), 12)
handshake_message_body = client_verify
handshake_message_type = b"\x14" # finished(20)
handshake_message_len = nb(len(handshake_message_body), 3)
handshake_message_header = handshake_message_type + handshake_message_len
handshake = handshake_message_header + handshake_message_body
handshake_messages += handshake
record_message_type = b"\x16" # handshake
record_tls_version = b"\x03\x03"
encrypted_handshake = encrypt(handshake, record_message_type, record_tls_version)
record_data_length = nb(len(encrypted_handshake), 2)
record_layer_header = record_message_type + record_tls_version + record_data_length
record = record_layer_header + encrypted_handshake
return record
# returns TLS record that contains encrypted Application data
def application_data(data):
print("--> Application_data()")
print(data.decode().strip())
return record
# parse TLS Handshake messages
def parsehandshake(r):
global server_hello_done_received, server_random, server_cert, handshake_messages, server_change_cipher_spec_received, server_finished_received
# decrypt if encryption enabled
if server_change_cipher_spec_received:
r = decrypt(r, b"\x16", b"\x03\x03")
# read Handshake message type and length from message header
htype, hlength = r[0:1], bn(r[1:4])
body = r[4:4+hlength]
handshake = r[:4+hlength]
handshake_messages+= handshake
if htype == b"\x02":
print(" <--- ServerHello()")
server_random = r[6:38]
timestamp = server_random[:4]
gmt = datetime.datetime.fromtimestamp(bn(timestamp)).strftime('%Y-%m-%d %H:%M:%S')
sessid_len = r[38]
sessid = r[39: 39 + sessid_len]
bookmark = 39 + sessid_len
cipher = r[bookmark:bookmark+2]
compression = r[bookmark + 2: bookmark + 3]
print(" [+] server randomness:", server_random.hex().upper())
print(" [+] server timestamp:", gmt)
print(" [+] TLS session ID:", sessid.hex().upper())
if cipher==b"\x00\x2f":
print(" [+] Cipher suite: TLS_RSA_WITH_AES_128_CBC_SHA")
elif cipher==b"\x00\x35":
print(" [+] Cipher suite: TLS_RSA_WITH_AES_256_CBC_SHA")
elif cipher==b"\x00\x05":
print(" [+] Cipher suite: TLS_RSA_WITH_RC4_128_SHA")
else:
print("[-] Unsupported cipher suite selected:", cipher.hex())
sys.exit(1)
if compression!=b"\x00":
print("[-] Wrong compression:", compression.hex())
sys.exit(1)
elif htype == b"\x0b":
print(" <--- Certificate()")
certlen = bn(r[7:10])
print(" [+] Server certificate length:", certlen)
certificate = r[10: certlen + 10]
server_cert = certificate
#debugging
#pem = b"-----BEGIN CERTIFICATE-----\n"
#pem += codecs.encode(certificate, "base64")
#pem += b"-----END CERTIFICATE-----\n"
#open('server.pem', 'wb').write(pem)
elif htype == b"\x0e":
print(" <--- ServerHelloDone()")
server_hello_done_received = True
elif htype == b"\x14":
print(" <--- Finished()")
# hashmac of all Handshake messages except the current Finished message (obviously)
verify_data_calc = PRF(master_secret, b"server finished" + sha256(handshake_messages[:-4-hlength]).digest(), 12)
if server_verify!=verify_data_calc:
print("[-] Server finished verification failed!")
sys.exit(1)
else:
print("[-] Unknown Handshake Type:", htype.hex())
sys.exit(1)
# handle the case of several Handshake messages in one record
leftover = r[4+len(body):]
if len(leftover):
parsehandshake(leftover)
# parses TLS record
def parserecord(r):
global server_change_cipher_spec_received
# parse TLS record header and pass the record body to the corresponding parsing method
ctype = r[0:1]
c = r[5:]
# handle known types
if ctype == b"\x16":
print("<--- Handshake()")
parsehandshake(c)
elif ctype == b"\x14":
print("<--- ChangeCipherSpec()")
server_change_cipher_spec_received = True
elif ctype == b"\x15":
print("<--- Alert()")
level, desc = c[0], c[1]
if level == 1:
print(" [-] warning:", desc)
elif level == 2:
print(" [-] fatal:", desc)
sys.exit(1)
else:
sys.exit(1)
elif ctype == b"\x17":
print("<--- Application_data()")
data = decrypt(c, b"\x17", b"\x03\x03")
print(data.decode().strip())
else:
print("[-] Unknown TLS Record type:", ctype.hex())
sys.exit(1)
# PRF defined in TLS v1.2
def PRF(secret, seed, l):
out = b""
A = hmac.new(secret, seed, sha256).digest()
while len(out) < l:
out += hmac.new(secret, A + seed, sha256).digest()
A = hmac.new(secret, A, sha256).digest()
return out[:l]
# derives master_secret
def derive_master_secret():
global premaster, master_secret, client_random, server_random
master_secret = PRF(premaster, b"master secret" + client_random + server_random, 48)
# derives keys for encryption and MAC
def derive_keys():
global premaster, master_secret, client_random, server_random
global client_mac_key, server_mac_key, client_enc_key, server_enc_key, rc4c, rc4s
key_block = PRF(master_secret, b"key expansion" + server_random + client_random, 136)
mac_size = 20
key_size = 16
iv_size = 16
client_mac_key = key_block[:mac_size]
server_mac_key = key_block[mac_size:mac_size*2]
client_enc_key = key_block[mac_size*2:mac_size*2+key_size]
server_enc_key = key_block[mac_size*2+key_size:mac_size*2+key_size*2]
rc4c = ARC4.new(client_enc_key)
rc4s = ARC4.new(server_enc_key)
# HMAC SHA1 wrapper
def HMAC_sha1(key, data):
return hmac.new(key, data, sha1).digest()
# calculates MAC and encrypts plaintext
def encrypt(plain, type, version):
global client_mac_key, client_enc_key, client_seq, rc4c
mac = HMAC_sha1(client_mac_key, nb(client_seq, 8) + type + version + nb(len(plain), 2) + plain)
ciphertext = rc4c.encrypt(plain + mac)
client_seq+= 1
return ciphertext
# decrypts ciphertext and verifies MAC
def decrypt(ciphertext, type, version):
global server_mac_key, server_enc_key, server_seq, rc4s
d = rc4s.decrypt(ciphertext)
mac = d[-20:]
plain = d[:-20]
# verify MAC
mac_calc = HMAC_sha1(server_mac_key, nb(server_seq, 8) + type + version + nb(len(plain), 2) + plain)
if mac!=mac_calc:
print("[-] MAC verification failed!")
sys.exit(1)
server_seq+= 1
return plain
# read from the socket full TLS record
def readrecord():
record = b""
# read TLS record header (5 bytes)
for _ in range(5):
record += s.recv(1)
# find data length
datalen = bn(record[3:5])
# read TLS record body
for _ in range(datalen):
record+= s.recv(1)
return record
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
url = urlparse(args.url)
host = url.netloc.split(':')
if len(host) > 1:
port = int(host[1])
else:
port = 443
host = host[0]
path = url.path
client_random = b"" # will hold client randomness
server_random = b"" # will hold server randomness
server_cert = b"" # will hold DER encoded server certificate
premaster = b"" # will hold 48 byte pre-master secret
master_secret = b"" # will hold master secret
handshake_messages = b"" # will hold concatenation of handshake messages
# client/server keys and sequence numbers
client_mac_key = b""
server_mac_key = b""
client_enc_key = b""
server_enc_key = b""
client_seq = 0
server_seq = 0
# client/server RC4 instances
rc4c = b""
rc4s = b""
s.connect((host, port))
s.send(client_hello())
server_hello_done_received = False
server_change_cipher_spec_received = False
server_finished_received = False
while not server_hello_done_received:
parserecord(readrecord())
s.send(client_key_exchange())
s.send(change_cipher_spec())
derive_master_secret()
derive_keys()
s.send(finished())
while not server_finished_received:
parserecord(readrecord())
s.send(application_data(b"GET / HTTP/1.0\r\n\r\n"))
parserecord(readrecord())
print("[+] Closing TCP connection!")
s.close()