0920_bidaf
unknown
python
a year ago
1.7 kB
14
Indexable
import numpy as np
from nltk import word_tokenize
import nltk
import onnxruntime as ort
# Download NLTK data
nltk.download('punkt_tab')
model_path = "./bidaf-11-int8.onnx"
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
cuda_options = {
'device_id': 0,
'arena_extend_strategy': 'kNextPowerOfTwo',
'gpu_mem_limit': 2 * 1024 * 1024 * 1024, # 2 GB
'cudnn_conv_algo_search': 'EXHAUSTIVE',
'do_copy_in_default_stream': True,
}
def preprocess(text):
tokens = word_tokenize(text)
words = np.asarray([w.lower() for w in tokens]).reshape(-1, 1)
chars = [[c for c in t][:16] for t in tokens]
chars = [cs + [''] * (16 - len(cs)) for cs in chars]
chars = np.asarray(chars).reshape(-1, 1, 1, 16)
return words, chars
# Input texts
context = 'A quick brown fox jumps over the lazy dog.'
query = 'What color is the fox?'
# Preprocess inputs
cw, cc = preprocess(context)
qw, qc = preprocess(query)
# Prepare inputs
context_word = cw.astype(object)
context_char = cc.astype(object)
query_word = qw.astype(object)
query_char = qc.astype(object)
# Load model
session = ort.InferenceSession(model_path, providers=providers)
# Verify the provider being used
print("Using provider:", session.get_providers())
print(f"Using device: {ort.get_device()}")
# Run inference
inputs = {
'context_word': context_word,
'context_char': context_char,
'query_word': query_word,
'query_char': query_char
}
outputs = session.run(None, inputs)
# Post-process outputs
start_pos = outputs[0]
end_pos = outputs[1]
start = int(start_pos[0])
end = int(end_pos[0])
answer_words = cw[start:end+1].reshape(-1)
answer = ' '.join(answer_words)
print(f'Answer: {answer}')
Editor is loading...
Leave a Comment