Untitled
unknown
python
4 years ago
4.5 kB
12
Indexable
import argparse
import warnings
import torch
import os
import re
from multihead_model import MultiheadIntentModel
import json
from collections import Counter
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default="/data2/nlp_team/data/topic_change/pl/latest/alior_dialogs.json", help="Sentence or utterance. If input == load_dialogues_alior then will load three example dialogues.")
parser.add_argument("--checkpoint", type=str, default="/data2/amikolajczyk/share/multihead_intent/herbert-intent-intent-full.pkl", help="path to pretrained model")
parser.add_argument("--class_dict", type=str, default="/data2/amikolajczyk/share/multihead_intent/multihead_class_dict_trans.json", help="class_to_idx dict in json file")
parser.add_argument("--intent_dict", type=str, default="/data2/amikolajczyk/share/intents/dict/base_intent_dict.json", help="maps intents to domains and topics")
parser.add_argument("--model_name", type=str, default="allegro/herbert-base-cased", help="model name")
parser.add_argument("--gpu_device", type=str, default="6", help="GPU device")
return parser
def main(args):
warnings.filterwarnings("ignore")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
warnings.filterwarnings("ignore")
print("Loading class dict...")
with open(args.class_dict) as json_file:
class_dict = json.load(json_file)
inv_class_dict = {v: k for k, v in class_dict.items()}
print("Loading intent dict...")
with open(args.intent_dict, 'r') as fp:
intent_dict = json.load(fp)
print("Loading model...")
model = MultiheadIntentModel(
model_name=args.model_name,
num_labels=len(list(class_dict.keys())),
max_len=512,
gpu_device="cuda:0",
).to("cuda:0")
print('Loading pretrained model...')
model.load_state_dict(torch.load(
args.checkpoint, map_location="cuda"))
model.eval()
# accepted_domains = ["bankowość","obsługa klienta", "bezpieczeństwo", "transakcje"]
accepted_domains = None
topics = list()
topic_combinations = list()
domains = list()
with open(args.input) as json_file:
dialogues = json.load(json_file)
# open single dialogue
for dialogue in dialogues:
topic_combination = list()
intents_path = list()
print("\n\nNew dialogue!")
for turn in dialogue["turns"]:
print(turn["speaker_id"])
text = turn["utterance"]
binary_out, intent_out = model([text])
binary_id = torch.argmax(binary_out).item()
intent_id = torch.argmax(intent_out).item()
binary_prediction = ["intencja" if binary_id == 1 else "brak"]
intent_prediction = inv_class_dict[intent_id]
print(text)
if intent_id != 0:
current_domain = intent_dict[str(intent_id)]["domain"]
current_topic = intent_dict[str(intent_id)]["topic"]
if accepted_domains != None and current_domain not in accepted_domains:
current_domain = "nieznana"
current_topic = "nieznany"
intent_prediction = "nieznana"
topics.append(current_topic)
domains.append(current_domain)
if len(topic_combination) == 0:
topic_combination.append(current_topic)
elif topic_combination[-1] != current_topic:
topic_combination.append(current_topic)
print("\n\t intent:",intent_prediction, intent_id,
"\n\tdomain:",current_domain,
"\ttopic:", current_topic)
intents_path.append(intent_prediction)
print("Detected intents path:")
print("-->".join(intents_path))
topic_combinations.append(", ".join(topic_combination))
print("\n\nSUMMARY FOR", len(dialogues), "DIALOGUES\n")
print("Detected domains:", list(set(domains)))
print(Counter(domains))
print(Counter(topics))
print(Counter(topic_combinations))
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
main(args)
Editor is loading...