Untitled
unknown
plain_text
2 years ago
4.2 kB
8
Indexable
# Copyright (c) Aptos
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import unittest
from nacl.signing import SigningKey, VerifyKey
from .bcs import Deserializer, Serializer
class PrivateKey:
LENGTH: int = 32
key: SigningKey
def __init__(self, key: SigningKey):
self.key = key
def __eq__(self, other: PrivateKey):
return self.key == other.key
def __str__(self):
return self.hex()
def from_hex(value: str) -> PrivateKey:
if value[0:2] == "0x":
value = value[2:]
return PrivateKey(SigningKey(bytes.fromhex(value)))
def hex(self) -> str:
return f"0x{self.key.encode().hex()}"
def public_key(self) -> PublicKey:
return PublicKey(self.key.verify_key)
def random() -> PrivateKey:
return PrivateKey(SigningKey.generate())
def sign(self, data: bytes) -> Signature:
return Signature(self.key.sign(data).signature)
def deserialize(deserializer: Deserializer) -> PrivateKey:
key = deserializer.bytes()
if len(key) != PrivateKey.LENGTH:
raise Exception("Length mismatch")
return PrivateKey(SigningKey(key))
def serialize(self, serializer: Serializer):
serializer.bytes(self.key.encode())
class PublicKey:
LENGTH: int = 32
key: VerifyKey
def __init__(self, key: VerifyKey):
self.key = key
def __eq__(self, other: PrivateKey):
return self.key == other.key
def __str__(self) -> str:
return f"0x{self.key.encode().hex()}"
def verify(self, data: bytes, signature: Signature) -> bool:
try:
self.key.verify(data, signature.data())
except:
return False
return True
def deserialize(deserializer: Deserializer) -> PublicKey:
key = deserializer.bytes()
if len(key) != PublicKey.LENGTH:
raise Exception("Length mismatch")
return PublicKey(VerifyKey(key))
def serialize(self, serializer: Serializer):
serializer.bytes(self.key.encode())
class Signature:
LENGTH: int = 64
signature: bytes
def __init__(self, signature: bytes):
self.signature = signature
def __eq__(self, other: PrivateKey):
return self.signature == other.signature
def __str__(self) -> str:
return f"0x{self.signature.hex()}"
def data(self) -> bytes:
return self.signature
def deserialize(deserializer: Deserializer) -> Signature:
signature = deserializer.bytes()
if len(signature) != Signature.LENGTH:
raise Exception("Length mismatch")
return Signature(signature)
def serialize(self, serializer: Serializer):
serializer.bytes(self.signature)
class Test(unittest.TestCase):
def test_sign_and_verify(self):
in_value = b"test_message"
private_key = PrivateKey.random()
public_key = private_key.public_key()
signature = private_key.sign(in_value)
self.assertTrue(public_key.verify(in_value, signature))
def test_private_key_serialization(self):
private_key = PrivateKey.random()
ser = Serializer()
private_key.serialize(ser)
ser_private_key = PrivateKey.deserialize(Deserializer(ser.output()))
self.assertEqual(private_key, ser_private_key)
def test_public_key_serialization(self):
private_key = PrivateKey.random()
public_key = private_key.public_key()
ser = Serializer()
public_key.serialize(ser)
ser_public_key = PublicKey.deserialize(Deserializer(ser.output()))
self.assertEqual(public_key, ser_public_key)
def test_signature_key_serialization(self):
private_key = PrivateKey.random()
in_value = b"another_message"
signature = private_key.sign(in_value)
ser = Serializer()
signature.serialize(ser)
ser_signature = Signature.deserialize(Deserializer(ser.output()))
self.assertEqual(signature, ser_signature)
Editor is loading...
Leave a Comment