Untitled
unknown
plain_text
4 years ago
1.1 kB
5
Indexable
import argparse import torch from mmdet.models.utils.ckpt_convert import swin_converter def main(): parser = argparse.ArgumentParser( description='Convert keys in official pretrained swin models to' 'MMDetection style.') parser.add_argument('src', help='src detection model path') # The dst path must be a full path of the new checkpoint. parser.add_argument('dst', help='save path') args = parser.parse_args() checkpoint = torch.load(args.src, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint weight = swin_converter(state_dict) if 'state_dict' in checkpoint: checkpoint['state_dict'] = weight elif 'model' in checkpoint: checkpoint['model'] = weight else: checkpoint = weight with open(args.dst, 'wb') as f: torch.save(checkpoint, f) if __name__ == '__main__': main()
Editor is loading...