Untitled

 avatar
unknown
plain_text
a year ago
2.5 kB
12
Indexable
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import cv2
from tqdm import tqdm

def create_emotion_model(num_ftrs, num_emotions):
    return nn.Sequential(
        nn.Linear(num_ftrs + num_emotions, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 2),
    )

def load_model(model_path, device):
    model = models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Identity()  
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()  
    return model, num_ftrs

def load_emotion_model(num_ftrs, num_emotions, model_path, device):
    emotion_model = create_emotion_model(num_ftrs, num_emotions).to(device)
    emotion_model.load_state_dict(torch.load(model_path, map_location=device))
    emotion_model.eval()  
    return emotion_model

def va_predict(val_model_path, val_featmodel_path, faces, emotions):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the models
    resnet, num_ftrs = load_model(val_featmodel_path, device)
    num_emotions = 1  # Assuming single emotion feature
    emotion_model = load_emotion_model(num_ftrs, num_emotions, val_model_path, device)

    # Define image transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    def model_forward(images, emotions):
        resnet_features = resnet(images)
        batch_size = resnet_features.size(0)
        emotions = emotions.view(batch_size, -1)
        x = torch.cat((resnet_features, emotions), dim=1)
        return emotion_model(x)

    arousal_list, valence_list, stress_list = [], [], []

    for face, emotion in tqdm(zip(faces, emotions), total=len(faces)):
        face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
        face_tensor = transform(face_pil).unsqueeze(0).to(device)
        emotion = emotion.to(device)
        
        with torch.no_grad():
            output_va = model_forward(face_tensor, emotion)
        
        arousal = float(output_va[0][0].item()) / 2 + 0.5
        valence = float(output_va[0][1].item()) / 2 + 0.5
        stress = (1 - valence) * arousal

        arousal_list.append(arousal)
        valence_list.append(valence)
        stress_list.append(stress)

    return valence_list, arousal_list, stress_list
Editor is loading...
Leave a Comment