Untitled

 avatar
unknown
plain_text
5 months ago
35 kB
4
Indexable
## exo/networking/discovery.py

from abc import ABC, abstractmethod
from typing import List
from .peer_handle import PeerHandle


class Discovery(ABC):
  @abstractmethod
  async def start(self) -> None:
    pass

  @abstractmethod
  async def stop(self) -> None:
    pass

  @abstractmethod
  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
    pass



## exo/networking/peer_handle.py

from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
import numpy as np
from exo.inference.shard import Shard
from exo.topology.device_capabilities import DeviceCapabilities
from exo.topology.topology import Topology

class PeerHandle(ABC):
  @abstractmethod
  def id(self) -> str:
    pass

  @abstractmethod
  def addr(self) -> str:
    pass

  @abstractmethod
  def device_capabilities(self) -> DeviceCapabilities:
    pass

  @abstractmethod
  async def connect(self) -> None:
    pass

  @abstractmethod
  async def is_connected(self) -> bool:
    pass

  @abstractmethod
  async def disconnect(self) -> None:
    pass

  @abstractmethod
  async def health_check(self) -> bool:
    pass

  @abstractmethod
  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
    pass

  @abstractmethod
  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
    pass

  @abstractmethod
  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
    pass

  @abstractmethod
  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
    pass

  @abstractmethod
  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
    pass



## exo/networking/server.py

from abc import ABC, abstractmethod


class Server(ABC):
  @abstractmethod
  async def start(self) -> None:
    pass

  @abstractmethod
  async def stop(self) -> None:
    pass


## exo/networking/grpc/grpc_peer_handle.py

import grpc
import numpy as np
import asyncio
from typing import Optional, Tuple, List

from . import node_service_pb2
from . import node_service_pb2_grpc

from ..peer_handle import PeerHandle
from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities
from exo.helpers import DEBUG


class GRPCPeerHandle(PeerHandle):
  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
    self._id = _id
    self.address = address
    self._device_capabilities = device_capabilities
    self.channel = None
    self.stub = None

  def id(self) -> str:
    return self._id

  def addr(self) -> str:
    return self.address

  def device_capabilities(self) -> DeviceCapabilities:
    return self._device_capabilities

  async def connect(self):
    if self.channel is None:
      self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
      self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
    await self.channel.channel_ready()

  async def is_connected(self) -> bool:
    return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY

  async def disconnect(self):
    if self.channel:
      await self.channel.close()
    self.channel = None
    self.stub = None

  async def _ensure_connected(self):
    if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)

  async def health_check(self) -> bool:
    try:
      await self._ensure_connected()
      request = node_service_pb2.HealthCheckRequest()
      response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
      return response.is_healthy
    except asyncio.TimeoutError:
      return False
    except:
      if DEBUG >= 4:
        print(f"Health check failed for {self._id}@{self.address}.")
        import traceback
        traceback.print_exc()
      return False

  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
    request = node_service_pb2.PromptRequest(
      prompt=prompt,
      image_str=image_str,
      shard=node_service_pb2.Shard(
        model_id=shard.model_id,
        start_layer=shard.start_layer,
        end_layer=shard.end_layer,
        n_layers=shard.n_layers,
      ),
      request_id=request_id,
      inference_state=inference_state,
    )
    response = await self.stub.SendPrompt(request)

    if not response.tensor_data or not response.shape or not response.dtype:
      return None

    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
    request = node_service_pb2.TensorRequest(
      shard=node_service_pb2.Shard(
        model_id=shard.model_id,
        start_layer=shard.start_layer,
        end_layer=shard.end_layer,
        n_layers=shard.n_layers,
      ),
      tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
      request_id=request_id,
      inference_state=inference_state,
    )
    response = await self.stub.SendTensor(request)

    if not response.tensor_data or not response.shape or not response.dtype:
      return None

    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
    response = await self.stub.GetInferenceResult(request)
    if response.tensor is None:
      return None, response.is_finished
    return (
      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
      response.is_finished,
    )

  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
    request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
    response = await self.stub.CollectTopology(request)
    topology = Topology()
    for node_id, capabilities in response.nodes.items():
      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
      topology.update_node(node_id, device_capabilities)
    for node_id, peers in response.peer_graph.items():
      for peer_id in peers.peer_ids:
        topology.add_edge(node_id, peer_id)
    return topology

  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
    await self.stub.SendResult(request)

  async def send_opaque_status(self, request_id: str, status: str) -> None:
    request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
    await self.stub.SendOpaqueStatus(request)


