engine3.py

 avatar
quoc14
python
5 months ago
8.4 kB
2
Indexable
FaceRC
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
import inspect
from PIL import Image
import gc
import sys
from face_engine.huggingface_model_utils import load_model_from_local_path

# Kiểm tra thiết bị (GPU hoặc CPU)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

class ModelLoader:
    def __init__(self, model_path: str):
        """Khởi tạo và load mô hình từ đường dẫn đã chỉ định."""
        self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE)
    
    def instance(self):
        """Trả về mô hình đã load."""
        return self.__model

class FaceEngine:
    def __init__(self, csv_file='/home/quoc14/Code/HeThongNhanDang/face_features_400.csv') -> None:
        """Khởi tạo FaceEngine và thiết lập file CSV dùng để lưu và so sánh."""
        self.__database_path = csv_file  # File CSV mặc định để lưu đặc trưng khuôn mặt
        self.__available_models = {
            "ir18_webface4m": "minchul/cvlface_adaface_ir18_webface4m",
            "ir50_webface4m": "minchul/cvlface_adaface_ir50_webface4m",
            "ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m",
            "ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m",
            "vit_base_kprpe_webface4m": "minchul/cvlface_adaface_vit_base_kprpe_webface4m",
            "vit_base_kprpe_webface12m": "minchul/cvlface_adaface_vit_base_webface12m"
        }
        self._extractor = None
        self._aligner = ModelLoader("minchul/cvlface_DFA_resnet50").instance()
        self.__model_name = "ir101_webface4m"  # Mặc định sử dụng mô hình này
        self.threshold = 0.3  # Ngưỡng để nhận diện
        self.load_csv()  # Tải cơ sở dữ liệu từ file CSV và xử lý các tensor đặc trưng

    def load_csv(self):
        """Load file CSV chứa các đặc trưng khuôn mặt và chuyển đổi đặc trưng từ JSON sang tensor."""
        if not os.path.exists(self.__database_path):
            # Tạo file CSV nếu chưa tồn tại
            columns = ['id', 'feat', 'name', 'image_path', 'date_registered']
            df = pd.DataFrame(columns=columns)
            df.to_csv(self.__database_path, index=False)
            self.__db = pd.DataFrame(columns=columns)  # Khởi tạo database rỗng
        else:
            # Chỉ load các cột cần thiết để tiết kiệm bộ nhớ
            self.__db = pd.read_csv(self.__database_path, dtype={'id': str}, usecols=['id', 'feat'])
            # Chuyển đổi tất cả các đặc trưng từ chuỗi JSON sang tensor một lần
            def convert_to_tensor(feat_str):
                try:
                    # Chuyển chuỗi JSON thành danh sách sau đó thành tensor
                    feat_list = eval(feat_str)
                    if isinstance(feat_list, list):
                        return torch.tensor(feat_list, device=DEVICE)
                    else:
                        raise ValueError("Feature is not a valid list.")
                except Exception as e:
                    print(f"Error converting feature: {e}")
                    return None
            
            self.__db['feat_tensor'] = self.__db['feat'].apply(convert_to_tensor)
            # Loại bỏ các hàng có đặc trưng không hợp lệ
            self.__db = self.__db.dropna(subset=['feat_tensor'])


    def reset_model(self):
        """Reset mô hình hiện tại và giải phóng bộ nhớ."""
        if hasattr(self, '_extractor'):
            del self._extractor
        
        # Dọn dẹp bộ nhớ và giải phóng GPU
        gc.collect()
        torch.cuda.empty_cache()

        # Xóa các module đã tải từ mô hình trước
        modules_to_delete = [key for key in sys.modules if "model" in key or "transformers" in key]
        for module in modules_to_delete:
            del sys.modules[module]

    def load_model(self, model_key):
        """Load mô hình dựa trên từ khóa model_key."""
        model_path = self.__available_models[model_key]
        print(f"Đang tải mô hình: {model_key} từ {model_path}")
        return ModelLoader(model_path).instance()

    def set_model(self, model_key):
        """Đặt mô hình hiện tại dựa trên lựa chọn của người dùng."""
        if model_key in self.__available_models:
            self.reset_model()  # Reset mô hình trước khi chuyển đổi
            self._extractor = self.load_model(model_key)  # Tải mô hình mới
            self.__model_name = model_key  # Cập nhật tên mô hình hiện tại
            print(f"Đã chuyển đổi sang mô hình: {model_key}")
        else:
            print(f"Mô hình {model_key} không tồn tại!")

    # Hàm chuyển đổi ảnh PIL thành tensor và chuẩn hóa
    def pil_to_input(self, pil_image):
        trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        return trans(pil_image).unsqueeze(0).to(DEVICE)

    def get_feat(self, pil_image):
        """Trích xuất đặc trưng từ ảnh."""
        input_tensor = self.pil_to_input(pil_image)
        aligned_x, _, aligned_ldmks, _, _, _ = self._aligner(input_tensor)

        # Kiểm tra mô hình xem có cần keypoints không
        input_signature = inspect.signature(self._extractor.model.net.forward)
        if input_signature.parameters.get('keypoints') is not None:
            feat = self._extractor(aligned_x, aligned_ldmks)
        else:
            feat = self._extractor(aligned_x)

        return feat

    def compute_cosine_similarity(self, feat_input, feats_db):
        """Tính toàn bộ cosine similarity giữa đặc trưng đầu vào và tất cả các đặc trưng trong cơ sở dữ liệu."""
        return torch.nn.functional.cosine_similarity(feat_input, feats_db)

    def get_id(self, pil_image):
        """So sánh ảnh với cơ sở dữ liệu và trả về ID nếu tìm thấy."""
        if not os.path.exists(self.__database_path):
            return {"id": None, "status": "not_found"}

        # Trích xuất đặc trưng của ảnh đầu vào
        feat_input = self.get_feat(pil_image)

        # Chuyển danh sách đặc trưng của cơ sở dữ liệu thành tensor để tính toán
        feats_db = torch.stack(self.__db['feat_tensor'].values.tolist()).to(DEVICE)

        # Tính toán cosine similarity giữa ảnh đầu vào và tất cả các ảnh trong cơ sở dữ liệu
        similarities = self.compute_cosine_similarity(feat_input, feats_db)

        # Lấy ra ảnh có cosine lớn nhất
        max_sim, idx_max = similarities.max(0)
        
        if max_sim.item() > self.threshold:
            # Đã nhận diện thành công
            best_match_id = self.__db.iloc[idx_max.item()]['id']
            return {"id": best_match_id, "status": "found"}
        else:
            return {"id": None, "status": "not_found"}
    def save_to_db(self, pil_image, user_id):
        """Lưu đặc trưng của ảnh mới với ID người dùng."""
        feature = self.get_feat(pil_image)

        if not isinstance(feature, torch.Tensor):
            return "Feature extraction failed, not saving to database."

        # Chuyển đổi tensor sang danh sách để lưu thành chuỗi JSON
        feature_list = feature.squeeze().cpu().detach().numpy().tolist()

        new_row = pd.DataFrame({
            'id': [user_id], 
            'name': ["quoc"], 
            'feat': [feature_list],  # Lưu đặc trưng dưới dạng danh sách JSON
             # Để trống nếu không có
            'image_path': ["quoc"],  # Để trống nếu không có
            'date_registered': [pd.Timestamp.now()]
        })
        self.__db = pd.concat([self.__db, new_row], ignore_index=True)
        self.__db.loc[self.__db['id'] == user_id, 'feat_tensor'] = self.__db.loc[self.__db['id'] == user_id, 'feat'].apply(lambda x: torch.tensor(x, device=DEVICE))


        # Thay vì ghi lại toàn bộ file, chỉ ghi thêm dòng mới
        with open(self.__database_path, 'a') as f:
            new_row.to_csv(f, header=False, index=False)
        
        return f"Data for user {user_id} saved successfully."

Editor is loading...
Leave a Comment