Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
2.5 kB
3
Indexable
Never
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import cv2
from tqdm import tqdm

def create_emotion_model(num_ftrs, num_emotions):
    return nn.Sequential(
        nn.Linear(num_ftrs + num_emotions, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 2),
    )

def load_model(model_path, device):
    model = models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Identity()  
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()  
    return model, num_ftrs

def load_emotion_model(num_ftrs, num_emotions, model_path, device):
    emotion_model = create_emotion_model(num_ftrs, num_emotions).to(device)
    emotion_model.load_state_dict(torch.load(model_path, map_location=device))
    emotion_model.eval()  
    return emotion_model

def va_predict(val_model_path, val_featmodel_path, faces, emotions):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the models
    resnet, num_ftrs = load_model(val_featmodel_path, device)
    num_emotions = 1  # Assuming single emotion feature
    emotion_model = load_emotion_model(num_ftrs, num_emotions, val_model_path, device)

    # Define image transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    def model_forward(images, emotions):
        resnet_features = resnet(images)
        batch_size = resnet_features.size(0)
        emotions = emotions.view(batch_size, -1)
        x = torch.cat((resnet_features, emotions), dim=1)
        return emotion_model(x)

    arousal_list, valence_list, stress_list = [], [], []

    for face, emotion in tqdm(zip(faces, emotions), total=len(faces)):
        face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
        face_tensor = transform(face_pil).unsqueeze(0).to(device)
        emotion = emotion.to(device)
        
        with torch.no_grad():
            output_va = model_forward(face_tensor, emotion)
        
        arousal = float(output_va[0][0].item()) / 2 + 0.5
        valence = float(output_va[0][1].item()) / 2 + 0.5
        stress = (1 - valence) * arousal

        arousal_list.append(arousal)
        valence_list.append(valence)
        stress_list.append(stress)

    return valence_list, arousal_list, stress_list
Leave a Comment