Untitled
unknown
python
3 years ago
4.5 kB
6
Indexable
Never
import argparse import warnings import torch import os import re from multihead_model import MultiheadIntentModel import json from collections import Counter def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, default="/data2/nlp_team/data/topic_change/pl/latest/alior_dialogs.json", help="Sentence or utterance. If input == load_dialogues_alior then will load three example dialogues.") parser.add_argument("--checkpoint", type=str, default="/data2/amikolajczyk/share/multihead_intent/herbert-intent-intent-full.pkl", help="path to pretrained model") parser.add_argument("--class_dict", type=str, default="/data2/amikolajczyk/share/multihead_intent/multihead_class_dict_trans.json", help="class_to_idx dict in json file") parser.add_argument("--intent_dict", type=str, default="/data2/amikolajczyk/share/intents/dict/base_intent_dict.json", help="maps intents to domains and topics") parser.add_argument("--model_name", type=str, default="allegro/herbert-base-cased", help="model name") parser.add_argument("--gpu_device", type=str, default="6", help="GPU device") return parser def main(args): warnings.filterwarnings("ignore") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device os.environ["TOKENIZERS_PARALLELISM"] = "true" warnings.filterwarnings("ignore") print("Loading class dict...") with open(args.class_dict) as json_file: class_dict = json.load(json_file) inv_class_dict = {v: k for k, v in class_dict.items()} print("Loading intent dict...") with open(args.intent_dict, 'r') as fp: intent_dict = json.load(fp) print("Loading model...") model = MultiheadIntentModel( model_name=args.model_name, num_labels=len(list(class_dict.keys())), max_len=512, gpu_device="cuda:0", ).to("cuda:0") print('Loading pretrained model...') model.load_state_dict(torch.load( args.checkpoint, map_location="cuda")) model.eval() # accepted_domains = ["bankowość","obsługa klienta", "bezpieczeństwo", "transakcje"] accepted_domains = None topics = list() topic_combinations = list() domains = list() with open(args.input) as json_file: dialogues = json.load(json_file) # open single dialogue for dialogue in dialogues: topic_combination = list() intents_path = list() print("\n\nNew dialogue!") for turn in dialogue["turns"]: print(turn["speaker_id"]) text = turn["utterance"] binary_out, intent_out = model([text]) binary_id = torch.argmax(binary_out).item() intent_id = torch.argmax(intent_out).item() binary_prediction = ["intencja" if binary_id == 1 else "brak"] intent_prediction = inv_class_dict[intent_id] print(text) if intent_id != 0: current_domain = intent_dict[str(intent_id)]["domain"] current_topic = intent_dict[str(intent_id)]["topic"] if accepted_domains != None and current_domain not in accepted_domains: current_domain = "nieznana" current_topic = "nieznany" intent_prediction = "nieznana" topics.append(current_topic) domains.append(current_domain) if len(topic_combination) == 0: topic_combination.append(current_topic) elif topic_combination[-1] != current_topic: topic_combination.append(current_topic) print("\n\t intent:",intent_prediction, intent_id, "\n\tdomain:",current_domain, "\ttopic:", current_topic) intents_path.append(intent_prediction) print("Detected intents path:") print("-->".join(intents_path)) topic_combinations.append(", ".join(topic_combination)) print("\n\nSUMMARY FOR", len(dialogues), "DIALOGUES\n") print("Detected domains:", list(set(domains))) print(Counter(domains)) print(Counter(topics)) print(Counter(topic_combinations)) if __name__ == '__main__': parser = get_args_parser() args = parser.parse_args() main(args)