Untitled

mail@pastecode.io avatar
unknown
python
8 months ago
8.7 kB
4
Indexable
Never
import logging
import pickle
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal

import datasets
import fire
import numpy as np
import structlog
from torch import cuda
from tqdm import tqdm

import bocoel
from bocoel import (
    AcquisitionFunc,
    Adaptor,
    AxServiceOptimizer,
    ClassifierModel,
    ComposedCorpus,
    DatasetsStorage,
    Distance,
    EnsembleEmbedder,
    GlueAdaptor,
    HnswlibIndex,
    HuggingfaceEmbedder,
    HuggingfaceLogitsLM,
    HuggingfaceSequenceLM,
    Index,
    KMeansOptimizer,
    KMedoidsOptimizer,
    Optimizer,
    PolarIndex,
    Sst2QuestionAnswer,
    Task,
    WhiteningIndex,
)

structlog.configure(
    wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
)

LOGGER = structlog.get_logger()


def main(
    *,
    ds_path: Literal[
        "SST2",
        "SetFit/mnli",
        "SetFit/mrpc",
        "SetFit/qnli",
        "SetFit/rte",
        "SetFit/qqp",
        "SetFit/sst2",
    ] = "SST2",
    ds_split: Literal["train", "validation", "test"] = "train",
    llm_model: str = "textattack/roberta-base-SST-2",
    batch_size: int = 16,
    index_name: Literal["hnswlib", "polar", "whitening"] = "hnswlib",
    sobol_steps: int = 5,
    index_threads: int = 8,
    optimizer_steps: int = 60,
    reduced: int = 32,
    device: str = "cpu",
    acqf: str = "ENTROPY",
    task: str = "EXPLORE",
    classification: Literal["logits", "seq"] = "seq",
    optimizer: Literal["ax", "kmeans", "kmedoids"] = "ax",
    corpus_cache_path: str | Path = "corpus.pickle",
    embedders: Sequence[str] = [
        # "textattack/bert-base-uncased-SST-2",
        #   "textattack/roberta-base-SST-2",
        #   "textattack/albert-base-v2-SST-2",
        #   "textattack/xlnet-large-cased-SST-2",
        #   "textattack/xlnet-base-cased-SST-2",
        #   "textattack/facebook-bart-large-SST-2",
        #   "textattack/distilbert-base-uncased-SST-2",
        "textattack/distilbert-base-cased-SST-2"
    ],
) -> None:
    # The corpus part

    sentence, label = sentence_label(ds_path)
    corpus_cache_path = Path(corpus_cache_path)
    corpus: ComposedCorpus
    if corpus_cache_path.exists():
        with open(corpus_cache_path, "rb") as f:
            corpus = pickle.load(f)
    else:
        corpus = composed_corpus(
            ds_path=ds_path,
            ds_split=ds_split,
            batch_size=batch_size,
            device=device,
            index_name=index_name,
            index_threads=index_threads,
            reduced=reduced,
            sentence=sentence,
            embedders=embedders,
        )
        with open(corpus_cache_path, "wb") as f:
            pickle.dump(corpus, f)

    # ------------------------
    # The model part

    task_name = ds_path.lower().replace("setfit/", "")

    lm: ClassifierModel
    LOGGER.info(
        "Creating LM with model", model=llm_model, device=device, task=task_name
    )
    match classification:
        case "seq":
            lm = HuggingfaceSequenceLM(
                model_path=llm_model,
                device=device,
                choices=GlueAdaptor.choices_per_task(task_name),
            )
        case "logits":
            lm = HuggingfaceLogitsLM(
                model_path=llm_model,
                batch_size=batch_size,
                device=device,
                choices=GlueAdaptor.choices_per_task(task_name),
            )
        case _:
            raise ValueError(f"Unknown classification {classification}")

    # ------------------------
    # Adaptor part

    LOGGER.info("Creating adaptor with arguments", sentence=sentence, label=label)
    adaptor: Adaptor
    if "setfit/" in ds_path.lower():
        adaptor = GlueAdaptor.task(task_name, lm)
    elif ds_path == "SST2":
        adaptor = Sst2QuestionAnswer(lm)
    else:
        raise ValueError(f"Unknown dataset {ds_path}")

    # ------------------------
    # The optimizer part.

    LOGGER.info(
        "Creating optimizer with arguments",
        corpus=corpus,
        lm=lm,
        adaptor=adaptor,
        sobol_steps=sobol_steps,
        device=device,
        acqf=acqf,
    )

    optim: Optimizer
    match optimizer:
        case "ax":
            optim = bocoel.evaluate_corpus(
                AxServiceOptimizer,
                corpus=corpus,
                adaptor=adaptor,
                sobol_steps=sobol_steps,
                device=device,
                task=Task.lookup(task),
                acqf=AcquisitionFunc.lookup(acqf),
            )
        case "kmeans":
            optim = bocoel.evaluate_corpus(
                KMeansOptimizer,
                corpus=corpus,
                adaptor=adaptor,
                batch_size=batch_size,
                embeddings=corpus.index.embeddings,
                model_kwargs={"n_clusters": optimizer_steps, "n_init": "auto"},
            )
        case "kmedoids":
            optim = bocoel.evaluate_corpus(
                KMedoidsOptimizer,
                corpus=corpus,
                adaptor=adaptor,
                batch_size=batch_size,
                embeddings=corpus.index.embeddings,
                model_kwargs={"n_clusters": optimizer_steps},
            )

    scores: list[float] = []
    for i in tqdm(range(optimizer_steps)):
        try:
            state = optim.step()
            LOGGER.info("iteration {i}: {state}", i=i, state=state)
            scores.extend(state.values())
        except StopIteration:
            break

    # Performs aggregation here.
    print("average:", np.average(scores))


