CHATGPT_RPCLIB_PYTHON_0

mail@pastecode.io avatar
unknown
python
2 years ago
6.7 kB
3
Indexable
Never
mport xmlrpc.server
import xmlrpc.client
import threading
import time
import logging
import functools
import unittest
import ssl
import random
import queue

class ConnectionPool:
    def __init__(self, host, port, pool_size=10, retry_interval=5, max_retries=5, use_ssl=False, cafile=None, certfile=None, keyfile=None):
        self.host = host
        self.port = port
        self.pool_size = pool_size
        self.retry_interval = retry_interval
        self.max_retries = max_retries
        self.pool = []
        self.current_size = 0
        self.use_ssl = use_ssl
        self.cafile = cafile
        self.certfile = certfile
        self.keyfile = keyfile
        self.logger = logging.getLogger(__name__)

    def create_connection(self):
        retries = 0
        while True:
            try:
                if self.use_ssl:
                    context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
                    context.verify_mode = ssl.CERT_REQUIRED
                    context.load_verify_locations(cafile=self.cafile)
                    context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
                    proxy = xmlrpc.client.ServerProxy(f"https://{self.host}:{self.port}", context=context)
                else:
                    proxy = xmlrpc.client.ServerProxy(f"http://{self.host}:{self.port}")
                return proxy
            except Exception as e:
                retries += 1
                if retries > self.max_retries:
                    raise e
                time.sleep(self.retry_interval)

    def get_connection(self):
        if self.current_size < self.pool_size:
            self.current_size += 1
            return self.create_connection()
        else:
            while True:
                if self.pool:
                    return self.pool.pop()
                time.sleep(0.1)

    def release_connection(self, connection):
        self.pool.append(connection)

class LoadBalancer:
    def __init__(self, servers):
        self.servers = servers
        self.current_index = 0
    def get_server(self):
        server = self.servers[self.current_index]
        self.current_index = (self.current_index + 1) % len(self.servers)
        return server

class LoadShedder:
    def __init__(self, max_queue_size=100):
        self.max_queue_size = max_queue_size
        self.queue = queue.Queue(max_queue_size)

    def add_request(self, request):
        if self.queue.full():
            raise Exception("Server overloaded")
        self.queue.put(request)

    def get_request(self):
        return self.queue.get()

class RPCServer:
    def __init__(self, host, port, use_ssl=False, cafile=None, certfile=None, keyfile=None):
        self.server = xmlrpc.server.SimpleXMLRPCServer((host, port), use_builtin_types=True, allow_none=True)
        if use_ssl:
            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
            context.load_cert_chain(certfile=certfile, keyfile=keyfile)
            self.server.socket = context.wrap_socket(self.server.socket, server_side=True)
        self.server.register_introspection_functions()

    def register_function(self, func, name=None):
        self.server.register_function(func, name)

    def serve_forever(self):
        self.server.serve_forever()


class RPCClient:
    def __init__(self, host, port, pool_size=10, retry_interval=5, max_retries=5, use_ssl=False, cafile=None, certfile=None, keyfile=None):
        self.connection_pool = ConnectionPool(host, port, pool_size, retry_interval, max_retries, use_ssl, cafile, certfile, keyfile)

    def call(self, method, *args, **kwargs):
        connection = self.connection_pool.get_connection()
        try:
            return getattr(connection, method)(*args, **kwargs)
        except xmlrpc.client.Fault as e:
            if e.faultCode == -32601:  # method not found
                raise Exception(f"Method not found: {method}") from e
            else:
                raise e
        except Exception as e:
            raise e
        finally:
            self.connection_pool.release_connection(connection)

    def call_async(self, method, *args, **kwargs):
        @functools.wraps(method)
        def wrapper(*args, **kwargs):
            try:
                result = self.call(method, *args, **kwargs)
                self.async_result = (True, result)
            except Exception as e:
                self.async_result = (False, e)

        thread = threading.Thread(target=wrapper, args=args, kwargs=kwargs)
        thread.start()
        return thread

def add(x, y):
    return x + y

class RPCTest(unittest.TestCase):
    def setUp(self):
        self.server1 = RPCServer("localhost", 8000)
        self.server1.register_function(add)
        self.server1_thread = threading.Thread(target=self.server1.serve_forever)
        self.server1_thread.start()
        self.server2 = RPCServer("localhost", 8001)
        self.server2.register_function(add)
        self.server2_thread = threading.Thread(target=self.server2.serve_forever)
        self.server2_thread.start()
        self.load_balancer = LoadBalancer([("localhost", 8000), ("localhost", 8001)])
        self.load_shedder = LoadShedder()
        self.client = RPCClient("localhost", 8000, use_ssl=True, cafile="cacert.pem", certfile="clientcert.pem", keyfile="clientkey.pem")

    def tearDown(self):
        self.server1.server_close()
        self.server1_thread.join()
        self.server2.server_close()
        self.server2_thread.join()

    def test_call(self):
        result = self.client.call("add", 1, 2)
        self.assertEqual(result, 3)

    def test_call_async(self):
        thread = self.client.call_async("add", 3, 4)
        thread.join()
        success, result = self.client.async_result
        self.assertTrue(success)
        self.assertEqual(result, 7)

    def test_load_balancer(self):
        results = []
        for i in range(10):
            host, port = self.load_balancer.get_server()
            client = RPCClient(host, port, use_ssl=True, cafile="cacert.pem", certfile="clientcert.pem", keyfile="clientkey.pem")
            results.append(client.call("add", i, i))
        self.assertEqual(results, list(range(0, 20, 2)))

    def test_load_shedder(self):
        success = False
        try:
            for i in range(200):
                self.load_shedder.add_request((add, (i, i)))
        except Exception as e:
            success = True
        self.assertTrue(success)

        results = []
        for i in range(100):
            results.append(self.load_shedder.get_request()[0](*self.load_shedder.get_request()[1]))
        self.assertEqual(results, list(range(0, 200, 2)))

if __name__ == "__main__":
    unittest.main()