Untitled
unknown
python
a year ago
711 B
6
Indexable
class BaseOnnx(threading.Thread):
def __init__(self, path):
threading.Thread.__init__(self)
import onnxruntime
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
def pre(self, input_feed):
for key in input_feed:
input_feed[key] = input_feed[key].astype(np.float32)
return input_feed
def pos(self, out):
return out[0]
def infer(self, input_feed):
out = self.session.run(None, input_feed=input_feed)
return out
def __call__(self, input_feed):
input_feed = self.pre(input_feed)
out = self.infer(input_feed)
out = self.pos(out)
return out
Editor is loading...
Leave a Comment