Untitled
unknown
python
a year ago
38 kB
8
Indexable
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import concurrent.futures
import dataclasses
import json
import os
import random
import re
from contextlib import asynccontextmanager, closing
from copy import deepcopy
from functools import partial
from itertools import count
from queue import Queue
from threading import Thread
from typing import (Any, AsyncIterator, Dict, Iterator, List, Literal,
Optional, Tuple, Union)
import tqdm
from lmdeploy.logger import RequestLogger
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response,
ResponseType, TurbomindEngineConfig)
from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model
from lmdeploy.serve.utils import LogitsMixin
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger
logger = get_logger('lmdeploy')
def get_names_from_model(model_path: str, model_name: str = None):
"""Get model name and chat template name from workspace model."""
triton_model_path = os.path.join(model_path, 'triton_models', 'weights')
if not os.path.exists(triton_model_path):
chat_template_name = best_match_model(model_path)
else:
# `model_path` refers to a turbomind model, reading
# chat_template_name from the config
config_path = os.path.join(triton_model_path, 'config.yaml')
with open(config_path, 'r') as f:
import yaml
config = yaml.safe_load(f)
chat_template_name = config['model_config']['chat_template']
model_name = model_name if model_name else model_path
return model_name, chat_template_name
@dataclasses.dataclass
class GenOut:
"""Pack all response information together."""
response: str
history_token_len: int
input_token_len: int
generate_token_len: int
finish_reason: Optional[Literal['stop', 'length', 'error']] = None
token_ids: List[int] = None
logprobs: List[Dict[int, float]] = None
def _gen_out_to_response(out: GenOut, index) -> Response:
return Response(text=out.response,
generate_token_len=out.generate_token_len,
input_token_len=out.input_token_len,
finish_reason=out.finish_reason,
token_ids=out.token_ids,
logprobs=out.logprobs,
index=index)
def _append_response(dst: Response, src: Response):
"""dst += src."""
if not dst:
return src
dst.text += src.text
dst.generate_token_len = src.generate_token_len
dst.input_token_len = src.input_token_len
dst.finish_reason = src.finish_reason
dst.index = src.index
if src.token_ids:
dst.token_ids += src.token_ids
if src.logprobs:
dst.logprobs = dst.logprobs or []
dst.logprobs += src.logprobs
return dst
class Session:
"""Session for AsyncEngine.chat.
Args:
_id (int): session_id for internal use.
_step (int): the offset of the k/v cache for internal use.
_prompt (Any): input prompt for internal use.
_response (Reaponse): model output for prompt.
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
def __init__(self,
session_id: int,
engine: Any,
gen_config: GenerationConfig = None):
self._id: int = session_id
self._engine = engine
self._step: int = 0
self._prompt: Any = None
self._response: Response = None
self._gen_config = gen_config
self.history: List[Tuple[Any, str]] = []
def _merge_response(self, resp: Response, step: Union[Response, GenOut]):
"""merge response."""
resp.text += step.text if isinstance(step, Response) else step.response
resp.input_token_len = step.input_token_len
resp.generate_token_len = step.generate_token_len
resp.finish_reason = step.finish_reason
return resp
@property
def response(self) -> Response:
"""return response."""
return self._response
def close(self):
"""release engine storage for this session."""
if self._engine:
self._engine._run(coro=self._engine.end_session(self._id)).result()
self._engine = None
def __repr__(self) -> str:
res = ''
for user, assistant in self.history:
if isinstance(user, list):
user = str(user)
res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n'
return res
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __call__(
self,
prompt: str,
gen_config: Optional[GenerationConfig] = None,
stream_response: bool = True,
do_preprocess: bool = True) -> Union[Response, Iterator[Response]]:
self._engine.chat(prompt=prompt,
gen_config=gen_config or self._gen_config,
stream_response=stream_response,
do_preprocess=do_preprocess,
session=self)
if stream_response:
return self.generator
else:
return self.response
class _EventLoopThread:
def __init__(self):
fut = concurrent.futures.Future()
self.thread = Thread(
target=partial(_EventLoopThread._thread_entry, fut))
self.thread.start()
self.loop: asyncio.AbstractEventLoop = fut.result()
self.closed = False
@staticmethod
def _thread_entry(fut):
loop = asyncio.new_event_loop()
fut.set_result(loop)
try:
loop.run_forever()
except BaseException as e:
logger.error(f'[internal_thread] {type(e).__name__} {e}')
finally:
loop.close()
def close(self):
if self.closed:
return
self.closed = True
self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()
def __del__(self):
self.close()
class AsyncEngine(LogitsMixin):
"""Async inference engine. Maintaining a bunch of tm_model instances.
Args:
model_path (str): the path of a model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
backend (str): either `turbomind` or `pytorch` backend. Default to
`turbomind` backend.
backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
config instance. Default to none.
chat_template_config (ChatTemplateConfig): chat template configuration.
Default to None.
max_log_len (int): Max number of prompt characters or prompt tokens
being printed in log. Default: Unlimited
"""
def __init__(self,
model_path: str,
model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
max_log_len: int = None,
**kwargs) -> None:
logger.info(
f'input backend={backend}, backend_config={backend_config}')
logger.info(f'input chat_template_config={chat_template_config}')
self.model_name, chat_template_name = get_names_from_model(
model_path, model_name)
if chat_template_config is None:
chat_template_config = ChatTemplateConfig(chat_template_name)
elif chat_template_config.model_name is None:
chat_template_config.model_name = chat_template_name
self.chat_template = chat_template_config.chat_template
logger.info(f'updated chat_template_onfig={chat_template_config}')
# build backend engine
if backend == 'turbomind':
self._build_turbomind(model_path=model_path,
backend_config=backend_config,
**kwargs)
elif backend == 'pytorch':
self._build_pytorch(model_path=model_path,
backend_config=backend_config,
**kwargs)
else:
raise ValueError(f'unsupported backend {backend}')
logger.info(f'updated backend_config={self.backend_config}')
# parameters for member functions
self.session_len = _get_and_verify_max_len(
self.hf_tm_cfg, self.backend_config.session_len)
self.stop_words = _stop_words(self.chat_template.stop_words,
self.engine.tokenizer)
if self.stop_words is not None:
self.stop_words = self.stop_words[0][0].tolist()
self.backend = backend
self.instance_num = self.backend_config.max_batch_size
self.tokenizer = self.engine.tokenizer
self.id2step = {}
self.id2inst = {}
self.free_insts: asyncio.Queue = None
self.instances = [
self.engine.create_instance() for _ in range(self.instance_num)
]
self._session_id = count(0)
self.request_logger = RequestLogger(max_log_len)
self.internal_thread = _EventLoopThread()
self.limiter: asyncio.Semaphore = None
def close(self):
self.internal_thread.close()
def _get_free_insts(self):
if self.free_insts is None:
# `asyncio.Queue` must be created in an async context
self.free_insts = asyncio.Queue()
for inst in self.instances:
self.free_insts.put_nowait(inst)
return self.free_insts
def _build_turbomind(
self,
model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
**kwargs):
"""Innter build method for turbomind backend."""
from lmdeploy import turbomind as tm
self.engine = tm.TurboMind.from_pretrained(
model_path, engine_config=backend_config, **kwargs)
self.backend_config = self.engine.engine_config
self.hf_tm_cfg = self.engine.config
def _build_pytorch(
self,
model_path: str,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None,
**kwargs):
"""Innter build method for pytorch backend."""
from lmdeploy.pytorch.engine import Engine
self.engine = Engine(model_path=model_path,
engine_config=backend_config)
self.backend_config = self.engine.engine_config
self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None)
def __call__(self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[GenerationConfig] = None,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
use_tqdm: bool = False,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a
batch of prompts. It accepts: string prompt, a list of string
prompts, a chat history in OpenAI format or a list of chat
history.
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
adapter_name (str): the adapter name of slora for pytorch backend.
Pick one from adapters. Default to None, using the base model.
use_tqdm (bool): Whether use the progress bar. Default to False
"""
if gen_config is None:
gen_config = GenerationConfig()
return self.batch_infer(prompts,
gen_config=gen_config,
do_preprocess=do_preprocess,
adapter_name=adapter_name,
use_tqdm=use_tqdm,
**kwargs)
async def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
generator = self.id2inst.get(session_id)
if generator:
await generator.async_cancel(session_id)
# else it's not running at all
async def end_session(self, session_id: int):
"""For ending a session that is not running."""
inst = self.id2inst.get(session_id)
if inst:
await inst._active.wait()
assert session_id not in self.id2inst
inst = await self._get_free_insts().get()
try:
await inst.async_end(session_id)
self.id2step[session_id] = 0
except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa
logger.error(f'[end_session] exception caught: {e}')
finally:
self._get_free_insts().put_nowait(inst)
def _get_limiter(self):
if not self.limiter:
self.limiter = asyncio.Semaphore(self.instance_num)
return self.limiter
async def _async_infer(self, requests: AsyncIterator[Dict],
**kwargs) -> AsyncIterator[AsyncIterator[Response]]:
async for req in requests:
gen = self.generate(**req, **kwargs)
yield gen
def _infer(self,
requests: Iterator[Dict],
multiplex: bool,
pbar=None,
loop=None) -> Iterator[Iterator[Response]]:
async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):
async for out in g:
que.put(_gen_out_to_response(out, idx))
sem.release()
if not multiplex:
que.put(None) # sentinel of inner generator
if pbar:
pbar.update(1)
que = Queue()
async def _infer():
sem = self._get_limiter()
tasks = []
for idx, req in enumerate(requests):
await sem.acquire()
gen = self.generate(**req)
dst = que if multiplex else Queue()
if not multiplex:
que.put(iter(dst.get, None))
# create a task to send the responses
task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))
tasks.append(task)
if not multiplex: # sentinel of outer generator
que.put(None)
await asyncio.gather(*tasks)
if multiplex:
que.put(None) # sentinel of inner generator
loop = loop or self.internal_thread.loop
# submit the coroutine to async world
asyncio.run_coroutine_threadsafe(
_infer(), loop).add_done_callback(lambda x: x.result())
return iter(que.get, None)
@staticmethod
def _is_single(prompts):
return isinstance(prompts, str) or isinstance(prompts[0], Dict)
def infer(self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig]]] = None,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
stream_response: bool = False,
multiplex: bool = False,
pbar: Optional[tqdm.tqdm] = None,
**kwargs):
prompts = [prompts] if AsyncEngine._is_single(prompts) else prompts
assert isinstance(prompts, List), 'prompts should be a list'
gen_config = gen_config or GenerationConfig()
if not isinstance(gen_config, List):
gen_config = [gen_config] * len(prompts)
assert len(prompts) == len(gen_config), \
'input gen_confg length differs from the length of prompts' # noqa
def requests():
for prompt, gen_cfg in zip(prompts, gen_config):
r = dict(messages=prompt,
gen_config=gen_cfg,
do_preprocess=do_preprocess,
adapter_name=adapter_name,
stream_response=stream_response,
**kwargs)
r.setdefault('sequence_start', True)
r.setdefault('sequence_end', True)
if 'session_id' not in r:
r['session_id'] = next(self._session_id)
yield r
return self._infer(requests(), multiplex, pbar)
def batch_infer(self,
prompts: Union[List[str], str, List[Dict],
List[List[Dict]]],
gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig]]] = None,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
use_tqdm: bool = False,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a
batch of prompts. It accepts: string prompt, a list of string
prompts, a chat history in OpenAI format or a list of chat
history.
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
adapter_name (str): the adapter name of slora for pytorch backend.
Pick one from adapters. Default to None, using the base model.
use_tqdm (bool): Whether use the progress bar. Default to False
"""
is_single = AsyncEngine._is_single(prompts)
outputs = []
pbar = tqdm.tqdm(
total=1 if is_single else len(prompts)) if use_tqdm else None
try:
for g in self.infer(prompts,
gen_config,
do_preprocess,
adapter_name,
stream_response=False,
pbar=pbar,
**kwargs):
res = None
for out in g:
res = _append_response(res, out)
outputs.append(res)
finally:
if pbar: pbar.close() # noqa
if is_single:
return outputs[0]
return outputs
def stream_infer(
self,
prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
gen_config: Optional[Union[GenerationConfig,
List[GenerationConfig]]] = None,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
stream_response: bool = True,
**kwargs):
"""Inference a batch of prompts with stream mode.
Args:
prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a
batch of prompts. It accepts: string prompt, a list of string
prompts, a chat history in OpenAI format or a list of chat
history.
gen_config (GenerationConfig | None): a instance of or a list of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
adapter_name (str): the adapter name of slora for pytorch backend.
Pick one from adapters. Default to None, using the base model.
"""
return self.infer(prompts,
gen_config,
do_preprocess,
adapter_name,
stream_response,
multiplex=True,
**kwargs)
async def _get_prompt_input(self,
prompt: str,
do_preprocess: bool,
sequence_start: bool,
adapter_name: str,
tools: Optional[List[object]] = None,
**kwargs):
if do_preprocess:
# use adapter's chat template if possible
chat_template = self.chat_template
if adapter_name in MODELS.module_dict:
chat_template = MODELS.module_dict[adapter_name]()
prompt = chat_template.messages2prompt(prompt,
sequence_start,
tools=tools)
if prompt is None:
raise ValueError(
f'You are using base template to handle chat task. Please specify a `--chat-template` name chosen from `lmdeploy list` if you want to use OpenAI messages input.' # noqa
)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
return {'prompt': prompt, 'input_ids': input_ids}
@asynccontextmanager
async def model_inst(self, session_id: int):
"""A context manager to make sure server's safe running."""
assert session_id not in self.id2inst
free_insts = self._get_free_insts()
inst = await free_insts.get()
inst._active = asyncio.Event()
self.id2inst[session_id] = inst
try:
yield inst
finally:
self.id2inst.pop(session_id)
inst._active.set()
free_insts.put_nowait(inst)
@asynccontextmanager
async def safe_run(self, inst, session_id, **kwargs):
# generator = inst.async_stream_infer(session_id, **kwargs)
async def fake():
gen_config = kwargs['gen_config']
for i in range(gen_config.max_new_tokens):
yield i
await asyncio.sleep(0.0)
generator = fake()
try:
yield generator
except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa
logger.error(f'[safe_run] exception caught: {e}')
# TODO: remove session_id from async cancel
await inst.async_cancel(session_id)
finally:
await generator.aclose()
async def generate(
self,
messages,
session_id: int,
gen_config: Optional[GenerationConfig] = None,
tools: Optional[List[object]] = None,
stream_response: bool = True,
sequence_start: bool = True,
sequence_end: bool = True, # no interactive mode by default
step: int = 0,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
**kwargs):
"""Generate responses.
Args:
messages (str | List): chat history or prompt
session_id (int): the session id
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
stream_response (bool): whether return responses streamingly
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
if session_id not in self.id2step:
self.id2step[session_id] = 0
if step != 0:
self.id2step[session_id] = step
if gen_config is None:
gen_config = GenerationConfig()
else:
gen_config = deepcopy(gen_config)
gen_config.convert_stop_bad_words_to_ids(self.tokenizer)
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
if not gen_config.do_sample:
logger.warning(f'GenerationConfig: {gen_config}')
logger.warning(
'Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
' decoding is needed')
# greedy decode
gen_config.top_k = 1
# avoid unnecessary process
gen_config.temperature = 1.0
gen_config.repetition_penalty = 1.0
# set random if it is not set and sequence_start is True
elif gen_config.random_seed is None and sequence_start:
gen_config.random_seed = random.getrandbits(64)
if gen_config.n > 1:
logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. "
f'Fallback to 1')
gen_config.n = 1
prompt = messages
self.request_logger.log_prompt(session_id=session_id, prompt=prompt)
prompt_input = await self._get_prompt_input(prompt,
do_preprocess,
sequence_start,
adapter_name,
tools=tools)
prompt = prompt_input['prompt']
input_ids = prompt_input['input_ids']
finish_reason = None
self.request_logger.log_inputs(session_id=session_id,
prompt=prompt,
prompt_token_ids=input_ids,
gen_config=gen_config,
adapter_name=adapter_name)
logger.info(f'session_id={session_id}, '
f'history_tokens={self.id2step[session_id]}, '
f'input_tokens={len(input_ids)}, '
f'max_new_tokens={gen_config.max_new_tokens}, '
f'seq_start={sequence_start}, seq_end={sequence_end}, '
f'step={step}, prep={do_preprocess}')
if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(
128,
self.session_len - self.id2step[session_id] - len(input_ids))
elif self.id2step[session_id] + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
gen_config.max_new_tokens = max(
self.session_len - self.id2step[session_id] - len(input_ids),
128)
logger.error(
f'Truncate max_new_tokens to {gen_config.max_new_tokens}')
if self.id2step[session_id] + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
logger.error(f'run out of tokens. session_id={session_id}.')
yield GenOut('', self.id2step[session_id], len(input_ids), 0,
'length')
if sequence_end is True and sequence_start is False:
await self.end_session(session_id)
return
def is_error(status):
return status not in [ResponseType.SUCCESS, ResponseType.FINISH]
async with self.model_inst(session_id) as inst:
state = DetokenizeState(len(input_ids))
token_ids = input_ids.copy()
prev_len = 0
start_ids_offset = state.ids_offset
response = ''
async with self.safe_run(inst,
session_id=session_id,
**prompt_input,
gen_config=gen_config,
adapter_name=adapter_name,
stream_output=stream_response,
sequence_start=sequence_start,
sequence_end=sequence_end,
step=self.id2step[session_id]) as gen:
async for _ in gen:
tokens = 0
yield GenOut(response='',
history_token_len=0,
input_token_len=len(input_ids),
generate_token_len=0,
finish_reason='length',
token_ids=[])
# async for outputs in gen:
# # decode res
# if is_error(outputs.status):
# tokens = 0
# break
# tokens = outputs.num_token
# token_ids += outputs.token_ids[prev_len - tokens:]
# prev_len = tokens
# if len(token_ids) <= state.ids_offset:
# continue
# ids_offset = state.ids_offset
# response, state = self.tokenizer.detokenize_incrementally(
# token_ids,
# state,
# skip_special_tokens=gen_config.skip_special_tokens)
# res = token_ids[ids_offset:]
# logprobs = None
# if outputs.logprobs:
# log_offset = ids_offset - start_ids_offset
# logprobs = outputs.logprobs[log_offset:]
# # response, history token len,
# # input token len, gen token len
# yield GenOut(response, self.id2step[session_id],
# len(input_ids), tokens, finish_reason, res,
# logprobs)
# # end of generator loop
# if not is_error(outputs.status):
# finish_reason = 'length' \
# if tokens >= gen_config.max_new_tokens else 'stop'
# # utf-8 char at the end means it's a potential unfinished
# # byte sequence
# if not response.endswith('�'):
# # avaid returning the last response twice
# response = ''
# yield GenOut(response, self.id2step[session_id],
# len(input_ids), tokens, finish_reason)
# else:
# yield GenOut(response='internal error happened',
# history_token_len=self.id2step[session_id],
# input_token_len=len(input_ids),
# generate_token_len=0,
# finish_reason='error',
# token_ids=[])
# update step
if sequence_end:
self.id2step[session_id] = 0
if self.backend == 'pytorch':
# manually end pytorch session
await inst.async_end(session_id)
else:
self.id2step[session_id] += len(input_ids) + tokens
def parse_tool_response(self, text, tools, **kwargs):
"""Parse model response containing tool information.
Args:
text(str): model response in string format
tools(List): tools from user request
"""
if '<|plugin|>' in text: # internlm2
text, action = text.split('<|action_start|><|plugin|>')
action = action.split('<|action_end|>'.strip())[0]
action = action[action.find('{'):]
action = json.loads(action)
name, parameters = action['name'], json.dumps(action.get(
'parameters', action.get('arguments', {})),
ensure_ascii=False)
call_info_list = [(name, parameters)]
elif '<function=' in text: # llama3.1
action, _ = text.split('</function>')
parameters = action[action.find('{'):]
name = action.split('<function=')[1].split('>{')[0]
call_info_list = [(name, parameters)]
elif '<tool_call>' in text and '</tool_call>' in text: # qwen2.5
# get tool_call in text
pattern = r'<tool_call>(.*?)</tool_call>'
match_result_list = re.findall(pattern, text, re.DOTALL)
call_info_list = []
for match_result in match_result_list:
action = json.loads(match_result)
call_info_list.append((action['name'],
json.dumps(action['arguments'],
ensure_ascii=False)))
# get text outside of tags
if not text.startswith('<tool_call>'):
text = text[:text.find('<tool_call>')]
elif not text.endswith('</tool_call>'):
text = text[text.rfind('</tool_call>') + len('</tool_call>'):]
else:
text = ''
else:
raise RuntimeError(f'Unexpected model response: {text}')
call_info_list = [([tool.function.name for tool in tools
].index(call_info[0]), call_info[0], call_info[1])
for call_info in call_info_list]
return text, call_info_list
def _run(self, fn=None, coro=None, loop=None):
assert (fn or coro) and not (fn and coro)
loop = loop or self.internal_thread.loop
if fn:
async def _coro():
return fn()
coro = _coro()
return asyncio.run_coroutine_threadsafe(coro, loop)
def session(self, gen_config: GenerationConfig = None):
return Session(self._run(fn=lambda: next(self._session_id)).result(),
engine=self,
gen_config=gen_config)
def chat(self,
prompt: str,
session=None,
gen_config: Optional[GenerationConfig] = None,
stream_response=False,
**kwargs) -> Union[Session, Iterator]:
"""Chat.
Args:
prompt (str): prompt
session (Session): the chat session
gen_config (GenerationConfig | None): a instance of
GenerationConfig. Default to None.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if session is None:
session = self.session()
# sync & init
session._prompt = prompt
session._response = None
sequence_start = session._step == 0
generator = self.infer(prompt,
gen_config,
sequence_start=sequence_start,
sequence_end=False,
session_id=session._id,
stream_response=stream_response,
multiplex=True)
def _gen():
resp = None
try:
for out in generator:
resp = _append_response(resp, out)
yield out
except: # noqa
self._run(coro=self.stop_session(session._id)).result()
raise
else:
session._response = resp
session._step += resp.generate_token_len + resp.input_token_len
session.history.append((session._prompt, resp.text))
if stream_response:
session.generator = _gen()
else:
# run the generator until finish
with closing(_gen()) as gen:
for _ in gen:
pass
session.generator = None
return session
Editor is loading...
Leave a Comment