## exo/networking/grpc/grpc_server.py

import grpc
from concurrent import futures
import numpy as np
from asyncio import CancelledError

from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node


class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
  def __init__(self, node: Node, host: str, port: int):
    self.node = node
    self.host = host
    self.port = port
    self.server = None

  async def start(self) -> None:
    self.server = grpc.aio.server(
      futures.ThreadPoolExecutor(max_workers=10),
      options=[
        ("grpc.max_metadata_size", 32*1024*1024),
        ("grpc.max_send_message_length", 128*1024*1024),
        ("grpc.max_receive_message_length", 128*1024*1024),
      ],
    )
    node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
    listen_addr = f"{self.host}:{self.port}"
    self.server.add_insecure_port(listen_addr)
    await self.server.start()
    if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")

  async def stop(self) -> None:
    if self.server:
      try:
        await self.server.stop(grace=5)
        await self.server.wait_for_termination()
      except CancelledError:
        pass
      if DEBUG >= 1: print("Server stopped and all connections are closed")

  async def SendPrompt(self, request, context):
    shard = Shard(
      model_id=request.shard.model_id,
      start_layer=request.shard.start_layer,
      end_layer=request.shard.end_layer,
      n_layers=request.shard.n_layers,
    )
    prompt = request.prompt
    image_str = request.image_str
    request_id = request.request_id
    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
    tensor_data = result.tobytes() if result is not None else None
    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

  async def SendTensor(self, request, context):
    shard = Shard(
      model_id=request.shard.model_id,
      start_layer=request.shard.start_layer,
      end_layer=request.shard.end_layer,
      n_layers=request.shard.n_layers,
    )
    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
    request_id = request.request_id
    inference_state = request.inference_state

    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
    tensor_data = result.tobytes() if result is not None else None
    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

  async def GetInferenceResult(self, request, context):
    request_id = request.request_id
    result = await self.node.get_inference_result(request_id)
    if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
    tensor_data = result[0].tobytes() if result[0] is not None else None
    return (
      node_service_pb2.InferenceResult(
        tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
        is_finished=result[1],
      ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
    )

  async def CollectTopology(self, request, context):
    max_depth = request.max_depth
    visited = set(request.visited)
    topology = await self.node.collect_topology(visited, max_depth)
    nodes = {
      node_id:
        node_service_pb2.DeviceCapabilities(
          model=cap.model,
          chip=cap.chip,
          memory=cap.memory,
          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
        )
      for node_id, cap in topology.nodes.items()
    }
    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
    return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)

  async def SendResult(self, request, context):
    request_id = request.request_id
    result = request.result
    is_finished = request.is_finished
    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
    self.node.on_token.trigger_all(request_id, result, is_finished)
    return node_service_pb2.Empty()

  async def SendOpaqueStatus(self, request, context):
    request_id = request.request_id
    status = request.status
    if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
    self.node.on_opaque_status.trigger_all(request_id, status)
    return node_service_pb2.Empty()

  async def HealthCheck(self, request, context):
    return node_service_pb2.HealthCheckResponse(is_healthy=True)


## exo/networking/grpc/node_service.proto

syntax = "proto3";

package node_service;

service NodeService {
  rpc SendPrompt (PromptRequest) returns (Tensor) {}
  rpc SendTensor (TensorRequest) returns (Tensor) {}
  rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
  rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
  rpc SendResult (SendResultRequest) returns (Empty) {}
  rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
  rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
}

message Shard {
  string model_id = 1;
  int32 start_layer = 2;
  int32 end_layer = 3;
  int32 n_layers = 4;
}

message PromptRequest {
  Shard shard = 1;
  string prompt = 2;
  optional string image_str = 3;
  optional string request_id = 4;
  optional string inference_state = 5;
}

message TensorRequest {
  Shard shard = 1;
  Tensor tensor = 2;
  optional string request_id = 3;
  optional string inference_state = 4;
}

message GetInferenceResultRequest {
  string request_id = 1;
}

message InferenceResult {
  optional Tensor tensor = 1;
  bool is_finished = 2;
}

message Tensor {
  bytes tensor_data = 1;
  repeated int32 shape = 2;
  string dtype = 3;
}

message CollectTopologyRequest {
  repeated string visited = 1;
  int32 max_depth = 2;
}

message Topology {
  map<string, DeviceCapabilities> nodes = 1;
  map<string, Peers> peer_graph = 2;
}

message Peers {
    repeated string peer_ids = 1;
}

