Untitled

mail@pastecode.io avatar
unknown
python
3 years ago
4.5 kB
6
Indexable
Never
import argparse
import warnings
import torch
import os
import re
from multihead_model import MultiheadIntentModel
import json
from collections import Counter

def get_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", type=str, default="/data2/nlp_team/data/topic_change/pl/latest/alior_dialogs.json", help="Sentence or utterance. If input == load_dialogues_alior then will load three example dialogues.")
    parser.add_argument("--checkpoint", type=str, default="/data2/amikolajczyk/share/multihead_intent/herbert-intent-intent-full.pkl", help="path to pretrained model")
    parser.add_argument("--class_dict", type=str, default="/data2/amikolajczyk/share/multihead_intent/multihead_class_dict_trans.json", help="class_to_idx dict in json file")
    parser.add_argument("--intent_dict", type=str, default="/data2/amikolajczyk/share/intents/dict/base_intent_dict.json", help="maps intents to domains and topics")
    parser.add_argument("--model_name", type=str, default="allegro/herbert-base-cased", help="model name")
    parser.add_argument("--gpu_device", type=str, default="6", help="GPU device")
    return parser
    
def main(args):
    warnings.filterwarnings("ignore")
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    warnings.filterwarnings("ignore")

    print("Loading class dict...")
    with open(args.class_dict) as json_file:
        class_dict = json.load(json_file)
    inv_class_dict = {v: k for k, v in class_dict.items()}

    print("Loading intent dict...")
    with open(args.intent_dict, 'r') as fp:
        intent_dict = json.load(fp)

    print("Loading model...")
    model = MultiheadIntentModel(
        model_name=args.model_name,
        num_labels=len(list(class_dict.keys())),
        max_len=512,
        gpu_device="cuda:0",
    ).to("cuda:0")
    
    print('Loading pretrained model...')  
    model.load_state_dict(torch.load(
                args.checkpoint, map_location="cuda"))
    model.eval()
    
    # accepted_domains = ["bankowość","obsługa klienta", "bezpieczeństwo", "transakcje"]
    accepted_domains = None

    topics = list()
    topic_combinations = list()
    domains = list()

    with open(args.input) as json_file:
        dialogues = json.load(json_file)
        
        # open single dialogue
        for dialogue in dialogues:
            topic_combination = list()
            intents_path = list()
            print("\n\nNew dialogue!")
            for turn in dialogue["turns"]:
                print(turn["speaker_id"])
                text = turn["utterance"]
                binary_out, intent_out = model([text])
                binary_id = torch.argmax(binary_out).item()
                intent_id = torch.argmax(intent_out).item()
                binary_prediction = ["intencja" if binary_id == 1 else "brak"]
                intent_prediction = inv_class_dict[intent_id]
                print(text)
                if intent_id != 0:
                    current_domain = intent_dict[str(intent_id)]["domain"]
                    current_topic = intent_dict[str(intent_id)]["topic"]
                    if accepted_domains != None and current_domain not in accepted_domains:
                        current_domain = "nieznana"
                        current_topic = "nieznany"
                        intent_prediction = "nieznana"
                    topics.append(current_topic)
                    domains.append(current_domain)
                    if len(topic_combination) == 0:
                        topic_combination.append(current_topic)
                    elif topic_combination[-1] != current_topic:
                        topic_combination.append(current_topic)

                    print("\n\t intent:",intent_prediction, intent_id,
                        "\n\tdomain:",current_domain,
                        "\ttopic:", current_topic)
                
                intents_path.append(intent_prediction)
            print("Detected intents path:")
            print("-->".join(intents_path))
            topic_combinations.append(", ".join(topic_combination))
    print("\n\nSUMMARY FOR", len(dialogues), "DIALOGUES\n")
    print("Detected domains:", list(set(domains)))
    print(Counter(domains))
    print(Counter(topics))
    print(Counter(topic_combinations))


if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    main(args)