Untitled

 avatar
unknown
python
a month ago
7.0 kB
5
Indexable
import argparse
import os

from datasets import Dataset
from datatrove.pipeline.readers import JsonlReader, ParquetReader, HuggingFaceDatasetReader
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig

from llmosaic.configs import GentaskConfig
from llmosaic.utils import set_seed, split_into_batches

set_seed()

DROP_FLAG = "[NO QA]"
def execute_meta_operations(text, operations):
    program = operations.strip(" ")
    if DROP_FLAG.lower() in program.lower():
        return ""
    else:
        return program.replace("```python", "").replace("```", "")

# 定义文本切分函数:按照单词数量切分
def split_into_chunks(text, chunk_size = int(6000)):
    words = text.split()
    return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]



SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."


def main(args):
    # load config
    config = GentaskConfig.from_yaml(args.config_path)
    config.__post_init__()

    tokenizer = AutoTokenizer.from_pretrained(config.model_path)


    # prepare data
    if config.data_format == "parquet":
        # data_reader = HuggingFaceDatasetReader(
        #     dataset=config.data_path,
        #     dataset_options={"split": "train"},
        #     doc_progress=True,
        #     batch_size=config.batch_size,
        #     limit=config.limit,
        # )
        data_reader = JsonlReader(
            data_folder=config.data_path,
            doc_progress=True,
            batch_size=config.batch_size,
            limit=config.limit,
            text_key=config.text_key,
        )
    elif config.data_format == "jsonl.gz":
        data_reader = JsonlReader(
            data_folder=config.data_path,
            doc_progress=True,
            # glob_pattern="output_worker_*.jsonl.gz",
            limit=config.limit,
            text_key=config.text_key,
        )
    arguments = []
    # for _, doc in enumerate(
    #     data_reader.run(rank=config.GLOBAL_RANK, world_size=config.TOTAL_SPLIT)
    # ):
    #     arguments.append({"text": doc.text, "metadata": doc.metadata})
    for idx, doc in enumerate(data_reader.run()):
        if idx % config.TOTAL_SPLIT == config.GLOBAL_RANK:
            arguments.append({"text": doc.text, "metadata": doc.metadata})

    dir_path = config.get_save_dir_path()
    base_name = config.save_name
    os.makedirs(dir_path, exist_ok=True)
    batches = split_into_batches(arguments, config.save_interval)

    gen_config = GenerationConfig(
        temperature=config.temperature,
        max_new_tokens=config.max_tokens,
    )
    engine = pipeline(
        model_path=config.model_path,
        backend_config=TurbomindEngineConfig(
            tp=config.tp_size,
            session_len=32000,
        ),
        use_tqdm=True,
    )

    score_prompt = open(config.prompt_path, "r").read()

    for batch_idx, batch in enumerate(tqdm(batches, desc="Processing batches")):
        all_prompts = []      # 存储所有需要发送给 engine 的消息(包括拆分后的各个 chunk)
        mapping = []          # 用于记录每个 prompt 对应原始 batch 中的样本索引及 chunk 顺序
        # 遍历 batch 中的每个样本
        for sample_idx, sample in enumerate(tqdm(batch, total=len(batch), unit="sample", desc="Tokenizing samples")):
            text = sample["text"]
            # 将长文本切分成多个 chunk(如果文本不长,返回列表中只有一个元素)
            chunks = split_into_chunks(text)
            # 对于每个 chunk,生成对应的 prompt 消息
            for chunk_idx, chunk in enumerate(chunks):
                total_msg = [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": score_prompt.replace("<EXAMPLE>", chunk)},
                ]
                # 如果样本数较小,调试打印部分内容
                if sample_idx <= 100 and chunk_idx <= 10:
                    print("Debug - Prompt message:")
                    # print(total_msg)
                    print(tokenizer.apply_chat_template(total_msg, tokenize=False))
                all_prompts.append(total_msg)
                # 记录下该 prompt 来自哪个文档以及在文档中的第几个 chunk
                mapping.append((sample_idx, chunk_idx))
        
        # 调用生成引擎处理所有 prompt
        outputs = engine(
            all_prompts,
            gen_config,
            use_tqdm=True,
        )
        # 去掉输出两端的空格
        outputs = [item.text.strip(" ") for item in outputs]
        
        # 调试打印部分生成结果
        for idx_out, item in enumerate(outputs):
            if idx_out > 100:
                break
            print(f"Output {idx_out}:")
            print(item)
            print("-" * 100)
        
        # 将输出按照原始文档进行归类:
        # 使用字典 key 为 sample 在 batch 中的索引,value 为对应文档所有 chunk 的生成结果列表
        doc_outputs = {}
        for (doc_idx, chunk_idx), output in zip(mapping, outputs):
            if doc_idx not in doc_outputs:
                doc_outputs[doc_idx] = []

            # rule out empty output
            output = execute_meta_operations(
                "",
                output
            )
            # 如果需要保证顺序,可以按照 chunk_idx 排序;这里假定 engine 的输出顺序与 mapping 保持一致
            doc_outputs[doc_idx].append(output)
        
        # 生成新的结果列表,每个元素对应一个原始文档,并包含所有 chunk 的生成结果
        new_results = []
        for idx, sample in enumerate(batch):
            chunks_output = doc_outputs.get(idx, [])
            new_sample = {
                "text": sample["text"],
                "metadata": sample["metadata"],
                # 将文档所有 chunk 的生成结果以列表形式存储
                "extraction_qa": chunks_output,
            }
            new_results.append(new_sample)
        
        # 将该 batch 的结果转为 Dataset,并保存为 parquet 文件
        intermediate_ds = Dataset.from_list(new_results)
        out_path = os.path.join(
            dir_path,
            f"{base_name}_{batch_idx + 1}_{(len(arguments) - 1) // config.save_interval + 1}.parquet",
        )
        intermediate_ds.to_parquet(out_path)

        # write into logging file
        os.makedirs(os.path.join(config.logging_path, config.save_name), exist_ok=True)
        with open(
            os.path.join(config.logging_path, config.save_name, "finished.log"), "a"
        ) as f:
            f.write(f"Rank [{config.GLOBAL_RANK}] Finished\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        required=True,
    )
    args = parser.parse_args()
    main(args)
Editor is loading...
Leave a Comment