Untitled

 avatar
unknown
python
3 years ago
2.8 kB
3
Indexable
import tensorflow as tf
from tensorflow.keras.layers import *

def part_of_model_creation():
    # ...
    if locality_term:
        locality1 = Lambda(locality1_op)([h, c])
        locality2 = Lambda(locality2_op)(inp3)
        locality_layer = Lambda(locality_term_op)([locality1, locality2])
    # ...
    decoder = Model(inputs=[decoderInpH, decoderInpC, decoderPrevInput], outputs=out)
    if locality_term:
        model = Model(inputs=[inp, inp2, inp3], outputs=decoder([h, c, embed2]))
    else:
        model = Model(inputs=[inp, inp2], outputs=decoder([h, c, embed2]))

    model.summary()
    decoder.summary()

    if locality_term:
        print("Using locality term! Locality power: " , locality_power)
        locality_loss = (1-locality_layer)*tf.constant(locality_power)
        model.add_loss(locality_loss)

        model.add_metric(locality_loss, name='locality', aggregation='mean')
    
    model.compile(optimizer=Adam(lr, clipnorm=1.0, clipvalue=0.5), loss='categorical_crossentropy')
    decoder.compile(optimizer=Adam(lr, clipnorm=1.0, clipvalue=0.5), loss='categorical_crossentropy')
    # ..
    

def zscore(x):
  mean = tf.reduce_mean(x, axis=0)
  std = tf.math.reduce_std(x, axis=0)
  return (x-mean)/(std+0.01)


def tf_corr(x,y):
    num = tf.reduce_sum(tf.cast(x, tf.float32) * tf.cast(y, tf.float32)) - tf.cast(tf.shape(x)[0],
                    tf.float32) * tf.cast(tf.reduce_mean(x), tf.float32) * tf.cast(tf.reduce_mean(y), tf.float32)

    den = tf.cast(tf.shape(x)[0], tf.float32) * tf.cast(tf.math.reduce_std(x), tf.float32) * tf.cast(tf.math.reduce_std(y), tf.float32)
    return num/(den+0.01)


def squared_dist(A):
    expanded_a = tf.expand_dims(A, 1)
    expanded_b = tf.expand_dims(A, 0)
    distances = tf.reduce_sum(tf.math.squared_difference(expanded_a, expanded_b), 2)
    return distances


def euclidean(A):
    sq_dist = tf.cast(squared_dist(A), tf.float32)
    return tf.where(sq_dist != 0, tf.sqrt(sq_dist), sq_dist)


def upper_triangular(A):
    ones = tf.ones_like(A)
    mask_a = tf.linalg.band_part(ones, 0, -1)  # Upper triangular matrix of 0s and 1s
    mask_b = tf.linalg.band_part(ones, 0, 0)  # Diagonal matrix of 0s and 1s
    mask = tf.cast(mask_a - mask_b, dtype=tf.bool)  # Make a bool mask

    upper_triangular_flat = tf.boolean_mask(A, mask)
    return upper_triangular_flat


def locality1_op(x):
    latent = zscore(tf.concat([x[0], x[1]], axis=-1))
    latent_dist = upper_triangular(euclidean(latent))
    return latent_dist


def locality2_op(x):
    genos_dist = upper_triangular(x[0])
    return genos_dist

@tf.function
def locality_term_op(x):
    corr = tf_corr(x[0], x[1])
    return corr