继承修改Transformers中vit的attention模块
直接修改字典VIT_ATTENTION_CLASSES,因为它是直接出现在__init__中的,才会生效unknown
python
18 days ago
1.4 kB
7
Indexable
Never
import torch from transformers.models.vit import modeling_vit from transformers.models.vit.modeling_vit import ViTAttention from transformers import ViTForImageClassification class NewViTAttention(ViTAttention): def forward( self, *args, **kwargs, ): print("我的修改版模块!!!") return super().forward(*args, **kwargs) VIT_ATTENTION_CLASSES = { "eager": NewViTAttention, "sdpa": NewViTAttention, } modeling_vit.VIT_ATTENTION_CLASSES = VIT_ATTENTION_CLASSES model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k") inputs = torch.randn(1, 3, 224,224) one_batch = dict(pixel_values=inputs) outputs = model(**one_batch) # 输出如下: # Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight'] # You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!! # 我的修改版模块!!!
Leave a Comment