Untitled
unknown
python
a month ago
7.0 kB
5
Indexable
import argparse import os from datasets import Dataset from datatrove.pipeline.readers import JsonlReader, ParquetReader, HuggingFaceDatasetReader from tqdm import tqdm from transformers import AutoTokenizer from vllm import LLM, SamplingParams from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig from llmosaic.configs import GentaskConfig from llmosaic.utils import set_seed, split_into_batches set_seed() DROP_FLAG = "[NO QA]" def execute_meta_operations(text, operations): program = operations.strip(" ") if DROP_FLAG.lower() in program.lower(): return "" else: return program.replace("```python", "").replace("```", "") # 定义文本切分函数:按照单词数量切分 def split_into_chunks(text, chunk_size = int(6000)): words = text.split() return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant." def main(args): # load config config = GentaskConfig.from_yaml(args.config_path) config.__post_init__() tokenizer = AutoTokenizer.from_pretrained(config.model_path) # prepare data if config.data_format == "parquet": # data_reader = HuggingFaceDatasetReader( # dataset=config.data_path, # dataset_options={"split": "train"}, # doc_progress=True, # batch_size=config.batch_size, # limit=config.limit, # ) data_reader = JsonlReader( data_folder=config.data_path, doc_progress=True, batch_size=config.batch_size, limit=config.limit, text_key=config.text_key, ) elif config.data_format == "jsonl.gz": data_reader = JsonlReader( data_folder=config.data_path, doc_progress=True, # glob_pattern="output_worker_*.jsonl.gz", limit=config.limit, text_key=config.text_key, ) arguments = [] # for _, doc in enumerate( # data_reader.run(rank=config.GLOBAL_RANK, world_size=config.TOTAL_SPLIT) # ): # arguments.append({"text": doc.text, "metadata": doc.metadata}) for idx, doc in enumerate(data_reader.run()): if idx % config.TOTAL_SPLIT == config.GLOBAL_RANK: arguments.append({"text": doc.text, "metadata": doc.metadata}) dir_path = config.get_save_dir_path() base_name = config.save_name os.makedirs(dir_path, exist_ok=True) batches = split_into_batches(arguments, config.save_interval) gen_config = GenerationConfig( temperature=config.temperature, max_new_tokens=config.max_tokens, ) engine = pipeline( model_path=config.model_path, backend_config=TurbomindEngineConfig( tp=config.tp_size, session_len=32000, ), use_tqdm=True, ) score_prompt = open(config.prompt_path, "r").read() for batch_idx, batch in enumerate(tqdm(batches, desc="Processing batches")): all_prompts = [] # 存储所有需要发送给 engine 的消息(包括拆分后的各个 chunk) mapping = [] # 用于记录每个 prompt 对应原始 batch 中的样本索引及 chunk 顺序 # 遍历 batch 中的每个样本 for sample_idx, sample in enumerate(tqdm(batch, total=len(batch), unit="sample", desc="Tokenizing samples")): text = sample["text"] # 将长文本切分成多个 chunk(如果文本不长,返回列表中只有一个元素) chunks = split_into_chunks(text) # 对于每个 chunk,生成对应的 prompt 消息 for chunk_idx, chunk in enumerate(chunks): total_msg = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": score_prompt.replace("<EXAMPLE>", chunk)}, ] # 如果样本数较小,调试打印部分内容 if sample_idx <= 100 and chunk_idx <= 10: print("Debug - Prompt message:") # print(total_msg) print(tokenizer.apply_chat_template(total_msg, tokenize=False)) all_prompts.append(total_msg) # 记录下该 prompt 来自哪个文档以及在文档中的第几个 chunk mapping.append((sample_idx, chunk_idx)) # 调用生成引擎处理所有 prompt outputs = engine( all_prompts, gen_config, use_tqdm=True, ) # 去掉输出两端的空格 outputs = [item.text.strip(" ") for item in outputs] # 调试打印部分生成结果 for idx_out, item in enumerate(outputs): if idx_out > 100: break print(f"Output {idx_out}:") print(item) print("-" * 100) # 将输出按照原始文档进行归类: # 使用字典 key 为 sample 在 batch 中的索引,value 为对应文档所有 chunk 的生成结果列表 doc_outputs = {} for (doc_idx, chunk_idx), output in zip(mapping, outputs): if doc_idx not in doc_outputs: doc_outputs[doc_idx] = [] # rule out empty output output = execute_meta_operations( "", output ) # 如果需要保证顺序,可以按照 chunk_idx 排序;这里假定 engine 的输出顺序与 mapping 保持一致 doc_outputs[doc_idx].append(output) # 生成新的结果列表,每个元素对应一个原始文档,并包含所有 chunk 的生成结果 new_results = [] for idx, sample in enumerate(batch): chunks_output = doc_outputs.get(idx, []) new_sample = { "text": sample["text"], "metadata": sample["metadata"], # 将文档所有 chunk 的生成结果以列表形式存储 "extraction_qa": chunks_output, } new_results.append(new_sample) # 将该 batch 的结果转为 Dataset,并保存为 parquet 文件 intermediate_ds = Dataset.from_list(new_results) out_path = os.path.join( dir_path, f"{base_name}_{batch_idx + 1}_{(len(arguments) - 1) // config.save_interval + 1}.parquet", ) intermediate_ds.to_parquet(out_path) # write into logging file os.makedirs(os.path.join(config.logging_path, config.save_name), exist_ok=True) with open( os.path.join(config.logging_path, config.save_name, "finished.log"), "a" ) as f: f.write(f"Rank [{config.GLOBAL_RANK}] Finished\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config_path", type=str, required=True, ) args = parser.parse_args() main(args)
Editor is loading...
Leave a Comment