message DeviceFlops {
  float fp32 = 1;
  float fp16 = 2;
  float int8 = 3;
}

message DeviceCapabilities {
  string model = 1;
  string chip = 2;
  int32 memory = 3;
  DeviceFlops flops = 4;
}

message SendResultRequest {
  string request_id = 1;
  repeated int32 result = 2;
  bool is_finished = 3;
}

message SendOpaqueStatusRequest {
  string request_id = 1;
  string status = 2;
}

message HealthCheckRequest {}

message HealthCheckResponse {
  bool is_healthy = 1;
}

message Empty {}


## exo/networking/grpc/node_service_pb2.py

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler.  DO NOT EDIT!
# source: node_service.proto
# Protobuf Python Version: 5.26.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()




DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb4\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
  DESCRIPTOR._loaded_options = None
  _globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
  _globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
  _globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
  _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
  _globals['_SHARD']._serialized_start=36
  _globals['_SHARD']._serialized_end=119
  _globals['_PROMPTREQUEST']._serialized_start=122
  _globals['_PROMPTREQUEST']._serialized_end=317
  _globals['_TENSORREQUEST']._serialized_start=320
  _globals['_TENSORREQUEST']._serialized_end=499
  _globals['_GETINFERENCERESULTREQUEST']._serialized_start=501
  _globals['_GETINFERENCERESULTREQUEST']._serialized_end=548
  _globals['_INFERENCERESULT']._serialized_start=550
  _globals['_INFERENCERESULT']._serialized_end=642
  _globals['_TENSOR']._serialized_start=644
  _globals['_TENSOR']._serialized_end=703
  _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=705
  _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=765
  _globals['_TOPOLOGY']._serialized_start=768
  _globals['_TOPOLOGY']._serialized_end=1038
  _globals['_TOPOLOGY_NODESENTRY']._serialized_start=889
  _globals['_TOPOLOGY_NODESENTRY']._serialized_end=967
  _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=969
  _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1038
  _globals['_PEERS']._serialized_start=1040
  _globals['_PEERS']._serialized_end=1065
  _globals['_DEVICEFLOPS']._serialized_start=1067
  _globals['_DEVICEFLOPS']._serialized_end=1122
  _globals['_DEVICECAPABILITIES']._serialized_start=1124
  _globals['_DEVICECAPABILITIES']._serialized_end=1231
  _globals['_SENDRESULTREQUEST']._serialized_start=1233
  _globals['_SENDRESULTREQUEST']._serialized_end=1309
  _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1311
  _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1372
  _globals['_HEALTHCHECKREQUEST']._serialized_start=1374
  _globals['_HEALTHCHECKREQUEST']._serialized_end=1394
  _globals['_HEALTHCHECKRESPONSE']._serialized_start=1396
  _globals['_HEALTHCHECKRESPONSE']._serialized_end=1437
  _globals['_EMPTY']._serialized_start=1439
  _globals['_EMPTY']._serialized_end=1446
  _globals['_NODESERVICE']._serialized_start=1449
  _globals['_NODESERVICE']._serialized_end=2013
# @@protoc_insertion_point(module_scope)


## exo/networking/grpc/node_service_pb2_grpc.py

# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings

from . import node_service_pb2 as node__service__pb2

GRPC_GENERATED_VERSION = '1.64.1'
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = '1.65.0'
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
_version_not_supported = False

try:
    from grpc._utilities import first_version_is_lower
    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
    _version_not_supported = True

if _version_not_supported:
    warnings.warn(
        f'The grpc package installed is at version {GRPC_VERSION},'
        + f' but the generated code in node_service_pb2_grpc.py depends on'
        + f' grpcio>={GRPC_GENERATED_VERSION}.'
        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
        RuntimeWarning
    )


class NodeServiceStub(object):
    """Missing associated documentation comment in .proto file."""

    def __init__(self, channel):
        """Constructor.

        Args:
            channel: A grpc.Channel.
        """
        self.SendPrompt = channel.unary_unary(
                '/node_service.NodeService/SendPrompt',
                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
                response_deserializer=node__service__pb2.Tensor.FromString,
                _registered_method=True)
        self.SendTensor = channel.unary_unary(
                '/node_service.NodeService/SendTensor',
                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
                response_deserializer=node__service__pb2.Tensor.FromString,
                _registered_method=True)
        self.GetInferenceResult = channel.unary_unary(
                '/node_service.NodeService/GetInferenceResult',
                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
                response_deserializer=node__service__pb2.InferenceResult.FromString,
                _registered_method=True)
        self.CollectTopology = channel.unary_unary(
                '/node_service.NodeService/CollectTopology',
                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                response_deserializer=node__service__pb2.Topology.FromString,
                _registered_method=True)
        self.SendResult = channel.unary_unary(
                '/node_service.NodeService/SendResult',
                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
                response_deserializer=node__service__pb2.Empty.FromString,
                _registered_method=True)
        self.SendOpaqueStatus = channel.unary_unary(
                '/node_service.NodeService/SendOpaqueStatus',
                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
                response_deserializer=node__service__pb2.Empty.FromString,
                _registered_method=True)
        self.HealthCheck = channel.unary_unary(
                '/node_service.NodeService/HealthCheck',
                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
                _registered_method=True)


