0920_bidaf
unknown
python
a month ago
1.7 kB
4
Indexable
Never
import numpy as np from nltk import word_tokenize import nltk import onnxruntime as ort # Download NLTK data nltk.download('punkt_tab') model_path = "./bidaf-11-int8.onnx" providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] cuda_options = { 'device_id': 0, 'arena_extend_strategy': 'kNextPowerOfTwo', 'gpu_mem_limit': 2 * 1024 * 1024 * 1024, # 2 GB 'cudnn_conv_algo_search': 'EXHAUSTIVE', 'do_copy_in_default_stream': True, } def preprocess(text): tokens = word_tokenize(text) words = np.asarray([w.lower() for w in tokens]).reshape(-1, 1) chars = [[c for c in t][:16] for t in tokens] chars = [cs + [''] * (16 - len(cs)) for cs in chars] chars = np.asarray(chars).reshape(-1, 1, 1, 16) return words, chars # Input texts context = 'A quick brown fox jumps over the lazy dog.' query = 'What color is the fox?' # Preprocess inputs cw, cc = preprocess(context) qw, qc = preprocess(query) # Prepare inputs context_word = cw.astype(object) context_char = cc.astype(object) query_word = qw.astype(object) query_char = qc.astype(object) # Load model session = ort.InferenceSession(model_path, providers=providers) # Verify the provider being used print("Using provider:", session.get_providers()) print(f"Using device: {ort.get_device()}") # Run inference inputs = { 'context_word': context_word, 'context_char': context_char, 'query_word': query_word, 'query_char': query_char } outputs = session.run(None, inputs) # Post-process outputs start_pos = outputs[0] end_pos = outputs[1] start = int(start_pos[0]) end = int(end_pos[0]) answer_words = cw[start:end+1].reshape(-1) answer = ' '.join(answer_words) print(f'Answer: {answer}')
Leave a Comment