0920_bidaf

mail@pastecode.io avatar
unknown
python
a month ago
1.7 kB
4
Indexable
Never
import numpy as np
from nltk import word_tokenize
import nltk
import onnxruntime as ort

# Download NLTK data
nltk.download('punkt_tab')
model_path = "./bidaf-11-int8.onnx"
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
cuda_options = {
    'device_id': 0,
    'arena_extend_strategy': 'kNextPowerOfTwo',
    'gpu_mem_limit': 2 * 1024 * 1024 * 1024,  # 2 GB
    'cudnn_conv_algo_search': 'EXHAUSTIVE',
    'do_copy_in_default_stream': True,
}


def preprocess(text):
    tokens = word_tokenize(text)
    words = np.asarray([w.lower() for w in tokens]).reshape(-1, 1)
    chars = [[c for c in t][:16] for t in tokens]
    chars = [cs + [''] * (16 - len(cs)) for cs in chars]
    chars = np.asarray(chars).reshape(-1, 1, 1, 16)
    return words, chars

# Input texts
context = 'A quick brown fox jumps over the lazy dog.'
query = 'What color is the fox?'

# Preprocess inputs
cw, cc = preprocess(context)
qw, qc = preprocess(query)

# Prepare inputs
context_word = cw.astype(object)
context_char = cc.astype(object)
query_word = qw.astype(object)
query_char = qc.astype(object)

# Load model
session = ort.InferenceSession(model_path, providers=providers)

# Verify the provider being used
print("Using provider:", session.get_providers())
print(f"Using device: {ort.get_device()}")


# Run inference
inputs = {
    'context_word': context_word,
    'context_char': context_char,
    'query_word': query_word,
    'query_char': query_char
}
outputs = session.run(None, inputs)

# Post-process outputs
start_pos = outputs[0]
end_pos = outputs[1]
start = int(start_pos[0])
end = int(end_pos[0])
answer_words = cw[start:end+1].reshape(-1)
answer = ' '.join(answer_words)

print(f'Answer: {answer}')
Leave a Comment