Untitled

 avatar
unknown
python
a year ago
711 B
3
Indexable
class BaseOnnx(threading.Thread):
    def __init__(self, path):
        threading.Thread.__init__(self)
        import onnxruntime
        self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])

    def pre(self, input_feed):
        for key in input_feed:
            input_feed[key] = input_feed[key].astype(np.float32)
        return input_feed

    def pos(self, out):
        return out[0]
    
    def infer(self, input_feed):
        out = self.session.run(None, input_feed=input_feed)
        return out
    
    def __call__(self, input_feed):

        input_feed = self.pre(input_feed)
        out = self.infer(input_feed)
        out = self.pos(out)
        return out
Editor is loading...
Leave a Comment