继承修改Transformers中vit的attention模块

直接修改字典VIT_ATTENTION_CLASSES,因为它是直接出现在__init__中的,才会生效
mail@pastecode.io avatar
unknown
python
5 months ago
1.4 kB
15
Indexable
import torch
from transformers.models.vit import modeling_vit
from transformers.models.vit.modeling_vit import ViTAttention
from transformers import ViTForImageClassification


class NewViTAttention(ViTAttention):
    def forward(
        self,
        *args,
        **kwargs,
    ):
        print("我的修改版模块!!!")
        return super().forward(*args, **kwargs)
    
VIT_ATTENTION_CLASSES = {
    "eager": NewViTAttention,
    "sdpa": NewViTAttention,
}
modeling_vit.VIT_ATTENTION_CLASSES = VIT_ATTENTION_CLASSES

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")

inputs = torch.randn(1, 3, 224,224)
one_batch = dict(pixel_values=inputs)

outputs = model(**one_batch)

# 输出如下:
# Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
# You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
# 我的修改版模块!!!
Leave a Comment