Untitled

 avatar
unknown
plain_text
a month ago
15 kB
3
Indexable
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import os
from mpi4py import MPI
import argparse
import psutil
import gc
from rng import get_rng, rng_context, register_rng
from mpiwrapper import mpi
from moe import SimpleMoE, MoE_EP, MoE_TP

class MemoryTracker:
    """Track peak memory usage during execution"""
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.peak_memory = 0
        
    def update(self):
        current_memory = psutil.Process().memory_info().rss / (1024 * 1024)  # Convert to MB
        self.peak_memory = max(self.peak_memory, current_memory)
        
    def get_peak(self):
        return self.peak_memory

def run_moe(
    moe_type="tp", 
    batch_size=8, 
    feature_dim=32, 
    hidden_dim=128, 
    output_dim=64, 
    num_experts=None,
    topk=2,
    warmup=2,
    iterations=5,
    measure_memory=True
):
    """
    Enhanced benchmarking function for MoE models
    
    Args:
        moe_type: Type of MoE to run ("simple", "ep", or "tp")
        batch_size: Number of samples in the batch
        feature_dim: Dimension of input features
        hidden_dim: Hidden dimension for experts
        output_dim: Output dimension
        num_experts: Number of experts (defaults to number of processes)
        topk: Number of experts to route each input to
        warmup: Number of warmup iterations
        iterations: Number of measured iterations
        measure_memory: Whether to track memory usage
    
    Returns:
        Dictionary with benchmarking results
    """
    # Get number of experts based on MPI world size if not specified
    if num_experts is None:
        num_experts = mpi.get_size()
    
    # Set up memory tracking if requested
    mem_tracker = MemoryTracker() if measure_memory else None
    
    # Generate input data
    np.random.seed(0)
    X = np.random.randn(batch_size, feature_dim)

    if moe_type != "simple":
        # Synchronize the input data across all processes
        if mpi.get_rank() == 0:
            X = get_rng().randn(batch_size, feature_dim)
        else:
            X = None
        X = mpi.comm.bcast(X, root=0)
    
    # Create appropriate MoE model
    model_class = {
        "simple": SimpleMoE,
        "ep": MoE_EP,
        "tp": MoE_TP
    }.get(moe_type, MoE_TP)
    
    moe = model_class(
        input_dim=feature_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        num_experts=num_experts,
        topk=topk
    )
    
    # Warmup
    for _ in range(warmup):
        _ = moe(X)
    
    # Synchronize before timing
    mpi.barrier()
    
    # Measure main iterations
    durations = []
    for i in range(iterations):
        if measure_memory:
            mem_tracker.reset()
            
        # Synchronize before each run
        mpi.barrier()
        start_time = time.time()
        
        outputs = moe(X)
        
        # Synchronize after to ensure all processes complete
        mpi.barrier()
        end_time = time.time()
        
        if measure_memory:
            mem_tracker.update()
            
        duration_ms = 1000 * (end_time - start_time)
        durations.append(duration_ms)
    
    # Gather results from all processes
    all_durations = mpi.gather(durations, root=0)
    
    # Gather memory stats if measured
    if measure_memory:
        all_peak_memory = mpi.gather(mem_tracker.get_peak(), root=0)
    else:
        all_peak_memory = None
    
    # Clean up to help with memory usage between runs
    del moe
    gc.collect()
    
    # Calculate results on root process
    if mpi.get_rank() == 0:
        avg_duration = np.mean(durations)
        min_duration = np.min(durations)
        max_duration = np.max(durations)
        std_duration = np.std(durations)
        
        # Calculate throughput
        tokens_per_second = batch_size * 1000 / avg_duration
        
        # Prepare results
        results = {
            "moe_type": moe_type,
            "batch_size": batch_size,
            "feature_dim": feature_dim,
            "hidden_dim": hidden_dim,
            "output_dim": output_dim,
            "num_experts": num_experts,
            "topk": topk,
            "avg_duration_ms": float(avg_duration),
            "min_duration_ms": float(min_duration),
            "max_duration_ms": float(max_duration),
            "std_duration_ms": float(std_duration),
            "tokens_per_second": float(tokens_per_second),
            "all_process_durations": all_durations,
        }
        
        if measure_memory:
            results["peak_memory_mb"] = all_peak_memory
            results["total_memory_mb"] = sum(all_peak_memory)
            
        return results
    else:
        return None


