hugg1.py

 avatar
quoc14
python
5 months ago
3.2 kB
2
Indexable
FaceRC
import sys

from transformers import AutoModel
from huggingface_hub import hf_hub_download
import shutil
import os
import sys
import gc
import torch

list_all_models = ["minchul/cvlface_DFA_mobilenet", 
                   "minchul/cvlface_DFA_resnet50",
                   "minchul/cvlface_adaface_vit_base_webface4m",
                   "minchul/cvlface_DFA_resnet50",
                   "minchul/cvlface_adaface_ir18_vgg2",
                   "minchul/cvlface_adaface_ir18_webface4m",
                   "minchul/cvlface_adaface_ir50_webface4m",
                   "minchul/cvlface_adaface_ir50_casia",
                   "minchul/cvlface_adaface_ir50_ms1mv2",
                   "minchul/cvlface_adaface_ir101_ms1mv2",
                   "minchul/cvlface_adaface_ir101_ms1mv3",
                   "minchul/cvlface_adaface_ir101_webface4m",
                   "minchul/cvlface_adaface_vit_base_kprpe_webface12m",
                   "minchul/cvlface_adaface_ir101_webface12m",
                   "minchul/cvlface_adaface_vit_base_webface4m",
                   "minchul/cvlface_adaface_vit_base_kprpe_webface4m"
                   ]

# helpfer function to download huggingface repo and use model
def download(repo_id, path, HF_TOKEN=None):
    os.makedirs(path, exist_ok=True)
    files_path = os.path.join(path, 'files.txt')
    if not os.path.exists(files_path):
        hf_hub_download(repo_id, 'files.txt', token=HF_TOKEN, local_dir=path, local_dir_use_symlinks=False)
    with open(os.path.join(path, 'files.txt'), 'r') as f:
        files = f.read().split('\n')
    for file in [f for f in files if f] + ['config.json', 'wrapper.py', 'model.safetensors']:
        full_path = os.path.join(path, file)
        if not os.path.exists(full_path):
            hf_hub_download(repo_id, file, token=HF_TOKEN, local_dir=path, local_dir_use_symlinks=False)

def download_all_models():
    for model in list_all_models:
        print("-----------------Downloading model: ", model, "-----------------")
        download(model, os.path.abspath(f"model/{model}"))


def load_model_from_local_path(path, HF_TOKEN=None):
    path = os.path.abspath(path)
    cwd = os.getcwd()
    try:
        os.chdir(path)
        sys.path.insert(0, path)

        # Tải mô hình từ đường dẫn chỉ định
        model = AutoModel.from_pretrained(path, trust_remote_code=True)
        print(f"Tải thành công mô hình từ: {path}")
        return model

    except Exception as e:
        print(f"Lỗi khi tải mô hình từ {path}: {str(e)}")
        raise e

    finally:
        # Quay lại thư mục ban đầu và dọn dẹp đường dẫn
        os.chdir(cwd)
        if path in sys.path:
            sys.path.pop(0)
        torch.cuda.empty_cache()  # Giải phóng bộ nhớ GPU nếu cần
        gc.collect()



# helpfer function to download huggingface repo and use model
def load_model_by_repo_id(repo_id, save_path, HF_TOKEN=None, force_download=False):
    if force_download:
        if os.path.exists(save_path):
            shutil.rmtree(save_path)
    download(repo_id, save_path, HF_TOKEN)
    return load_model_from_local_path(save_path, HF_TOKEN)

Editor is loading...
Leave a Comment