engine2.py
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image
import gc
import sys
from facenet_pytorch import MTCNN
from huggingface_model_utils import load_model_from_local_path
import inspect
# Kiểm tra thiết bị (GPU hoặc CPU)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class ModelLoader():
def __init__(self, model_path: str):
"""Khởi tạo và load mô hình từ đường dẫn đã chỉ định."""
self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE)
def instance(self):
"""Trả về mô hình đã load."""
return self.__model
class FaceEngine():
def __init__(self, csv_file='/home/quoc14/Code/HeThongNhanDang/face_features_400.csv') -> None:
"""Khởi tạo FaceEngine và thiết lập file CSV dùng để lưu và so sánh."""
self.__database_path = csv_file # File CSV mặc định để lưu đặc trưng khuôn mặt
self.__available_models = {
"ir18_webface4m": "minchul/cvlface_adaface_ir18_webface4m",
"ir50_webface4m": "minchul/cvlface_adaface_ir50_webface4m",
"ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m",
"ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m",
"vit_base_kprpe_webface4m": "minchul/cvlface_adaface_vit_base_kprpe_webface4m",
"vit_base_kprpe_webface12m": "minchul/cvlface_adaface_vit_base_kprpe_webface12m",
"vit_base_webface4m": "minchul/cvlface_adaface_vit_base_webface4m"
}
self._extractor = None
self._aligner = ModelLoader("minchul/cvlface_DFA_resnet50").instance()
self.__model_name = "ir101_webface4m" # Mặc định sử dụng mô hình này
self.threshold = 0.3 # Ngưỡng để nhận diện
self.mtcnn = MTCNN(keep_all=False, device=DEVICE) # Sử dụng MTCNN cho phát hiện khuôn mặt
self.load_csv() # Tải cơ sở dữ liệu từ file CSV
def load_csv(self):
"""Load file CSV chứa các đặc trưng khuôn mặt và chuyển đổi đặc trưng từ JSON sang tensor."""
if not os.path.exists(self.__database_path):
# Tạo file CSV nếu chưa tồn tại
columns = ['id', 'feat']
df = pd.DataFrame(columns=columns)
df.to_csv(self.__database_path, index=False)
self.__db = pd.DataFrame(columns=columns) # Khởi tạo database rỗng
else:
self.__db = pd.read_csv(self.__database_path, dtype={'id': str})
# Chuyển đổi tất cả các đặc trưng từ chuỗi JSON sang tensor một lần
self.__db['feat_tensor'] = self.__db['feat'].apply(lambda x: torch.tensor(eval(x), device=DEVICE))
def reset_model(self):
"""Reset mô hình hiện tại và giải phóng bộ nhớ."""
if hasattr(self, '_extractor'):
del self._extractor
# Dọn dẹp bộ nhớ và giải phóng GPU
gc.collect()
torch.cuda.empty_cache()
def load_model(self, model_key):
"""Load mô hình dựa trên từ khóa model_key."""
model_path = self.__available_models[model_key]
return ModelLoader(model_path).instance()
def set_model(self, model_key):
"""Đặt mô hình hiện tại dựa trên lựa chọn của người dùng."""
if model_key in self.__available_models:
self.reset_model() # Reset mô hình trước khi chuyển đổi
self._extractor = self.load_model(model_key) # Tải mô hình mới
self.__model_name = model_key # Cập nhật tên mô hình hiện tại
else:
print(f"Mô hình {model_key} không tồn tại!")
# Hàm chuẩn hóa tensor từ MTCNN
def normalize_tensor(self, tensor):
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
return normalize(tensor).unsqueeze(0).to(DEVICE) # Thêm chiều batch
def get_feat(self, pil_image):
"""Trích xuất đặc trưng từ ảnh."""
# Dùng MTCNN để phát hiện khuôn mặt trước khi trích xuất
face_crop = self.mtcnn(pil_image)
if face_crop is None:
return None
input_tensor = self.normalize_tensor(face_crop)
aligned_x, _, aligned_ldmks, _, _, _ = self._aligner(input_tensor)
input_signature = inspect.signature(self._extractor.model.net.forward)
if input_signature.parameters.get('keypoints') is not None:
feat = self._extractor(aligned_x, aligned_ldmks)
else:
feat = self._extractor(aligned_x)
return feat
def compute_cosine_similarity(self, feat_input, feats_db):
"""Tính toàn bộ cosine similarity giữa đặc trưng đầu vào và tất cả các đặc trưng trong cơ sở dữ liệu."""
return torch.nn.functional.cosine_similarity(feat_input, feats_db)
def get_id(self, pil_image):
"""So sánh ảnh với cơ sở dữ liệu và trả về ID nếu tìm thấy."""
if not os.path.exists(self.__database_path):
return {"id": None, "status": "not_found"}
# Trích xuất đặc trưng của ảnh đầu vào
feat_input = self.get_feat(pil_image)
if feat_input is None:
return {"id": None, "status": "not_found"}
# Tính toán cosine similarity giữa ảnh đầu vào và tất cả các ảnh trong cơ sở dữ liệu
feats_db = torch.stack(self.__db['feat_tensor'].values.tolist()).to(DEVICE)
similarities = self.compute_cosine_similarity(feat_input, feats_db)
# Lấy ra ảnh có cosine lớn nhất
max_sim, idx_max = similarities.max(0)
if max_sim.item() > self.threshold:
best_match_id = self.__db.iloc[idx_max.item()]['id']
return {"id": best_match_id, "status": "found"}
else:
return {"id": None, "status": "not_found"}
Editor is loading...
Leave a Comment