Untitled
unknown
python
2 years ago
9.5 kB
2
Indexable
import os from typing import Dict import transformers from transformers import RobertaTokenizer, T5ForConditionalGeneration from transformers.generation_logits_process import LogitsProcessor, LogitsProcessorList import torch from torch import Tensor from .basenet import BaseNet from .metadata import * class CodeT5(): def __init__(self, model_name): # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum').to(self.device) # self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum') # self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum') self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base') self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base') self.load_checkpoint(model_name) self.default_beam_size = 1 # This is max length of the target sequence self.max_length = 512 if torch.cuda.is_available(): print("Using this GPU device for prediction : ", torch.cuda.get_device_name(0), "....") else: print("Using CPU for prediction ......") def load_checkpoint(self, filename): try: if os.path.isfile(filename): if torch.cuda.is_available(): self.model.load_state_dict(torch.load(filename)) self.model.to('cuda') else: self.model.load_state_dict(torch.load(filename,map_location=torch.device('cpu'))) except: print(f"Failed to load model from {filename}") def predict(self, input_ids: torch.Tensor, **kwargs): if kwargs.get("max_length_mapper") is not None: MAX_SUB_LENGTH_MAPPER = kwargs["max_length_mapper"] if kwargs.get("min_length_mapper") is not None: MIN_SUB_LENGTH_MAPPER = kwargs["min_length_mapper"] print("DEvice: ", self.model.device) input_ids = input_ids.to(self.model.device) batch_size = input_ids.size(0) beam_size = self.default_beam_size beam_size = 1 temperature = 1.0 anchor_token_ids = [self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS_MAP[comment_type]) for comment_type in COMMENT_TYPES] anchor_token_ids = torch.LongTensor(anchor_token_ids).to(input_ids.device) max_sub_length_mapper = {self.tokenizer.convert_tokens_to_ids(token): length for token, length in MAX_SUB_LENGTH_MAPPER.items()} min_sub_length_mapper = {self.tokenizer.convert_tokens_to_ids(token): length for token, length in MIN_SUB_LENGTH_MAPPER.items()} # csearch or greedy? with torch.no_grad(): logits_processor = MaxMinSubLengthLogitsProcessor(anchor_token_ids, max_sub_length_mapper, min_sub_length_mapper, self.tokenizer.vocab_size, batch_size, beam_size, self.tokenizer.eos_token_id, bound_max_length=True, bound_min_length=True, ) dot_processor = DotGenerationLogitsProcessor(self.tokenizer.eos_token_id, dot_token_id=self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize("."))[0]) logits_processor = LogitsProcessorList([logits_processor, dot_processor]) generated_ids = self.model.generate(input_ids, max_length=self.max_length, num_beams=beam_size, temperature=temperature, # num_return_sequences=num_samples, logits_processor=logits_processor, penalty_alpha=kwargs["penalty_alpha"], top_k = kwargs["top_k"] ) return generated_ids def cuda(self): pass class MaxMinSubLengthLogitsProcessor(LogitsProcessor): def __init__(self, anchor_token_ids: Tensor, max_sub_length_mapper: Dict, min_sub_length_mapper: Dict, vocab_size: int, batch_size: int, beam_size: int, eos_token_id: int, bound_max_length=True, bound_min_length=True): self.vocab_size = vocab_size self.anchor_token_ids = anchor_token_ids self.max_sub_length_mapper = max_sub_length_mapper self.min_sub_length_mapper = min_sub_length_mapper self.word_token_ids = torch.LongTensor([i for i in range(vocab_size) if i not in self.anchor_token_ids.cpu().tolist() + [eos_token_id]]).to(anchor_token_ids.device) self.ending_token_ids = torch.LongTensor(self.anchor_token_ids.cpu().tolist() + [eos_token_id]).to(anchor_token_ids.device) self.current_sub_sequence_lenths = torch.zeros(batch_size * beam_size).to(self.anchor_token_ids.device) self.max_sub_sequence_length_tracker = torch.empty(batch_size * beam_size).long().to(self.anchor_token_ids.device).fill_(1e4) self.min_sub_sequence_length_tracker = torch.empty(batch_size * beam_size).long().to(self.anchor_token_ids.device).fill_(-1) self.bound_max_length = bound_max_length self.bound_min_length = bound_min_length def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: original_shape = scores.shape scores = scores.view(-1) # MIN LENGTH if self.bound_min_length: violate_indices = self.current_sub_sequence_lenths.lt(self.min_sub_sequence_length_tracker) violate_indices = torch.nonzero(violate_indices) violate_indices = violate_indices.unsqueeze(-1) * self.vocab_size violate_indices = violate_indices + self.ending_token_ids.unsqueeze(0) violate_indices = violate_indices.view(-1) scores[violate_indices] = -float("inf") # MAX LENGTH if self.bound_max_length: violate_indices = (self.current_sub_sequence_lenths + 1).eq(self.max_sub_sequence_length_tracker) violate_indices = torch.nonzero(violate_indices) violate_indices = violate_indices.unsqueeze(-1) * self.vocab_size violate_indices = violate_indices + self.word_token_ids.unsqueeze(0) violate_indices = violate_indices.view(-1) scores[violate_indices] = -float("inf") scores = scores.view(original_shape) next_tokens = torch.argmax(scores, dim=-1) self.update_tracker(next_tokens) return scores def update_tracker(self, next_tokens): anchor_positions = torch.isin(next_tokens, self.anchor_token_ids) beginning_token_ids = next_tokens[anchor_positions] self.current_sub_sequence_lenths = (self.current_sub_sequence_lenths + 1) * ~anchor_positions # print("before: {}".format(beginning_token_ids)) self.max_sub_sequence_length_tracker[anchor_positions] = beginning_token_ids.clone().cpu().apply_(lambda token: self.max_sub_length_mapper.get(token)).to(beginning_token_ids.device) # print("after: {}".format(beginning_token_ids)) self.min_sub_sequence_length_tracker[anchor_positions] = beginning_token_ids.clone().cpu().apply_(lambda token: self.min_sub_length_mapper.get(token)).to(beginning_token_ids.device) class DotGenerationLogitsProcessor(LogitsProcessor): def __init__(self, eos_token_id: int, dot_token_id): if not isinstance(eos_token_id, int) or eos_token_id < 0: raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") if not isinstance(dot_token_id, int) or dot_token_id < 0: raise ValueError(f"`dot_token_id` has to be a positive integer, but is {dot_token_id}") self.eos_token_id = eos_token_id self.dot_token_id = dot_token_id self.allowed_eos_indices = None def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if self.allowed_eos_indices is None: self.allowed_eos_indices = torch.zeros(scores.shape[0]).bool().to(scores.device) original_shape = scores.shape scores = scores.contiguous().view(-1) violated_indices = torch.nonzero(~self.allowed_eos_indices).to(scores.device) violated_indices = violated_indices * scores.shape[-1] + self.eos_token_id violated_indices = violated_indices.view(-1) scores[violated_indices] = -float("inf") scores = scores.view(original_shape) next_tokens = torch.argmax(scores, dim=-1) self.allowed_eos_indices = next_tokens.eq(self.dot_token_id) return scores if __name__=='__main__': print("Success")
Editor is loading...