def run_scaling_experiment(
    moe_types=["simple", "ep", "tp"],
    batch_sizes=[8, 16, 32, 64, 128, 256],
    feature_dims=[32, 64, 128, 256],
    hidden_dims=[128, 256, 512],
    output_dims=[64, 128, 256],
    topk_values=[1, 2, 4, 8],
    output_dir="benchmark_results"
):
    """
    Run comprehensive scaling experiments
    
    This function will run benchmarks with different parameter configurations
    and save the results for analysis.
    """
    if mpi.get_rank() == 0:
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        print(f"Starting scaling experiments. Results will be saved to {output_dir}")
    
    # Track all results
    all_results = []
    
    # Experiment 1: Varying batch size
    if mpi.get_rank() == 0:
        print("\nExperiment 1: Varying batch size")
    
    for moe_type in moe_types:
        for batch_size in batch_sizes:
            results = run_moe(
                moe_type=moe_type,
                batch_size=batch_size,
                feature_dim=64,  # Fixed for this experiment
                hidden_dim=128,  # Fixed for this experiment
                output_dim=64,   # Fixed for this experiment
                topk=2           # Fixed for this experiment
            )
            
            if results is not None:  # Only root process has results
                all_results.append(results)
                print(f"  {moe_type}, batch_size={batch_size}: {results['avg_duration_ms']:.2f} ms")
    
    # Experiment 2: Varying feature dimension
    if mpi.get_rank() == 0:
        print("\nExperiment 2: Varying feature dimension")
    
    for moe_type in moe_types:
        for feature_dim in feature_dims:
            results = run_moe(
                moe_type=moe_type,
                batch_size=32,      # Fixed for this experiment
                feature_dim=feature_dim,
                hidden_dim=128,     # Fixed for this experiment
                output_dim=64,      # Fixed for this experiment
                topk=2              # Fixed for this experiment
            )
            
            if results is not None:
                all_results.append(results)
                print(f"  {moe_type}, feature_dim={feature_dim}: {results['avg_duration_ms']:.2f} ms")
    
    # Experiment 3: Varying topk
    if mpi.get_rank() == 0:
        print("\nExperiment 3: Varying topk (number of experts per token)")
    
    for moe_type in moe_types:
        for topk in topk_values:
            results = run_moe(
                moe_type=moe_type,
                batch_size=32,      # Fixed for this experiment
                feature_dim=64,     # Fixed for this experiment
                hidden_dim=128,     # Fixed for this experiment
                output_dim=64,      # Fixed for this experiment
                topk=topk
            )
            
            if results is not None:
                all_results.append(results)
                print(f"  {moe_type}, topk={topk}: {results['avg_duration_ms']:.2f} ms")
    
    # Save all results to disk
    if mpi.get_rank() == 0:
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        filename = os.path.join(output_dir, f"moe_benchmark_results_{timestamp}.json")
        
        with open(filename, 'w') as f:
            json.dump(all_results, f, indent=2)
        
        print(f"\nAll results saved to {filename}")


