CONV w/ tensor core 20240823

 avatar
user_3093867
python
20 days ago
4.2 kB
4
Indexable
Never
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset
import torch.cuda.nvtx as nvtx

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"CUDA version: {torch.version.cuda}")

class TFLiteMaximum(nn.Module):
    def __init__(self, channels):
        super(TFLiteMaximum, self).__init__()
        self.threshold = nn.Parameter(torch.zeros(1, channels, 1, 1))
   
    def forward(self, x):
        return torch.maximum(x, self.threshold)

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Modify Conv2d to use a multiple of 8 for both input and output channels
        self.conv = nn.Conv2d(8, 64, kernel_size=3, padding=1, bias=True)
        nn.init.constant_(self.conv.bias, 16)
        self.maximum = TFLiteMaximum(64)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(64 * 16 * 16, 16)
   
    def forward(self, x):
        nvtx.range_push("CONV2D")
        x = self.conv(x)
        nvtx.range_pop()

        nvtx.range_push("SCALE")
        x = x * 0.10000000
        nvtx.range_pop()

        nvtx.range_push("MAXIMUM")
        x = self.maximum(x)
        nvtx.range_pop()

        nvtx.range_push("MAXPOOL")
        x = self.max_pool(x)
        nvtx.range_pop()

        nvtx.range_push("RESHAPE")
        x = x.view(x.size(0), -1)
        nvtx.range_pop()

        nvtx.range_push("FULLY_CONNECTED")
        x = self.fc(x)
        nvtx.range_pop()

        return x

def train_model(model, criterion, optimizer, train_loader, num_epochs):
    scaler = GradScaler()
   
    for epoch in range(num_epochs):
        nvtx.range_push(f"Epoch {epoch+1}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        start_time = time.time()
       
        model.train()
        for batch_x, batch_y in train_loader:
            nvtx.range_push("Batch")
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
           
            with autocast():
                nvtx.range_push("Forward")
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                nvtx.range_pop()  # Forward
           
            optimizer.zero_grad()
            nvtx.range_push("Backward")
            scaler.scale(loss).backward()
            nvtx.range_pop()  # Backward
            
            nvtx.range_push("Optimizer step")
            scaler.step(optimizer)
            scaler.update()
            nvtx.range_pop()  # Optimizer step
            
            nvtx.range_pop()  # Batch
       
        print(f"Loss: {loss.item():.4f}")
        print(f"Time taken: {time.time() - start_time:.2f} seconds")
        print("------")
        torch.cuda.empty_cache()
        nvtx.range_pop()  # Epoch

def main():
    num_samples = 1024  # Multiple of 32 for better GPU utilization
    batch_size = 32
   
    model = CNNModel().to(device)
   
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
   
    # Generate dummy data (in FP16, with 8 input channels)
    x_train = torch.randn(num_samples, 8, 32, 32, dtype=torch.float16, device=device)
    y_train = torch.randint(0, 16, (num_samples,), device=device)
   
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
   
    # Warmup run
    print("Warmup run...")
    train_model(model, criterion, optimizer, train_loader, num_epochs=1)
   
    # Profiled run
    print("Profiled run...")
    train_model(model, criterion, optimizer, train_loader, num_epochs=5)
   
    # Inference profiling
    model.eval()
    with torch.no_grad(), autocast():
        print("Inference profiling...")
        test_batch = x_train[:batch_size]
        for _ in range(10):  # Run inference 10 times for better profiling
            predictions = model(test_batch)

if __name__ == "__main__":
    main()
Leave a Comment