Untitled

 avatar
unknown
plain_text
a month ago
386 B
4
Indexable
class STEncoder(tf.keras.layers.Layer):
    def __init__(self, L, K, d, bn, bn_decay):
        super().__init__()
        self.blocks = [STAttBlock(K, d, bn, bn_decay) for _ in range(L)]
        self.final_fc = tf.keras.layers.Dense(1)

    def call(self, X, STE, is_training):
        for block in self.blocks:
            X = block(X, STE, is_training)
        return self.final_fc(X)
Editor is loading...
Leave a Comment