Untitled
unknown
plain_text
7 months ago
386 B
4
Indexable
class STEncoder(tf.keras.layers.Layer):
def __init__(self, L, K, d, bn, bn_decay):
super().__init__()
self.blocks = [STAttBlock(K, d, bn, bn_decay) for _ in range(L)]
self.final_fc = tf.keras.layers.Dense(1)
def call(self, X, STE, is_training):
for block in self.blocks:
X = block(X, STE, is_training)
return self.final_fc(X)Editor is loading...
Leave a Comment