Untitled

 avatar
unknown
plain_text
2 years ago
4.4 kB
66
Indexable

def encode_text(text, tokenizer):
    
    encoded = tokenizer.batch_encode_plus(
        text,
        add_special_tokens=True,
        max_length=50,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors="tf",
    )

    input_ids = np.array(encoded["input_ids"], dtype="int32")
    attention_masks = np.array(encoded["attention_mask"], dtype="int32")

    return {
        "input_ids": input_ids,
        "attention_masks": attention_masks
    }


model_checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train = data.iloc[:int(490*0.80),:]
val = data.iloc[int(490*0.80):,:]

X1_train = encode_text(train['Anchor'].tolist(), tokenizer)
X2_train = encode_text(train['Positive'].tolist(), tokenizer)
X3_train = encode_text(train['Negative'].tolist(), tokenizer)

X1_val = encode_text(val['Anchor'].tolist(), tokenizer)
X2_val = encode_text(val['Positive'].tolist(), tokenizer)
X3_val = encode_text(val['Negative'].tolist(), tokenizer)



class DistanceLayer(Layer):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, anchor, positive, negative):
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
        return (ap_distance, an_distance)

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    BATCH_SIZE = strategy.num_replicas_in_sync * 4
    print("Running on TPU:", tpu.master())
    print(f"Batch Size: {BATCH_SIZE}")
    
except ValueError:
    strategy = tf.distribute.get_strategy()
    BATCH_SIZE = 32
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    print(f"Batch Size: {BATCH_SIZE}")



with strategy.scope():
    transformer_model = TFBertModel.from_pretrained('bert-base-uncased')
    
    input_ids_in1 = Input(shape=(50,),name='input_ids1', dtype='int32')
    
    input_masks_in1 = Input(shape=(50,), name='attention_mask1', dtype='int32')

    anchor_input = Input(name="anchor_ids", shape=(50,), dtype='int32')
    
    anchor_masks = Input(name="anchor_mask", shape=(50,), dtype='int32')

    positive_input = Input(name="positive_ids", shape=(50,), dtype='int32')
    
    positive_masks = Input(name="positive_mask", shape=(50,), dtype='int32')

    negative_input = Input(name="negative_ids", shape=(50,), dtype='int32')
    
    negative_masks = Input(name="negative_mask", shape=(50,), dtype='int32')

    embedding_layer = transformer_model(input_ids_in1, attention_mask=input_masks_in1).last_hidden_state

    average = GlobalAveragePooling1D()(embedding_layer)
    embeds = Dense(512,activation='relu')(average)
    
    embeddings = Model(inputs=[input_ids_in1,input_masks_in1],outputs=embeds)
    
    for layer in embeddings.layers[:-1]:
        layer.trainable = False
    
    embeds1 = embeddings([anchor_input,anchor_masks])
    embeds2 = embeddings([positive_input,positive_masks])
    embeds3 = embeddings([negative_input,negative_masks])

    
    distances = DistanceLayer()(embeds1,embeds2,embeds3)
    
    
    siamese_network = Model(
        inputs=[anchor_input, anchor_masks, positive_input, positive_masks, negative_input, negative_masks], outputs=distances
    )

    siamese_model = SiameseModel(siamese_network)
    
    
    siamese_model.compile(optimizer=tf.keras.optimizers.Adam(0.00001))
    history = siamese_model.fit((np.asarray(X1_train['input_ids']),np.asarray(X1_train['attention_masks']),
                                   np.asarray(X2_train['input_ids']),np.asarray(X2_train['attention_masks']),
                                   np.asarray(X3_train['input_ids']),np.asarray(X3_train['attention_masks'])), 
                                  epochs=10, 
                                  validation_data=((np.asarray(X1_val['input_ids']),np.asarray(X1_val['attention_masks']),
                                   np.asarray(X2_val['input_ids']),np.asarray(X2_val['attention_masks']),
                                   np.asarray(X3_val['input_ids']),np.asarray(X3_val['attention_masks'])),))
Editor is loading...