Untitled
unknown
plain_text
a year ago
35 kB
9
Indexable
## exo/networking/discovery.py
from abc import ABC, abstractmethod
from typing import List
from .peer_handle import PeerHandle
class Discovery(ABC):
@abstractmethod
async def start(self) -> None:
pass
@abstractmethod
async def stop(self) -> None:
pass
@abstractmethod
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
pass
## exo/networking/peer_handle.py
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
import numpy as np
from exo.inference.shard import Shard
from exo.topology.device_capabilities import DeviceCapabilities
from exo.topology.topology import Topology
class PeerHandle(ABC):
@abstractmethod
def id(self) -> str:
pass
@abstractmethod
def addr(self) -> str:
pass
@abstractmethod
def device_capabilities(self) -> DeviceCapabilities:
pass
@abstractmethod
async def connect(self) -> None:
pass
@abstractmethod
async def is_connected(self) -> bool:
pass
@abstractmethod
async def disconnect(self) -> None:
pass
@abstractmethod
async def health_check(self) -> bool:
pass
@abstractmethod
async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
pass
@abstractmethod
async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
pass
@abstractmethod
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
pass
@abstractmethod
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
pass
@abstractmethod
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
pass
## exo/networking/server.py
from abc import ABC, abstractmethod
class Server(ABC):
@abstractmethod
async def start(self) -> None:
pass
@abstractmethod
async def stop(self) -> None:
pass
## exo/networking/grpc/grpc_peer_handle.py
import grpc
import numpy as np
import asyncio
from typing import Optional, Tuple, List
from . import node_service_pb2
from . import node_service_pb2_grpc
from ..peer_handle import PeerHandle
from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities
from exo.helpers import DEBUG
class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
self._id = _id
self.address = address
self._device_capabilities = device_capabilities
self.channel = None
self.stub = None
def id(self) -> str:
return self._id
def addr(self) -> str:
return self.address
def device_capabilities(self) -> DeviceCapabilities:
return self._device_capabilities
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
async def is_connected(self) -> bool:
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
async def disconnect(self):
if self.channel:
await self.channel.close()
self.channel = None
self.stub = None
async def _ensure_connected(self):
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
async def health_check(self) -> bool:
try:
await self._ensure_connected()
request = node_service_pb2.HealthCheckRequest()
response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
return response.is_healthy
except asyncio.TimeoutError:
return False
except:
if DEBUG >= 4:
print(f"Health check failed for {self._id}@{self.address}.")
import traceback
traceback.print_exc()
return False
async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(
prompt=prompt,
image_str=image_str,
shard=node_service_pb2.Shard(
model_id=shard.model_id,
start_layer=shard.start_layer,
end_layer=shard.end_layer,
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=inference_state,
)
response = await self.stub.SendPrompt(request)
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
start_layer=shard.start_layer,
end_layer=shard.end_layer,
n_layers=shard.n_layers,
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=inference_state,
)
response = await self.stub.SendTensor(request)
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
response = await self.stub.GetInferenceResult(request)
if response.tensor is None:
return None, response.is_finished
return (
np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
response.is_finished,
)
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
response = await self.stub.CollectTopology(request)
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
topology.update_node(node_id, device_capabilities)
for node_id, peers in response.peer_graph.items():
for peer_id in peers.peer_ids:
topology.add_edge(node_id, peer_id)
return topology
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
await self.stub.SendResult(request)
async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
## exo/networking/grpc/grpc_server.py
import grpc
from concurrent import futures
import numpy as np
from asyncio import CancelledError
from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
def __init__(self, node: Node, host: str, port: int):
self.node = node
self.host = host
self.port = port
self.server = None
async def start(self) -> None:
self.server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
("grpc.max_metadata_size", 32*1024*1024),
("grpc.max_send_message_length", 128*1024*1024),
("grpc.max_receive_message_length", 128*1024*1024),
],
)
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
listen_addr = f"{self.host}:{self.port}"
self.server.add_insecure_port(listen_addr)
await self.server.start()
if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
async def stop(self) -> None:
if self.server:
try:
await self.server.stop(grace=5)
await self.server.wait_for_termination()
except CancelledError:
pass
if DEBUG >= 1: print("Server stopped and all connections are closed")
async def SendPrompt(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
start_layer=request.shard.start_layer,
end_layer=request.shard.end_layer,
n_layers=request.shard.n_layers,
)
prompt = request.prompt
image_str = request.image_str
request_id = request.request_id
result = await self.node.process_prompt(shard, prompt, image_str, request_id)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendTensor(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
start_layer=request.shard.start_layer,
end_layer=request.shard.end_layer,
n_layers=request.shard.n_layers,
)
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id
inference_state = request.inference_state
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def GetInferenceResult(self, request, context):
request_id = request.request_id
result = await self.node.get_inference_result(request_id)
if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
tensor_data = result[0].tobytes() if result[0] is not None else None
return (
node_service_pb2.InferenceResult(
tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
is_finished=result[1],
) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
)
async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
topology = await self.node.collect_topology(visited, max_depth)
nodes = {
node_id:
node_service_pb2.DeviceCapabilities(
model=cap.model,
chip=cap.chip,
memory=cap.memory,
flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
)
for node_id, cap in topology.nodes.items()
}
peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
async def SendResult(self, request, context):
request_id = request.request_id
result = request.result
is_finished = request.is_finished
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
async def SendOpaqueStatus(self, request, context):
request_id = request.request_id
status = request.status
if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
self.node.on_opaque_status.trigger_all(request_id, status)
return node_service_pb2.Empty()
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
## exo/networking/grpc/node_service.proto
syntax = "proto3";
package node_service;
service NodeService {
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
}
message Shard {
string model_id = 1;
int32 start_layer = 2;
int32 end_layer = 3;
int32 n_layers = 4;
}
message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string image_str = 3;
optional string request_id = 4;
optional string inference_state = 5;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional string inference_state = 4;
}
message GetInferenceResultRequest {
string request_id = 1;
}
message InferenceResult {
optional Tensor tensor = 1;
bool is_finished = 2;
}
message Tensor {
bytes tensor_data = 1;
repeated int32 shape = 2;
string dtype = 3;
}
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
}
message Topology {
map<string, DeviceCapabilities> nodes = 1;
map<string, Peers> peer_graph = 2;
}
message Peers {
repeated string peer_ids = 1;
}
message DeviceFlops {
float fp32 = 1;
float fp16 = 2;
float int8 = 3;
}
message DeviceCapabilities {
string model = 1;
string chip = 2;
int32 memory = 3;
DeviceFlops flops = 4;
}
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
bool is_finished = 3;
}
message SendOpaqueStatusRequest {
string request_id = 1;
string status = 2;
}
message HealthCheckRequest {}
message HealthCheckResponse {
bool is_healthy = 1;
}
message Empty {}
## exo/networking/grpc/node_service_pb2.py
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: node_service.proto
# Protobuf Python Version: 5.26.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb4\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
_globals['_PROMPTREQUEST']._serialized_start=122
_globals['_PROMPTREQUEST']._serialized_end=317
_globals['_TENSORREQUEST']._serialized_start=320
_globals['_TENSORREQUEST']._serialized_end=499
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=501
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=548
_globals['_INFERENCERESULT']._serialized_start=550
_globals['_INFERENCERESULT']._serialized_end=642
_globals['_TENSOR']._serialized_start=644
_globals['_TENSOR']._serialized_end=703
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=705
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=765
_globals['_TOPOLOGY']._serialized_start=768
_globals['_TOPOLOGY']._serialized_end=1038
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=889
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=967
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=969
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1038
_globals['_PEERS']._serialized_start=1040
_globals['_PEERS']._serialized_end=1065
_globals['_DEVICEFLOPS']._serialized_start=1067
_globals['_DEVICEFLOPS']._serialized_end=1122
_globals['_DEVICECAPABILITIES']._serialized_start=1124
_globals['_DEVICECAPABILITIES']._serialized_end=1231
_globals['_SENDRESULTREQUEST']._serialized_start=1233
_globals['_SENDRESULTREQUEST']._serialized_end=1309
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1311
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1372
_globals['_HEALTHCHECKREQUEST']._serialized_start=1374
_globals['_HEALTHCHECKREQUEST']._serialized_end=1394
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1396
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1437
_globals['_EMPTY']._serialized_start=1439
_globals['_EMPTY']._serialized_end=1446
_globals['_NODESERVICE']._serialized_start=1449
_globals['_NODESERVICE']._serialized_end=2013
# @@protoc_insertion_point(module_scope)
## exo/networking/grpc/node_service_pb2_grpc.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from . import node_service_pb2 as node__service__pb2
GRPC_GENERATED_VERSION = '1.64.1'
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = '1.65.0'
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
warnings.warn(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
RuntimeWarning
)
class NodeServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.HealthCheck = channel.unary_unary(
'/node_service.NodeService/HealthCheck',
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
_registered_method=True)
class NodeServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def SendPrompt(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTensor(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetInferenceResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectTopology(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendOpaqueStatus(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def HealthCheck(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
),
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'node_service.NodeService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class NodeService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def SendPrompt(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendTensor(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetInferenceResult(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def CollectTopology(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/CollectTopology',
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendResult(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendOpaqueStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def HealthCheck(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/HealthCheck',
node__service__pb2.HealthCheckRequest.SerializeToString,
node__service__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
Editor is loading...
Leave a Comment