Untitled
unknown
python
2 months ago
38 kB
3
Indexable
# Copyright (c) OpenMMLab. All rights reserved. import asyncio import concurrent.futures import dataclasses import json import os import random import re from contextlib import asynccontextmanager, closing from copy import deepcopy from functools import partial from itertools import count from queue import Queue from threading import Thread from typing import (Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union) import tqdm from lmdeploy.logger import RequestLogger from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig) from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model from lmdeploy.serve.utils import LogitsMixin from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger logger = get_logger('lmdeploy') def get_names_from_model(model_path: str, model_name: str = None): """Get model name and chat template name from workspace model.""" triton_model_path = os.path.join(model_path, 'triton_models', 'weights') if not os.path.exists(triton_model_path): chat_template_name = best_match_model(model_path) else: # `model_path` refers to a turbomind model, reading # chat_template_name from the config config_path = os.path.join(triton_model_path, 'config.yaml') with open(config_path, 'r') as f: import yaml config = yaml.safe_load(f) chat_template_name = config['model_config']['chat_template'] model_name = model_name if model_name else model_path return model_name, chat_template_name @dataclasses.dataclass class GenOut: """Pack all response information together.""" response: str history_token_len: int input_token_len: int generate_token_len: int finish_reason: Optional[Literal['stop', 'length', 'error']] = None token_ids: List[int] = None logprobs: List[Dict[int, float]] = None def _gen_out_to_response(out: GenOut, index) -> Response: return Response(text=out.response, generate_token_len=out.generate_token_len, input_token_len=out.input_token_len, finish_reason=out.finish_reason, token_ids=out.token_ids, logprobs=out.logprobs, index=index) def _append_response(dst: Response, src: Response): """dst += src.""" if not dst: return src dst.text += src.text dst.generate_token_len = src.generate_token_len dst.input_token_len = src.input_token_len dst.finish_reason = src.finish_reason dst.index = src.index if src.token_ids: dst.token_ids += src.token_ids if src.logprobs: dst.logprobs = dst.logprobs or [] dst.logprobs += src.logprobs return dst class Session: """Session for AsyncEngine.chat. Args: _id (int): session_id for internal use. _step (int): the offset of the k/v cache for internal use. _prompt (Any): input prompt for internal use. _response (Reaponse): model output for prompt. _engine (Any): engine for internal use. history (List[Any, str]): chat history. """ def __init__(self, session_id: int, engine: Any, gen_config: GenerationConfig = None): self._id: int = session_id self._engine = engine self._step: int = 0 self._prompt: Any = None self._response: Response = None self._gen_config = gen_config self.history: List[Tuple[Any, str]] = [] def _merge_response(self, resp: Response, step: Union[Response, GenOut]): """merge response.""" resp.text += step.text if isinstance(step, Response) else step.response resp.input_token_len = step.input_token_len resp.generate_token_len = step.generate_token_len resp.finish_reason = step.finish_reason return resp @property def response(self) -> Response: """return response.""" return self._response def close(self): """release engine storage for this session.""" if self._engine: self._engine._run(coro=self._engine.end_session(self._id)).result() self._engine = None def __repr__(self) -> str: res = '' for user, assistant in self.history: if isinstance(user, list): user = str(user) res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n' return res def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def __call__( self, prompt: str, gen_config: Optional[GenerationConfig] = None, stream_response: bool = True, do_preprocess: bool = True) -> Union[Response, Iterator[Response]]: self._engine.chat(prompt=prompt, gen_config=gen_config or self._gen_config, stream_response=stream_response, do_preprocess=do_preprocess, session=self) if stream_response: return self.generator else: return self.response class _EventLoopThread: def __init__(self): fut = concurrent.futures.Future() self.thread = Thread( target=partial(_EventLoopThread._thread_entry, fut)) self.thread.start() self.loop: asyncio.AbstractEventLoop = fut.result() self.closed = False @staticmethod def _thread_entry(fut): loop = asyncio.new_event_loop() fut.set_result(loop) try: loop.run_forever() except BaseException as e: logger.error(f'[internal_thread] {type(e).__name__} {e}') finally: loop.close() def close(self): if self.closed: return self.closed = True self.loop.call_soon_threadsafe(self.loop.stop) self.thread.join() def __del__(self): self.close() class AsyncEngine(LogitsMixin): """Async inference engine. Maintaining a bunch of tm_model instances. Args: model_path (str): the path of a model. It could be one of the following options: - i) A local directory path of a turbomind model which is converted by `lmdeploy convert` command or download from ii) and iii). - ii) The model_id of a lmdeploy-quantized model hosted inside a model repo on huggingface.co, such as "InternLM/internlm-chat-20b-4bit", "lmdeploy/llama2-chat-70b-4bit", etc. - iii) The model_id of a model hosted inside a model repo on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. model_name (str): needed when model_path is a pytorch model on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. backend (str): either `turbomind` or `pytorch` backend. Default to `turbomind` backend. backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend config instance. Default to none. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. max_log_len (int): Max number of prompt characters or prompt tokens being printed in log. Default: Unlimited """ def __init__(self, model_path: str, model_name: Optional[str] = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, max_log_len: int = None, **kwargs) -> None: logger.info( f'input backend={backend}, backend_config={backend_config}') logger.info(f'input chat_template_config={chat_template_config}') self.model_name, chat_template_name = get_names_from_model( model_path, model_name) if chat_template_config is None: chat_template_config = ChatTemplateConfig(chat_template_name) elif chat_template_config.model_name is None: chat_template_config.model_name = chat_template_name self.chat_template = chat_template_config.chat_template logger.info(f'updated chat_template_onfig={chat_template_config}') # build backend engine if backend == 'turbomind': self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs) elif backend == 'pytorch': self._build_pytorch(model_path=model_path, backend_config=backend_config, **kwargs) else: raise ValueError(f'unsupported backend {backend}') logger.info(f'updated backend_config={self.backend_config}') # parameters for member functions self.session_len = _get_and_verify_max_len( self.hf_tm_cfg, self.backend_config.session_len) self.stop_words = _stop_words(self.chat_template.stop_words, self.engine.tokenizer) if self.stop_words is not None: self.stop_words = self.stop_words[0][0].tolist() self.backend = backend self.instance_num = self.backend_config.max_batch_size self.tokenizer = self.engine.tokenizer self.id2step = {} self.id2inst = {} self.free_insts: asyncio.Queue = None self.instances = [ self.engine.create_instance() for _ in range(self.instance_num) ] self._session_id = count(0) self.request_logger = RequestLogger(max_log_len) self.internal_thread = _EventLoopThread() self.limiter: asyncio.Semaphore = None def close(self): self.internal_thread.close() def _get_free_insts(self): if self.free_insts is None: # `asyncio.Queue` must be created in an async context self.free_insts = asyncio.Queue() for inst in self.instances: self.free_insts.put_nowait(inst) return self.free_insts def _build_turbomind( self, model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, **kwargs): """Innter build method for turbomind backend.""" from lmdeploy import turbomind as tm self.engine = tm.TurboMind.from_pretrained( model_path, engine_config=backend_config, **kwargs) self.backend_config = self.engine.engine_config self.hf_tm_cfg = self.engine.config def _build_pytorch( self, model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, **kwargs): """Innter build method for pytorch backend.""" from lmdeploy.pytorch.engine import Engine self.engine = Engine(model_path=model_path, engine_config=backend_config) self.backend_config = self.engine.engine_config self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None) def __call__(self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], gen_config: Optional[GenerationConfig] = None, do_preprocess: bool = True, adapter_name: Optional[str] = None, use_tqdm: bool = False, **kwargs): """Inference a batch of prompts. Args: prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. adapter_name (str): the adapter name of slora for pytorch backend. Pick one from adapters. Default to None, using the base model. use_tqdm (bool): Whether use the progress bar. Default to False """ if gen_config is None: gen_config = GenerationConfig() return self.batch_infer(prompts, gen_config=gen_config, do_preprocess=do_preprocess, adapter_name=adapter_name, use_tqdm=use_tqdm, **kwargs) async def stop_session(self, session_id: int): """Stop a session by a session_id.""" generator = self.id2inst.get(session_id) if generator: await generator.async_cancel(session_id) # else it's not running at all async def end_session(self, session_id: int): """For ending a session that is not running.""" inst = self.id2inst.get(session_id) if inst: await inst._active.wait() assert session_id not in self.id2inst inst = await self._get_free_insts().get() try: await inst.async_end(session_id) self.id2step[session_id] = 0 except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa logger.error(f'[end_session] exception caught: {e}') finally: self._get_free_insts().put_nowait(inst) def _get_limiter(self): if not self.limiter: self.limiter = asyncio.Semaphore(self.instance_num) return self.limiter async def _async_infer(self, requests: AsyncIterator[Dict], **kwargs) -> AsyncIterator[AsyncIterator[Response]]: async for req in requests: gen = self.generate(**req, **kwargs) yield gen def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]: async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore): async for out in g: que.put(_gen_out_to_response(out, idx)) sem.release() if not multiplex: que.put(None) # sentinel of inner generator if pbar: pbar.update(1) que = Queue() async def _infer(): sem = self._get_limiter() tasks = [] for idx, req in enumerate(requests): await sem.acquire() gen = self.generate(**req) dst = que if multiplex else Queue() if not multiplex: que.put(iter(dst.get, None)) # create a task to send the responses task = asyncio.create_task(_sync_resp(gen, dst, idx, sem)) tasks.append(task) if not multiplex: # sentinel of outer generator que.put(None) await asyncio.gather(*tasks) if multiplex: que.put(None) # sentinel of inner generator loop = loop or self.internal_thread.loop # submit the coroutine to async world asyncio.run_coroutine_threadsafe( _infer(), loop).add_done_callback(lambda x: x.result()) return iter(que.get, None) @staticmethod def _is_single(prompts): return isinstance(prompts, str) or isinstance(prompts[0], Dict) def infer(self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], gen_config: Optional[Union[GenerationConfig, List[GenerationConfig]]] = None, do_preprocess: bool = True, adapter_name: Optional[str] = None, stream_response: bool = False, multiplex: bool = False, pbar: Optional[tqdm.tqdm] = None, **kwargs): prompts = [prompts] if AsyncEngine._is_single(prompts) else prompts assert isinstance(prompts, List), 'prompts should be a list' gen_config = gen_config or GenerationConfig() if not isinstance(gen_config, List): gen_config = [gen_config] * len(prompts) assert len(prompts) == len(gen_config), \ 'input gen_confg length differs from the length of prompts' # noqa def requests(): for prompt, gen_cfg in zip(prompts, gen_config): r = dict(messages=prompt, gen_config=gen_cfg, do_preprocess=do_preprocess, adapter_name=adapter_name, stream_response=stream_response, **kwargs) r.setdefault('sequence_start', True) r.setdefault('sequence_end', True) if 'session_id' not in r: r['session_id'] = next(self._session_id) yield r return self._infer(requests(), multiplex, pbar) def batch_infer(self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], gen_config: Optional[Union[GenerationConfig, List[GenerationConfig]]] = None, do_preprocess: bool = True, adapter_name: Optional[str] = None, use_tqdm: bool = False, **kwargs): """Inference a batch of prompts. Args: prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. adapter_name (str): the adapter name of slora for pytorch backend. Pick one from adapters. Default to None, using the base model. use_tqdm (bool): Whether use the progress bar. Default to False """ is_single = AsyncEngine._is_single(prompts) outputs = [] pbar = tqdm.tqdm( total=1 if is_single else len(prompts)) if use_tqdm else None try: for g in self.infer(prompts, gen_config, do_preprocess, adapter_name, stream_response=False, pbar=pbar, **kwargs): res = None for out in g: res = _append_response(res, out) outputs.append(res) finally: if pbar: pbar.close() # noqa if is_single: return outputs[0] return outputs def stream_infer( self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], gen_config: Optional[Union[GenerationConfig, List[GenerationConfig]]] = None, do_preprocess: bool = True, adapter_name: Optional[str] = None, stream_response: bool = True, **kwargs): """Inference a batch of prompts with stream mode. Args: prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. adapter_name (str): the adapter name of slora for pytorch backend. Pick one from adapters. Default to None, using the base model. """ return self.infer(prompts, gen_config, do_preprocess, adapter_name, stream_response, multiplex=True, **kwargs) async def _get_prompt_input(self, prompt: str, do_preprocess: bool, sequence_start: bool, adapter_name: str, tools: Optional[List[object]] = None, **kwargs): if do_preprocess: # use adapter's chat template if possible chat_template = self.chat_template if adapter_name in MODELS.module_dict: chat_template = MODELS.module_dict[adapter_name]() prompt = chat_template.messages2prompt(prompt, sequence_start, tools=tools) if prompt is None: raise ValueError( f'You are using base template to handle chat task. Please specify a `--chat-template` name chosen from `lmdeploy list` if you want to use OpenAI messages input.' # noqa ) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) return {'prompt': prompt, 'input_ids': input_ids} @asynccontextmanager async def model_inst(self, session_id: int): """A context manager to make sure server's safe running.""" assert session_id not in self.id2inst free_insts = self._get_free_insts() inst = await free_insts.get() inst._active = asyncio.Event() self.id2inst[session_id] = inst try: yield inst finally: self.id2inst.pop(session_id) inst._active.set() free_insts.put_nowait(inst) @asynccontextmanager async def safe_run(self, inst, session_id, **kwargs): # generator = inst.async_stream_infer(session_id, **kwargs) async def fake(): gen_config = kwargs['gen_config'] for i in range(gen_config.max_new_tokens): yield i await asyncio.sleep(0.0) generator = fake() try: yield generator except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa logger.error(f'[safe_run] exception caught: {e}') # TODO: remove session_id from async cancel await inst.async_cancel(session_id) finally: await generator.aclose() async def generate( self, messages, session_id: int, gen_config: Optional[GenerationConfig] = None, tools: Optional[List[object]] = None, stream_response: bool = True, sequence_start: bool = True, sequence_end: bool = True, # no interactive mode by default step: int = 0, do_preprocess: bool = True, adapter_name: Optional[str] = None, **kwargs): """Generate responses. Args: messages (str | List): chat history or prompt session_id (int): the session id gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. stream_response (bool): whether return responses streamingly sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence step (int): the offset of the k/v cache do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. """ if session_id not in self.id2step: self.id2step[session_id] = 0 if step != 0: self.id2step[session_id] = step if gen_config is None: gen_config = GenerationConfig() else: gen_config = deepcopy(gen_config) gen_config.convert_stop_bad_words_to_ids(self.tokenizer) if gen_config.stop_token_ids is None: gen_config.stop_token_ids = self.stop_words if not gen_config.do_sample: logger.warning(f'GenerationConfig: {gen_config}') logger.warning( 'Since v0.6.0, lmdeploy add `do_sample` in ' 'GenerationConfig. It defaults to False, meaning greedy ' 'decoding. Please set `do_sample=True` if sampling ' ' decoding is needed') # greedy decode gen_config.top_k = 1 # avoid unnecessary process gen_config.temperature = 1.0 gen_config.repetition_penalty = 1.0 # set random if it is not set and sequence_start is True elif gen_config.random_seed is None and sequence_start: gen_config.random_seed = random.getrandbits(64) if gen_config.n > 1: logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. " f'Fallback to 1') gen_config.n = 1 prompt = messages self.request_logger.log_prompt(session_id=session_id, prompt=prompt) prompt_input = await self._get_prompt_input(prompt, do_preprocess, sequence_start, adapter_name, tools=tools) prompt = prompt_input['prompt'] input_ids = prompt_input['input_ids'] finish_reason = None self.request_logger.log_inputs(session_id=session_id, prompt=prompt, prompt_token_ids=input_ids, gen_config=gen_config, adapter_name=adapter_name) logger.info(f'session_id={session_id}, ' f'history_tokens={self.id2step[session_id]}, ' f'input_tokens={len(input_ids)}, ' f'max_new_tokens={gen_config.max_new_tokens}, ' f'seq_start={sequence_start}, seq_end={sequence_end}, ' f'step={step}, prep={do_preprocess}') if gen_config.max_new_tokens is None: # for interactive endpoint, will try maximum possible token num gen_config.max_new_tokens = max( 128, self.session_len - self.id2step[session_id] - len(input_ids)) elif self.id2step[session_id] + len( input_ids) + gen_config.max_new_tokens > self.session_len: gen_config.max_new_tokens = max( self.session_len - self.id2step[session_id] - len(input_ids), 128) logger.error( f'Truncate max_new_tokens to {gen_config.max_new_tokens}') if self.id2step[session_id] + len( input_ids) + gen_config.max_new_tokens > self.session_len: logger.error(f'run out of tokens. session_id={session_id}.') yield GenOut('', self.id2step[session_id], len(input_ids), 0, 'length') if sequence_end is True and sequence_start is False: await self.end_session(session_id) return def is_error(status): return status not in [ResponseType.SUCCESS, ResponseType.FINISH] async with self.model_inst(session_id) as inst: state = DetokenizeState(len(input_ids)) token_ids = input_ids.copy() prev_len = 0 start_ids_offset = state.ids_offset response = '' async with self.safe_run(inst, session_id=session_id, **prompt_input, gen_config=gen_config, adapter_name=adapter_name, stream_output=stream_response, sequence_start=sequence_start, sequence_end=sequence_end, step=self.id2step[session_id]) as gen: async for _ in gen: tokens = 0 yield GenOut(response='', history_token_len=0, input_token_len=len(input_ids), generate_token_len=0, finish_reason='length', token_ids=[]) # async for outputs in gen: # # decode res # if is_error(outputs.status): # tokens = 0 # break # tokens = outputs.num_token # token_ids += outputs.token_ids[prev_len - tokens:] # prev_len = tokens # if len(token_ids) <= state.ids_offset: # continue # ids_offset = state.ids_offset # response, state = self.tokenizer.detokenize_incrementally( # token_ids, # state, # skip_special_tokens=gen_config.skip_special_tokens) # res = token_ids[ids_offset:] # logprobs = None # if outputs.logprobs: # log_offset = ids_offset - start_ids_offset # logprobs = outputs.logprobs[log_offset:] # # response, history token len, # # input token len, gen token len # yield GenOut(response, self.id2step[session_id], # len(input_ids), tokens, finish_reason, res, # logprobs) # # end of generator loop # if not is_error(outputs.status): # finish_reason = 'length' \ # if tokens >= gen_config.max_new_tokens else 'stop' # # utf-8 char at the end means it's a potential unfinished # # byte sequence # if not response.endswith('�'): # # avaid returning the last response twice # response = '' # yield GenOut(response, self.id2step[session_id], # len(input_ids), tokens, finish_reason) # else: # yield GenOut(response='internal error happened', # history_token_len=self.id2step[session_id], # input_token_len=len(input_ids), # generate_token_len=0, # finish_reason='error', # token_ids=[]) # update step if sequence_end: self.id2step[session_id] = 0 if self.backend == 'pytorch': # manually end pytorch session await inst.async_end(session_id) else: self.id2step[session_id] += len(input_ids) + tokens def parse_tool_response(self, text, tools, **kwargs): """Parse model response containing tool information. Args: text(str): model response in string format tools(List): tools from user request """ if '<|plugin|>' in text: # internlm2 text, action = text.split('<|action_start|><|plugin|>') action = action.split('<|action_end|>'.strip())[0] action = action[action.find('{'):] action = json.loads(action) name, parameters = action['name'], json.dumps(action.get( 'parameters', action.get('arguments', {})), ensure_ascii=False) call_info_list = [(name, parameters)] elif '<function=' in text: # llama3.1 action, _ = text.split('</function>') parameters = action[action.find('{'):] name = action.split('<function=')[1].split('>{')[0] call_info_list = [(name, parameters)] elif '<tool_call>' in text and '</tool_call>' in text: # qwen2.5 # get tool_call in text pattern = r'<tool_call>(.*?)</tool_call>' match_result_list = re.findall(pattern, text, re.DOTALL) call_info_list = [] for match_result in match_result_list: action = json.loads(match_result) call_info_list.append((action['name'], json.dumps(action['arguments'], ensure_ascii=False))) # get text outside of tags if not text.startswith('<tool_call>'): text = text[:text.find('<tool_call>')] elif not text.endswith('</tool_call>'): text = text[text.rfind('</tool_call>') + len('</tool_call>'):] else: text = '' else: raise RuntimeError(f'Unexpected model response: {text}') call_info_list = [([tool.function.name for tool in tools ].index(call_info[0]), call_info[0], call_info[1]) for call_info in call_info_list] return text, call_info_list def _run(self, fn=None, coro=None, loop=None): assert (fn or coro) and not (fn and coro) loop = loop or self.internal_thread.loop if fn: async def _coro(): return fn() coro = _coro() return asyncio.run_coroutine_threadsafe(coro, loop) def session(self, gen_config: GenerationConfig = None): return Session(self._run(fn=lambda: next(self._session_id)).result(), engine=self, gen_config=gen_config) def chat(self, prompt: str, session=None, gen_config: Optional[GenerationConfig] = None, stream_response=False, **kwargs) -> Union[Session, Iterator]: """Chat. Args: prompt (str): prompt session (Session): the chat session gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. **kwargs (dict): ad hoc parametrization of `gen_config """ if session is None: session = self.session() # sync & init session._prompt = prompt session._response = None sequence_start = session._step == 0 generator = self.infer(prompt, gen_config, sequence_start=sequence_start, sequence_end=False, session_id=session._id, stream_response=stream_response, multiplex=True) def _gen(): resp = None try: for out in generator: resp = _append_response(resp, out) yield out except: # noqa self._run(coro=self.stop_session(session._id)).result() raise else: session._response = resp session._step += resp.generate_token_len + resp.input_token_len session.history.append((session._prompt, resp.text)) if stream_response: session.generator = _gen() else: # run the generator until finish with closing(_gen()) as gen: for _ in gen: pass session.generator = None return session
Editor is loading...
Leave a Comment