def generate_plots(results_file, output_dir="benchmark_plots"):
    """
    Generate visualization plots from benchmark results
    
    Args:
        results_file: Path to JSON file with benchmark results
        output_dir: Directory to save plots
    """
    # Only root process should generate plots
    if mpi.get_rank() != 0:
        return
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    with open(results_file, 'r') as f:
        results = json.load(f)
    
    # Group results by experiment type
    batch_size_results = {}
    feature_dim_results = {}
    topk_results = {}
    
    for r in results:
        moe_type = r["moe_type"]
        
        # Group batch size results
        if r["feature_dim"] == 64 and r["hidden_dim"] == 128 and r["output_dim"] == 64 and r["topk"] == 2:
            if moe_type not in batch_size_results:
                batch_size_results[moe_type] = []
            batch_size_results[moe_type].append((r["batch_size"], r["avg_duration_ms"]))
        
        # Group feature dimension results
        if r["batch_size"] == 32 and r["hidden_dim"] == 128 and r["output_dim"] == 64 and r["topk"] == 2:
            if moe_type not in feature_dim_results:
                feature_dim_results[moe_type] = []
            feature_dim_results[moe_type].append((r["feature_dim"], r["avg_duration_ms"]))
        
        # Group topk results
        if r["batch_size"] == 32 and r["feature_dim"] == 64 and r["hidden_dim"] == 128 and r["output_dim"] == 64:
            if moe_type not in topk_results:
                topk_results[moe_type] = []
            topk_results[moe_type].append((r["topk"], r["avg_duration_ms"]))
    
    # Sort results by x-axis value
    for moe_type in batch_size_results:
        batch_size_results[moe_type].sort(key=lambda x: x[0])
    
    for moe_type in feature_dim_results:
        feature_dim_results[moe_type].sort(key=lambda x: x[0])
    
    for moe_type in topk_results:
        topk_results[moe_type].sort(key=lambda x: x[0])
    
    # Create color scheme
    colors = {
        "simple": "blue",
        "ep": "green",
        "tp": "red"
    }
    
    # Plot batch size scaling
    plt.figure(figsize=(10, 6))
    for moe_type in batch_size_results:
        x = [r[0] for r in batch_size_results[moe_type]]
        y = [r[1] for r in batch_size_results[moe_type]]
        plt.plot(x, y, 'o-', label=f"{moe_type} MoE", color=colors.get(moe_type, "black"))
    
    plt.xlabel("Batch Size")
    plt.ylabel("Avg Duration (ms)")
    plt.title("MoE Performance vs Batch Size")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.savefig(os.path.join(output_dir, "batch_size_scaling.png"), dpi=300, bbox_inches='tight')
    
    # Plot feature dimension scaling
    plt.figure(figsize=(10, 6))
    for moe_type in feature_dim_results:
        x = [r[0] for r in feature_dim_results[moe_type]]
        y = [r[1] for r in feature_dim_results[moe_type]]
        plt.plot(x, y, 'o-', label=f"{moe_type} MoE", color=colors.get(moe_type, "black"))
    
    plt.xlabel("Feature Dimension")
    plt.ylabel("Avg Duration (ms)")
    plt.title("MoE Performance vs Feature Dimension")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.savefig(os.path.join(output_dir, "feature_dim_scaling.png"), dpi=300, bbox_inches='tight')
    
    # Plot topk scaling
    plt.figure(figsize=(10, 6))
    for moe_type in topk_results:
        x = [r[0] for r in topk_results[moe_type]]
        y = [r[1] for r in topk_results[moe_type]]
        plt.plot(x, y, 'o-', label=f"{moe_type} MoE", color=colors.get(moe_type, "black"))
    
    plt.xlabel("TopK (Experts per token)")
    plt.ylabel("Avg Duration (ms)")
    plt.title("MoE Performance vs TopK")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.savefig(os.path.join(output_dir, "topk_scaling.png"), dpi=300, bbox_inches='tight')
    
    print(f"Plots saved to {output_dir}")


def benchmark_scaling(world_sizes):
    """
    Benchmark MoE models with different world sizes (expert/shard counts).
    Note: This function must be run externally for different processes.
    """
    results = {}
    
    # Only rank 0 records results
    if mpi.get_rank() == 0:
        print(f"Benchmarking with {mpi.get_size()} processes")
        
        for moe_type in ["ep", "tp"]:
            result = run_moe(moe_type=moe_type)
            if result is not None:
                results[moe_type] = result
        
        # Record world size and results for later comparison
        with open(f"scaling_results_{mpi.get_size()}.json", 'w') as f:
            json.dump(results, f)


def main():
    parser = argparse.ArgumentParser(description='Benchmark MoE models')
    parser.add_argument('--mode', choices=['basic', 'scaling', 'detailed'], default='basic',
                        help='Benchmarking mode')
    parser.add_argument('--plot', action='store_true', help='Generate plots from results')
    parser.add_argument('--results-file', type=str, help='Results file for plotting')
    parser.add_argument('--output-dir', type=str, default='benchmark_results',
                        help='Output directory for results and plots')
    
    args = parser.parse_args()
    
    if args.plot and args.results_file:
        generate_plots(args.results_file, output_dir=args.output_dir)
        return
    
    if args.mode == 'basic':
        # Run the original simple benchmark
        if mpi.get_rank() == 0:
            print("Running basic benchmark")
        
        # Test simple MoE
        simple_result = run_moe(moe_type="simple")
        if mpi.get_rank() == 0:
            print(f"Simple MoE: {simple_result['avg_duration_ms']:.2f} ms")

        # Test TP MoE
        tp_result = run_moe(moe_type="tp")
        if mpi.get_rank() == 0:
            print(f"TP MoE: {tp_result['avg_duration_ms']:.2f} ms")

        # Test EP MoE
        ep_result = run_moe(moe_type="ep")
        if mpi.get_rank() == 0:
            print(f"EP MoE: {ep_result['avg_duration_ms']:.2f} ms")
            
    elif args.mode == 'scaling':
        # Run benchmark with current world size
        benchmark_scaling([mpi.get_size()])
        
    elif args.mode == 'detailed':
        # Run detailed scaling experiments
        run_scaling_experiment(output_dir=args.output_dir)


if __name__ == "__main__":
    main()
Editor is loading...
Leave a Comment