Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
9.5 kB
0
Indexable
Never
import os
from typing import Dict
import transformers
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from transformers.generation_logits_process import LogitsProcessor, LogitsProcessorList
import torch
from torch import Tensor
from .basenet import BaseNet
from .metadata import *

class CodeT5():
    def __init__(self, model_name):

        
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum').to(self.device)
        # self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
        # self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum')

        self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
        self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
        self.load_checkpoint(model_name)
        self.default_beam_size = 1
        # This is max length of the target sequence
        self.max_length = 512

        if torch.cuda.is_available():
            print("Using this GPU device for prediction : ", torch.cuda.get_device_name(0), "....")
        else:
            print("Using CPU for prediction ......")

    def load_checkpoint(self, filename):
        try:
            if os.path.isfile(filename):
                if torch.cuda.is_available():
                    self.model.load_state_dict(torch.load(filename))
                    self.model.to('cuda')
                else:
                    self.model.load_state_dict(torch.load(filename,map_location=torch.device('cpu')))
        except:
            print(f"Failed to load model from {filename}")

    def predict(self, input_ids: torch.Tensor, **kwargs):
        if kwargs.get("max_length_mapper") is not None:
            MAX_SUB_LENGTH_MAPPER = kwargs["max_length_mapper"]
        if kwargs.get("min_length_mapper") is not None:
            MIN_SUB_LENGTH_MAPPER = kwargs["min_length_mapper"]
        print("DEvice: ", self.model.device)
        input_ids = input_ids.to(self.model.device)
        batch_size = input_ids.size(0)
        beam_size = self.default_beam_size
        
        beam_size = 1
        temperature = 1.0

        anchor_token_ids =  [self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS_MAP[comment_type]) for comment_type in COMMENT_TYPES]
        anchor_token_ids = torch.LongTensor(anchor_token_ids).to(input_ids.device)
        max_sub_length_mapper = {self.tokenizer.convert_tokens_to_ids(token): length for token, length in MAX_SUB_LENGTH_MAPPER.items()}
        min_sub_length_mapper = {self.tokenizer.convert_tokens_to_ids(token): length for token, length in MIN_SUB_LENGTH_MAPPER.items()}

        # csearch or greedy?
        
        with torch.no_grad():
            logits_processor = MaxMinSubLengthLogitsProcessor(anchor_token_ids,
                                                     max_sub_length_mapper,
                                                     min_sub_length_mapper,
                                                     self.tokenizer.vocab_size,
                                                     batch_size,
                                                     beam_size,
                                                     self.tokenizer.eos_token_id,
                                                     bound_max_length=True,
                                                     bound_min_length=True,
                                                     )
            
            dot_processor = DotGenerationLogitsProcessor(self.tokenizer.eos_token_id, dot_token_id=self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize("."))[0])

            logits_processor = LogitsProcessorList([logits_processor, dot_processor])
            
            generated_ids = self.model.generate(input_ids, 
                                                max_length=self.max_length, 
                                                num_beams=beam_size, 
                                                temperature=temperature,
                                                # num_return_sequences=num_samples,
                                                logits_processor=logits_processor,
                                                penalty_alpha=kwargs["penalty_alpha"],
                                                top_k = kwargs["top_k"]
                                                )

        return generated_ids

    def cuda(self):
        pass

