Untitled
unknown
python
3 years ago
9.5 kB
7
Indexable
import os
from typing import Dict
import transformers
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from transformers.generation_logits_process import LogitsProcessor, LogitsProcessorList
import torch
from torch import Tensor
from .basenet import BaseNet
from .metadata import *
class CodeT5():
def __init__(self, model_name):
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum').to(self.device)
# self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
# self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum')
self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
self.load_checkpoint(model_name)
self.default_beam_size = 1
# This is max length of the target sequence
self.max_length = 512
if torch.cuda.is_available():
print("Using this GPU device for prediction : ", torch.cuda.get_device_name(0), "....")
else:
print("Using CPU for prediction ......")
def load_checkpoint(self, filename):
try:
if os.path.isfile(filename):
if torch.cuda.is_available():
self.model.load_state_dict(torch.load(filename))
self.model.to('cuda')
else:
self.model.load_state_dict(torch.load(filename,map_location=torch.device('cpu')))
except:
print(f"Failed to load model from {filename}")
def predict(self, input_ids: torch.Tensor, **kwargs):
if kwargs.get("max_length_mapper") is not None:
MAX_SUB_LENGTH_MAPPER = kwargs["max_length_mapper"]
if kwargs.get("min_length_mapper") is not None:
MIN_SUB_LENGTH_MAPPER = kwargs["min_length_mapper"]
print("DEvice: ", self.model.device)
input_ids = input_ids.to(self.model.device)
batch_size = input_ids.size(0)
beam_size = self.default_beam_size
beam_size = 1
temperature = 1.0
anchor_token_ids = [self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS_MAP[comment_type]) for comment_type in COMMENT_TYPES]
anchor_token_ids = torch.LongTensor(anchor_token_ids).to(input_ids.device)
max_sub_length_mapper = {self.tokenizer.convert_tokens_to_ids(token): length for token, length in MAX_SUB_LENGTH_MAPPER.items()}
min_sub_length_mapper = {self.tokenizer.convert_tokens_to_ids(token): length for token, length in MIN_SUB_LENGTH_MAPPER.items()}
# csearch or greedy?
with torch.no_grad():
logits_processor = MaxMinSubLengthLogitsProcessor(anchor_token_ids,
max_sub_length_mapper,
min_sub_length_mapper,
self.tokenizer.vocab_size,
batch_size,
beam_size,
self.tokenizer.eos_token_id,
bound_max_length=True,
bound_min_length=True,
)
dot_processor = DotGenerationLogitsProcessor(self.tokenizer.eos_token_id, dot_token_id=self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize("."))[0])
logits_processor = LogitsProcessorList([logits_processor, dot_processor])
generated_ids = self.model.generate(input_ids,
max_length=self.max_length,
num_beams=beam_size,
temperature=temperature,
# num_return_sequences=num_samples,
logits_processor=logits_processor,
penalty_alpha=kwargs["penalty_alpha"],
top_k = kwargs["top_k"]
)
return generated_ids
def cuda(self):
pass
class MaxMinSubLengthLogitsProcessor(LogitsProcessor):
def __init__(self, anchor_token_ids: Tensor, max_sub_length_mapper: Dict, min_sub_length_mapper: Dict, vocab_size: int, batch_size: int, beam_size: int, eos_token_id: int, bound_max_length=True, bound_min_length=True):
self.vocab_size = vocab_size
self.anchor_token_ids = anchor_token_ids
self.max_sub_length_mapper = max_sub_length_mapper
self.min_sub_length_mapper = min_sub_length_mapper
self.word_token_ids = torch.LongTensor([i for i in range(vocab_size) if i not in self.anchor_token_ids.cpu().tolist() + [eos_token_id]]).to(anchor_token_ids.device)
self.ending_token_ids = torch.LongTensor(self.anchor_token_ids.cpu().tolist() + [eos_token_id]).to(anchor_token_ids.device)
self.current_sub_sequence_lenths = torch.zeros(batch_size * beam_size).to(self.anchor_token_ids.device)
self.max_sub_sequence_length_tracker = torch.empty(batch_size * beam_size).long().to(self.anchor_token_ids.device).fill_(1e4)
self.min_sub_sequence_length_tracker = torch.empty(batch_size * beam_size).long().to(self.anchor_token_ids.device).fill_(-1)
self.bound_max_length = bound_max_length
self.bound_min_length = bound_min_length
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
original_shape = scores.shape
scores = scores.view(-1)
# MIN LENGTH
if self.bound_min_length:
violate_indices = self.current_sub_sequence_lenths.lt(self.min_sub_sequence_length_tracker)
violate_indices = torch.nonzero(violate_indices)
violate_indices = violate_indices.unsqueeze(-1) * self.vocab_size
violate_indices = violate_indices + self.ending_token_ids.unsqueeze(0)
violate_indices = violate_indices.view(-1)
scores[violate_indices] = -float("inf")
# MAX LENGTH
if self.bound_max_length:
violate_indices = (self.current_sub_sequence_lenths + 1).eq(self.max_sub_sequence_length_tracker)
violate_indices = torch.nonzero(violate_indices)
violate_indices = violate_indices.unsqueeze(-1) * self.vocab_size
violate_indices = violate_indices + self.word_token_ids.unsqueeze(0)
violate_indices = violate_indices.view(-1)
scores[violate_indices] = -float("inf")
scores = scores.view(original_shape)
next_tokens = torch.argmax(scores, dim=-1)
self.update_tracker(next_tokens)
return scores
def update_tracker(self, next_tokens):
anchor_positions = torch.isin(next_tokens, self.anchor_token_ids)
beginning_token_ids = next_tokens[anchor_positions]
self.current_sub_sequence_lenths = (self.current_sub_sequence_lenths + 1) * ~anchor_positions
# print("before: {}".format(beginning_token_ids))
self.max_sub_sequence_length_tracker[anchor_positions] = beginning_token_ids.clone().cpu().apply_(lambda token: self.max_sub_length_mapper.get(token)).to(beginning_token_ids.device)
# print("after: {}".format(beginning_token_ids))
self.min_sub_sequence_length_tracker[anchor_positions] = beginning_token_ids.clone().cpu().apply_(lambda token: self.min_sub_length_mapper.get(token)).to(beginning_token_ids.device)
class DotGenerationLogitsProcessor(LogitsProcessor):
def __init__(self, eos_token_id: int, dot_token_id):
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if not isinstance(dot_token_id, int) or dot_token_id < 0:
raise ValueError(f"`dot_token_id` has to be a positive integer, but is {dot_token_id}")
self.eos_token_id = eos_token_id
self.dot_token_id = dot_token_id
self.allowed_eos_indices = None
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.allowed_eos_indices is None:
self.allowed_eos_indices = torch.zeros(scores.shape[0]).bool().to(scores.device)
original_shape = scores.shape
scores = scores.contiguous().view(-1)
violated_indices = torch.nonzero(~self.allowed_eos_indices).to(scores.device)
violated_indices = violated_indices * scores.shape[-1] + self.eos_token_id
violated_indices = violated_indices.view(-1)
scores[violated_indices] = -float("inf")
scores = scores.view(original_shape)
next_tokens = torch.argmax(scores, dim=-1)
self.allowed_eos_indices = next_tokens.eq(self.dot_token_id)
return scores
if __name__=='__main__':
print("Success")
Editor is loading...