engine2.py
quoc14
python
a year ago
6.1 kB
13
Indexable
FaceRC
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image
import gc
import sys
from facenet_pytorch import MTCNN
from huggingface_model_utils import load_model_from_local_path
import inspect
# Kiểm tra thiết bị (GPU hoặc CPU)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class ModelLoader():
def __init__(self, model_path: str):
"""Khởi tạo và load mô hình từ đường dẫn đã chỉ định."""
self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE)
def instance(self):
"""Trả về mô hình đã load."""
return self.__model
class FaceEngine():
def __init__(self, csv_file='/home/quoc14/Code/HeThongNhanDang/face_features_400.csv') -> None:
"""Khởi tạo FaceEngine và thiết lập file CSV dùng để lưu và so sánh."""
self.__database_path = csv_file # File CSV mặc định để lưu đặc trưng khuôn mặt
self.__available_models = {
"ir18_webface4m": "minchul/cvlface_adaface_ir18_webface4m",
"ir50_webface4m": "minchul/cvlface_adaface_ir50_webface4m",
"ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m",
"ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m",
"vit_base_kprpe_webface4m": "minchul/cvlface_adaface_vit_base_kprpe_webface4m",
"vit_base_kprpe_webface12m": "minchul/cvlface_adaface_vit_base_kprpe_webface12m",
"vit_base_webface4m": "minchul/cvlface_adaface_vit_base_webface4m"
}
self._extractor = None
self._aligner = ModelLoader("minchul/cvlface_DFA_resnet50").instance()
self.__model_name = "ir101_webface4m" # Mặc định sử dụng mô hình này
self.threshold = 0.3 # Ngưỡng để nhận diện
self.mtcnn = MTCNN(keep_all=False, device=DEVICE) # Sử dụng MTCNN cho phát hiện khuôn mặt
self.load_csv() # Tải cơ sở dữ liệu từ file CSV
def load_csv(self):
"""Load file CSV chứa các đặc trưng khuôn mặt và chuyển đổi đặc trưng từ JSON sang tensor."""
if not os.path.exists(self.__database_path):
# Tạo file CSV nếu chưa tồn tại
columns = ['id', 'feat']
df = pd.DataFrame(columns=columns)
df.to_csv(self.__database_path, index=False)
self.__db = pd.DataFrame(columns=columns) # Khởi tạo database rỗng
else:
self.__db = pd.read_csv(self.__database_path, dtype={'id': str})
# Chuyển đổi tất cả các đặc trưng từ chuỗi JSON sang tensor một lần
self.__db['feat_tensor'] = self.__db['feat'].apply(lambda x: torch.tensor(eval(x), device=DEVICE))
def reset_model(self):
"""Reset mô hình hiện tại và giải phóng bộ nhớ."""
if hasattr(self, '_extractor'):
del self._extractor
# Dọn dẹp bộ nhớ và giải phóng GPU
gc.collect()
torch.cuda.empty_cache()
def load_model(self, model_key):
"""Load mô hình dựa trên từ khóa model_key."""
model_path = self.__available_models[model_key]
return ModelLoader(model_path).instance()
def set_model(self, model_key):
"""Đặt mô hình hiện tại dựa trên lựa chọn của người dùng."""
if model_key in self.__available_models:
self.reset_model() # Reset mô hình trước khi chuyển đổi
self._extractor = self.load_model(model_key) # Tải mô hình mới
self.__model_name = model_key # Cập nhật tên mô hình hiện tại
else:
print(f"Mô hình {model_key} không tồn tại!")
# Hàm chuẩn hóa tensor từ MTCNN
def normalize_tensor(self, tensor):
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
return normalize(tensor).unsqueeze(0).to(DEVICE) # Thêm chiều batch
def get_feat(self, pil_image):
"""Trích xuất đặc trưng từ ảnh."""
# Dùng MTCNN để phát hiện khuôn mặt trước khi trích xuất
face_crop = self.mtcnn(pil_image)
if face_crop is None:
return None
input_tensor = self.normalize_tensor(face_crop)
aligned_x, _, aligned_ldmks, _, _, _ = self._aligner(input_tensor)
input_signature = inspect.signature(self._extractor.model.net.forward)
if input_signature.parameters.get('keypoints') is not None:
feat = self._extractor(aligned_x, aligned_ldmks)
else:
feat = self._extractor(aligned_x)
return feat
def compute_cosine_similarity(self, feat_input, feats_db):
"""Tính toàn bộ cosine similarity giữa đặc trưng đầu vào và tất cả các đặc trưng trong cơ sở dữ liệu."""
return torch.nn.functional.cosine_similarity(feat_input, feats_db)
def get_id(self, pil_image):
"""So sánh ảnh với cơ sở dữ liệu và trả về ID nếu tìm thấy."""
if not os.path.exists(self.__database_path):
return {"id": None, "status": "not_found"}
# Trích xuất đặc trưng của ảnh đầu vào
feat_input = self.get_feat(pil_image)
if feat_input is None:
return {"id": None, "status": "not_found"}
# Tính toán cosine similarity giữa ảnh đầu vào và tất cả các ảnh trong cơ sở dữ liệu
feats_db = torch.stack(self.__db['feat_tensor'].values.tolist()).to(DEVICE)
similarities = self.compute_cosine_similarity(feat_input, feats_db)
# Lấy ra ảnh có cosine lớn nhất
max_sim, idx_max = similarities.max(0)
if max_sim.item() > self.threshold:
best_match_id = self.__db.iloc[idx_max.item()]['id']
return {"id": best_match_id, "status": "found"}
else:
return {"id": None, "status": "not_found"}
Editor is loading...
Leave a Comment