class MaxMinSubLengthLogitsProcessor(LogitsProcessor):
    def __init__(self, anchor_token_ids: Tensor, max_sub_length_mapper: Dict, min_sub_length_mapper: Dict, vocab_size: int, batch_size: int, beam_size: int, eos_token_id: int, bound_max_length=True, bound_min_length=True):
        self.vocab_size = vocab_size
        self.anchor_token_ids = anchor_token_ids
        self.max_sub_length_mapper = max_sub_length_mapper
        self.min_sub_length_mapper = min_sub_length_mapper
        self.word_token_ids = torch.LongTensor([i for i in range(vocab_size) if i not in self.anchor_token_ids.cpu().tolist() + [eos_token_id]]).to(anchor_token_ids.device)
        self.ending_token_ids = torch.LongTensor(self.anchor_token_ids.cpu().tolist() + [eos_token_id]).to(anchor_token_ids.device)
        self.current_sub_sequence_lenths = torch.zeros(batch_size * beam_size).to(self.anchor_token_ids.device)
        self.max_sub_sequence_length_tracker = torch.empty(batch_size * beam_size).long().to(self.anchor_token_ids.device).fill_(1e4)
        self.min_sub_sequence_length_tracker = torch.empty(batch_size * beam_size).long().to(self.anchor_token_ids.device).fill_(-1)
        self.bound_max_length = bound_max_length
        self.bound_min_length = bound_min_length

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        original_shape = scores.shape
        scores = scores.view(-1)

        # MIN LENGTH
        if self.bound_min_length:
            violate_indices = self.current_sub_sequence_lenths.lt(self.min_sub_sequence_length_tracker)
            violate_indices = torch.nonzero(violate_indices)
            violate_indices = violate_indices.unsqueeze(-1) * self.vocab_size
            violate_indices = violate_indices + self.ending_token_ids.unsqueeze(0)
            violate_indices = violate_indices.view(-1)
            scores[violate_indices] = -float("inf")

        # MAX LENGTH
        if self.bound_max_length:
            violate_indices = (self.current_sub_sequence_lenths + 1).eq(self.max_sub_sequence_length_tracker)
            violate_indices = torch.nonzero(violate_indices)
            violate_indices = violate_indices.unsqueeze(-1) * self.vocab_size
            violate_indices = violate_indices + self.word_token_ids.unsqueeze(0)
            violate_indices = violate_indices.view(-1)
            scores[violate_indices] = -float("inf")
            scores = scores.view(original_shape)
            next_tokens = torch.argmax(scores, dim=-1)
            self.update_tracker(next_tokens)

        return scores

    def update_tracker(self, next_tokens):
        anchor_positions = torch.isin(next_tokens, self.anchor_token_ids)
        beginning_token_ids = next_tokens[anchor_positions]
        self.current_sub_sequence_lenths = (self.current_sub_sequence_lenths + 1) * ~anchor_positions
        # print("before: {}".format(beginning_token_ids))
        self.max_sub_sequence_length_tracker[anchor_positions] = beginning_token_ids.clone().cpu().apply_(lambda token: self.max_sub_length_mapper.get(token)).to(beginning_token_ids.device)
        # print("after: {}".format(beginning_token_ids))
        self.min_sub_sequence_length_tracker[anchor_positions] = beginning_token_ids.clone().cpu().apply_(lambda token: self.min_sub_length_mapper.get(token)).to(beginning_token_ids.device)

class DotGenerationLogitsProcessor(LogitsProcessor):
    def __init__(self, eos_token_id: int, dot_token_id):
        if not isinstance(eos_token_id, int) or eos_token_id < 0:
            raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
        if not isinstance(dot_token_id, int) or dot_token_id < 0:
            raise ValueError(f"`dot_token_id` has to be a positive integer, but is {dot_token_id}")

        self.eos_token_id = eos_token_id
        self.dot_token_id = dot_token_id
        self.allowed_eos_indices = None

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.allowed_eos_indices is None:
            self.allowed_eos_indices = torch.zeros(scores.shape[0]).bool().to(scores.device)
        original_shape = scores.shape
        scores = scores.contiguous().view(-1)

        violated_indices = torch.nonzero(~self.allowed_eos_indices).to(scores.device)
        violated_indices = violated_indices * scores.shape[-1] + self.eos_token_id
        violated_indices = violated_indices.view(-1)
        scores[violated_indices] = -float("inf")
        scores = scores.view(original_shape)
        next_tokens = torch.argmax(scores, dim=-1)
        self.allowed_eos_indices = next_tokens.eq(self.dot_token_id)
        return scores

if __name__=='__main__':
    print("Success")