engine.py
quoc14
python
5 months ago
4.7 kB
2
Indexable
FaceRC
import os import torch import pandas as pd from torchvision.transforms import Compose, ToTensor, Normalize import inspect from face_engine.huggingface_model_utils import load_model_from_local_path # Check devices DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' class ModelLoader(): def __init__(self, model_path: str): self.__model = load_model_from_local_path(f'face_engine/models/{model_path}').to(DEVICE) def instance(self): return self.__model class FaceEngine(): def __init__(self) -> None: self.__database_path = './face_db.csv' self.__extractor = ModelLoader("minchul/cvlface_adaface_vit_base_webface4m").instance() self.__aligner = ModelLoader("minchul/cvlface_DFA_mobilenet").instance() self.threshold = 0.3 self.load_csv() pass def reset_csv(self): # Tạo một DataFrame rỗng với các cột tiêu đề columns = ['id', 'feat'] df = pd.DataFrame(columns=columns) df.to_csv(self.__database_path, index=False) def load_csv(self): # Kiểm tra nếu file CSV không tồn tại hoặc trống if not os.path.exists(self.__database_path) or os.stat(self.__database_path).st_size == 0: # Nếu không tồn tại hoặc trống, bắt đầu ID từ 1 self.__next_id = 1 self.__db = pd.DataFrame(columns=['id', 'feat']) self.__db.to_csv(self.__database_path, index=False) else: self.__db = pd.read_csv(self.__database_path) # Đảm bảo rằng cột 'id' không bị lỗi và có thể chuyển thành số if pd.to_numeric(self.__db['id'], errors='coerce').isna().all(): self.__next_id = 1 # Bắt đầu lại nếu tất cả giá trị trong 'id' là NaN else: self.__next_id = self.__db['id'].max() + 1 # Tính ID kế tiếp # Hàm chuẩn hóa ảnh và căn chỉnh def pil_to_input(self, pil_image): trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) return trans(pil_image).unsqueeze(0).to(DEVICE) def get_feat(self, pil_image): # Chuẩn hóa ảnh và căn chỉnh input_tensor = self.pil_to_input(pil_image) # Căn chỉnh khuôn mặt aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, bbox = self.__aligner(input_tensor) # Nhận diện đặc trưng (feature) input_signature = inspect.signature(self.__extractor.model.net.forward) if input_signature.parameters.get('keypoints') is not None: feat = self.__extractor(aligned_x, aligned_ldmks) else: feat = self.__extractor(aligned_x) return feat def compute_cosine_similarity(self, pil_image_1, pil_image_2): # Lấy đặc trưng của hai ảnh feature_1 = self.get_feat(pil_image_1) feature_2 = self.get_feat(pil_image_2) # Tính toán cosine similarity cosine_similarity = torch.nn.functional.cosine_similarity(feature_1, feature_2).item() return cosine_similarity def get_id(self, pil_image): if not os.path.exists(self.__database_path): return None # Lấy đặc trưng (feature) feat_input = self.get_feat(pil_image) # So sánh với từng ảnh trong CSDL max_sim = -1 # Biến lưu giá trị cosine similarity lớn nhất matched_id = None for i, row in self.__db.iterrows(): # Chuyển đổi đặc trưng từ CSDL thành tensor feat_db = torch.tensor(eval(row['feat']), device=DEVICE) # Tính cosine similarity giữa ảnh đầu vào và ảnh trong CSDL cossim = torch.nn.functional.cosine_similarity(feat_input, feat_db).item() # Kiểm tra nếu cosine similarity lớn hơn ngưỡng và lớn hơn giá trị max_sim hiện tại if cossim > self.threshold and cossim > max_sim: max_sim = cossim matched_id = row['id'] return matched_id # Lưu feature mới vào CSDL def save_to_db(self, pil_image): new_id = self.__next_id feature = self.get_feat(pil_image) new_row = pd.DataFrame({'id': [new_id], 'feat': [str(feature.squeeze().cpu().detach().numpy().tolist())]}) self.__db = pd.concat([self.__db, new_row], ignore_index=True) self.__db.to_csv(self.__database_path, index=False) self.__next_id = self.__next_id + 1 return new_id
Editor is loading...
Leave a Comment