Untitled

 avatar
unknown
python
12 days ago
5.5 kB
3
Indexable
import random

""" 			  		 			     			  	   		   	  			  	
Seq2Seq model.  (c) 2021 Georgia Tech

Copyright 2021, Georgia Institute of Technology (Georgia Tech)
Atlanta, Georgia 30332
All Rights Reserved

Template code for CS 7643 Deep Learning

Georgia Tech asserts copyright ownership of this template and all derivative
works, including solutions to the projects assigned in this course. Students
and other users of this template code are advised not to share it with others
or to make it available on publicly viewable websites including repositories
such as Github, Bitbucket, and Gitlab.  This copyright statement should
not be removed or edited.

Sharing solutions with current or future students of CS 7643 Deep Learning is
prohibited and subject to being investigated as a GT honor code violation.

-----do not edit anything above this line---
"""

import torch
import torch.nn as nn
import torch.optim as optim


# import custom models


class Seq2Seq(nn.Module):
    """ The Sequence to Sequence model.
        You will need to complete the init function and the forward function.
    """

    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.device = device
        #############################################################################
        # TODO:                                                                     #
        #    Initialize the Seq2Seq model. You should use .to(device) to make sure  #
        #    that the models are on the same device (CPU/GPU). This should take no  #
        #    more than 2 lines of code.                                             #
        #############################################################################
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)

        #############################################################################
        #                              END OF YOUR CODE                             #
        #############################################################################

    def forward(self, source):
        """ The forward pass of the Seq2Seq model.
            Args:
                source (tensor): sequences in source language of shape (batch_size, seq_len)
        """

        batch_size = source.shape[0]
        seq_len = source.shape[1]
        #############################################################################
        # TODO:                                                                     #
        #   Implement the forward pass of the Seq2Seq model. Please refer to the    #
        #   following steps:                                                        #
        #       1) Get the last hidden representation from the encoder. Use it as   #
        #          the first hidden state of the decoder                            #
        #       2) The first input for the decoder should be the <sos> token, which #
        #          is the first in the source sequence.                             #
        #       3) Feed this first input and hidden state into the decoder          #  
        #          one step at a time in the sequence, adding the output to the     #
        #          final outputs.                                                   #
        #       4) Update the input and hidden being fed into the decoder           #
        #          at each time step. The decoder output at the previous time step  # 
        #          will have to be manipulated before being fed in as the decoder   #
        #          input at the next time step.                                     #
        #############################################################################
        batch_size, seq_len = source.shape[0], source.shape[1]

        # Step 1: Encode
        encoder_output, hidden = self.encoder.forward(source)
    
        # Step 2: Prepare initial input (<sos> token)
        decoder_input = source[:, 0].unsqueeze(1)  # Shape: [batch_size, 1]
    
        # Step 3: Output tensor to accumulate results
        outputs = torch.zeros(batch_size, seq_len, self.decoder.output_size).to(source.device)
        
        # If LSTM, hidden is a tuple (hidden_state, cell_state)
        lstm_flag = False 
        if isinstance(hidden, tuple):  
            hidden, cell = hidden
            lstm_flag = True

        for t in range(seq_len):
            if self.decoder.attention:
                
                if lstm_flag:  # If LSTM
                    output, hidden = self.decoder.forward(decoder_input, (hidden, cell), encoder_output)
                else:  # If GRU
                    output, hidden = self.decoder.forward(decoder_input, hidden, encoder_output)
            else:
                if lstm_flag:  # If LSTM
                    output, hidden = self.decoder.forward(decoder_input, (hidden, cell))
                else:  # If GRU
                    output, hidden = self.decoder.forward(decoder_input, hidden)
        
    
            # Step 6: Store decoder output logits
            outputs[:, t, :] = output
    
            # Step 7: Greedy decoding - Pick the token with the highest logit
            decoder_input = torch.argmax(output, dim=1).unsqueeze(1)
        #############################################################################
        #                              END OF YOUR CODE                             #
        #############################################################################
        return outputs
Editor is loading...
Leave a Comment