def sentence_label(ds_path: str) -> tuple[str, str]:
    if "setfit" in ds_path.lower():
        sentence = "text" if ds_path == "SetFit/sst2" else "text1 text2"
        label = "label"
    elif ds_path == "SST2":
        sentence = "sentence"
        label = "label"
    else:
        raise ValueError(f"Unknown dataset {ds_path}")
    return sentence, label


def ensemble_embedder(embedders: Sequence[str], batch_size: int):
    LOGGER.info("Creating embedder")
    embs = []

    cuda_available = cuda.is_available()
    device_count = cuda.device_count()
    for i, model in enumerate(embedders):
        # Auto cast devices
        if cuda_available:
            hf_device = f"cuda:{i%device_count}"
        else:
            hf_device = "cpu"
        embs.append(
            HuggingfaceEmbedder(path=model, device=hf_device, batch_size=batch_size)
        )
    return EnsembleEmbedder(embs)


def index_backend_and_kwargs(
    name: str, index_threads: int, batch_size: int, reduced: int
) -> tuple[type[Index], dict[str, Any]]:
    match name:
        case "hnswlib":
            return HnswlibIndex, {"threads": index_threads, "batch_size": batch_size}
        case "polar":
            return PolarIndex, {
                "polar_backend": HnswlibIndex,
                "threads": index_threads,
                "batch_size": batch_size,
            }
        case "whitening":
            return WhiteningIndex, {
                "whitening_backend": HnswlibIndex,
                "reduced": reduced,
                "threads": index_threads,
                "batch_size": batch_size,
            }
        case _:
            raise ValueError(f"Unknown index backend {name}")


def composed_corpus(
    ds_path: str,
    ds_split: str,
    batch_size: int,
    device: str,
    index_name: str,
    index_threads: int,
    reduced: int,
    sentence: str,
    embedders: Sequence[str],
) -> ComposedCorpus:
    LOGGER.info("Loading datasets...", dataset=ds_path, split=ds_split)
    ds = datasets.load_dataset(ds_path)[ds_split]
    storage = DatasetsStorage(ds)

    embedder = ensemble_embedder(batch_size=batch_size, embedders=embedders)

    LOGGER.info(
        "Creating corpus with storage and embedder",
        storage=storage,
        embedder=embedder,
        device=device,
    )

    index_backend, index_kwargs = index_backend_and_kwargs(
        name=index_name,
        index_threads=index_threads,
        batch_size=batch_size,
        reduced=min(reduced, embedder.dims),
    )
    corpus = ComposedCorpus.index_storage(
        storage=storage,
        embedder=embedder,
        keys=sentence.split(),
        index_backend=index_backend,
        concat=" [SEP] ".join,
        distance=Distance.INNER_PRODUCT,
        **index_kwargs,
    )
    return corpus


if __name__ == "__main__":
    fire.Fire(main)
Leave a Comment