class NodeServiceServicer(object):
    """Missing associated documentation comment in .proto file."""

    def SendPrompt(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

    def SendTensor(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

    def GetInferenceResult(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

    def CollectTopology(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

    def SendResult(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

    def SendOpaqueStatus(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')

    def HealthCheck(self, request, context):
        """Missing associated documentation comment in .proto file."""
        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
        context.set_details('Method not implemented!')
        raise NotImplementedError('Method not implemented!')


def add_NodeServiceServicer_to_server(servicer, server):
    rpc_method_handlers = {
            'SendPrompt': grpc.unary_unary_rpc_method_handler(
                    servicer.SendPrompt,
                    request_deserializer=node__service__pb2.PromptRequest.FromString,
                    response_serializer=node__service__pb2.Tensor.SerializeToString,
            ),
            'SendTensor': grpc.unary_unary_rpc_method_handler(
                    servicer.SendTensor,
                    request_deserializer=node__service__pb2.TensorRequest.FromString,
                    response_serializer=node__service__pb2.Tensor.SerializeToString,
            ),
            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
                    servicer.GetInferenceResult,
                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
            ),
            'CollectTopology': grpc.unary_unary_rpc_method_handler(
                    servicer.CollectTopology,
                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                    response_serializer=node__service__pb2.Topology.SerializeToString,
            ),
            'SendResult': grpc.unary_unary_rpc_method_handler(
                    servicer.SendResult,
                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
                    response_serializer=node__service__pb2.Empty.SerializeToString,
            ),
            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
                    servicer.SendOpaqueStatus,
                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
                    response_serializer=node__service__pb2.Empty.SerializeToString,
            ),
            'HealthCheck': grpc.unary_unary_rpc_method_handler(
                    servicer.HealthCheck,
                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
            ),
    }
    generic_handler = grpc.method_handlers_generic_handler(
            'node_service.NodeService', rpc_method_handlers)
    server.add_generic_rpc_handlers((generic_handler,))
    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)


 # This class is part of an EXPERIMENTAL API.
class NodeService(object):
    """Missing associated documentation comment in .proto file."""

    @staticmethod
    def SendPrompt(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/SendPrompt',
            node__service__pb2.PromptRequest.SerializeToString,
            node__service__pb2.Tensor.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)

    @staticmethod
    def SendTensor(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/SendTensor',
            node__service__pb2.TensorRequest.SerializeToString,
            node__service__pb2.Tensor.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)

    @staticmethod
    def GetInferenceResult(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/GetInferenceResult',
            node__service__pb2.GetInferenceResultRequest.SerializeToString,
            node__service__pb2.InferenceResult.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)

    @staticmethod
    def CollectTopology(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/CollectTopology',
            node__service__pb2.CollectTopologyRequest.SerializeToString,
            node__service__pb2.Topology.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)

    @staticmethod
    def SendResult(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/SendResult',
            node__service__pb2.SendResultRequest.SerializeToString,
            node__service__pb2.Empty.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)

    @staticmethod
    def SendOpaqueStatus(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/SendOpaqueStatus',
            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
            node__service__pb2.Empty.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)

    @staticmethod
    def HealthCheck(request,
            target,
            options=(),
            channel_credentials=None,
            call_credentials=None,
            insecure=False,
            compression=None,
            wait_for_ready=None,
            timeout=None,
            metadata=None):
        return grpc.experimental.unary_unary(
            request,
            target,
            '/node_service.NodeService/HealthCheck',
            node__service__pb2.HealthCheckRequest.SerializeToString,
            node__service__pb2.HealthCheckResponse.FromString,
            options,
            channel_credentials,
            insecure,
            call_credentials,
            compression,
            wait_for_ready,
            timeout,
            metadata,
            _registered_method=True)


Editor is loading...
Leave a Comment