Untitled
unknown
plain_text
2 years ago
1.2 kB
39
Indexable
class SiameseModel(Model): def __init__(self, siamese_network, margin=0.5): super(SiameseModel, self).__init__() self.siamese_network = siamese_network self.margin = margin self.loss_tracker = metrics.Mean(name="loss") def call(self, inputs): return self.siamese_network(inputs) def train_step(self, data): with tf.GradientTape() as tape: loss = self._compute_loss(data) gradients = tape.gradient(loss, self.siamese_network.trainable_weights) self.optimizer.apply_gradients( zip(gradients, self.siamese_network.trainable_weights) ) self.loss_tracker.update_state(loss) return {"loss": self.loss_tracker.result()} def test_step(self, data): loss = self._compute_loss(data) self.loss_tracker.update_state(loss) return {"loss": self.loss_tracker.result()} def _compute_loss(self, data): ap_distance, an_distance = self.siamese_network(data) loss = ap_distance - an_distance loss = tf.maximum(loss + self.margin, 0.0) return loss @property def metrics(self): return [self.loss_tracker]
Editor is loading...