CONV w/ tensor core 20240823

 avatar
user_3093867
python
a year ago
4.2 kB
10
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset
import torch.cuda.nvtx as nvtx

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"CUDA version: {torch.version.cuda}")

class TFLiteMaximum(nn.Module):
    def __init__(self, channels):
        super(TFLiteMaximum, self).__init__()
        self.threshold = nn.Parameter(torch.zeros(1, channels, 1, 1))
   
    def forward(self, x):
        return torch.maximum(x, self.threshold)

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Modify Conv2d to use a multiple of 8 for both input and output channels
        self.conv = nn.Conv2d(8, 64, kernel_size=3, padding=1, bias=True)
        nn.init.constant_(self.conv.bias, 16)
        self.maximum = TFLiteMaximum(64)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(64 * 16 * 16, 16)
   
    def forward(self, x):
        nvtx.range_push("CONV2D")
        x = self.conv(x)
        nvtx.range_pop()

        nvtx.range_push("SCALE")
        x = x * 0.10000000
        nvtx.range_pop()

        nvtx.range_push("MAXIMUM")
        x = self.maximum(x)
        nvtx.range_pop()

        nvtx.range_push("MAXPOOL")
        x = self.max_pool(x)
        nvtx.range_pop()

        nvtx.range_push("RESHAPE")
        x = x.view(x.size(0), -1)
        nvtx.range_pop()

        nvtx.range_push("FULLY_CONNECTED")
        x = self.fc(x)
        nvtx.range_pop()

        return x

def train_model(model, criterion, optimizer, train_loader, num_epochs):
    scaler = GradScaler()
   
    for epoch in range(num_epochs):
        nvtx.range_push(f"Epoch {epoch+1}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        start_time = time.time()
       
        model.train()
        for batch_x, batch_y in train_loader:
            nvtx.range_push("Batch")
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
           
            with autocast():
                nvtx.range_push("Forward")
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                nvtx.range_pop()  # Forward
           
            optimizer.zero_grad()
            nvtx.range_push("Backward")
            scaler.scale(loss).backward()
            nvtx.range_pop()  # Backward
            
            nvtx.range_push("Optimizer step")
            scaler.step(optimizer)
            scaler.update()
            nvtx.range_pop()  # Optimizer step
            
            nvtx.range_pop()  # Batch
       
        print(f"Loss: {loss.item():.4f}")
        print(f"Time taken: {time.time() - start_time:.2f} seconds")
        print("------")
        torch.cuda.empty_cache()
        nvtx.range_pop()  # Epoch

def main():
    num_samples = 1024  # Multiple of 32 for better GPU utilization
    batch_size = 32
   
    model = CNNModel().to(device)
   
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
   
    # Generate dummy data (in FP16, with 8 input channels)
    x_train = torch.randn(num_samples, 8, 32, 32, dtype=torch.float16, device=device)
    y_train = torch.randint(0, 16, (num_samples,), device=device)
   
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
   
    # Warmup run
    print("Warmup run...")
    train_model(model, criterion, optimizer, train_loader, num_epochs=1)
   
    # Profiled run
    print("Profiled run...")
    train_model(model, criterion, optimizer, train_loader, num_epochs=5)
   
    # Inference profiling
    model.eval()
    with torch.no_grad(), autocast():
        print("Inference profiling...")
        test_batch = x_train[:batch_size]
        for _ in range(10):  # Run inference 10 times for better profiling
            predictions = model(test_batch)

if __name__ == "__main__":
    main()
Editor is loading...
Leave a Comment