import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time
import random
import numpy as np

# ====================================================================
# 1. Configuration - All benchmark parameters are here
# ====================================================================
# --- Benchmark Parameters ---
NUM_SAMPLES = 100000      # Number of data samples to generate
NUM_WARMUP_EPOCHS = 5     # Number of warmup epochs (NOT timed)
NUM_EPOCHS = 50           # Number of measured training epochs
BATCH_SIZE = 256          # Batch size for the DataLoader
LEARNING_RATE = 0.1       # Learning rate for the optimizer
SEED = 42                 # Random seed for reproducibility

# ====================================================================
# 2. Helper Functions
# ====================================================================

def set_deterministic(seed):
    """Sets random seeds for reproducibility across libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        torch.cuda.synchronize()

def generate_data(num_samples, input_dim=2):
    """Generates synthetic linear data with SIGNIFICANT noise."""
    x = torch.linspace(-1, 1, num_samples * input_dim).reshape(-1, input_dim)
    # y = 3x1 + 2x2 + 1 + Noise
    # Changed noise from small sin wave to larger Gaussian noise
    # This prevents Loss from hitting 0.000000
    noise = 0.5 * torch.randn(num_samples)
    y = 3 * x[:, 0] + 2 * x[:, 1] + 1 + noise
    return TensorDataset(x, y)

class SimpleModel(nn.Module):
    """A simple linear regression model."""
    def __init__(self, input_dim=2):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x)

# ====================================================================
# 3. Main Benchmark Logic
# ====================================================================

def run_benchmark(device):
    """
    Runs the main training benchmark with warmup, prints epoch logs, and returns results.
    """
    # --- Setup ---
    set_deterministic(SEED)
    dataset = generate_data(NUM_SAMPLES)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    model = SimpleModel().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

    total_run_epochs = NUM_WARMUP_EPOCHS + NUM_EPOCHS
    start_time = 0.0 # Will be set after warmup

    # --- Core Training Loop ---
    print(f"Running {NUM_WARMUP_EPOCHS} warmup epochs + {NUM_EPOCHS} benchmark epochs...")

    for epoch in range(total_run_epochs):

        # --- Timing Logic Start ---
        # If we just finished warmup, synchronize and start the timer
        if epoch == NUM_WARMUP_EPOCHS:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start_time = time.time()
        # --- Timing Logic End ---

        epoch_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        final_avg_loss = epoch_loss / len(dataloader)

        # Print Log
        if epoch < NUM_WARMUP_EPOCHS:
            print(f"[Warmup]    Epoch {epoch+1}/{NUM_WARMUP_EPOCHS} | Average Loss: {final_avg_loss:.6f}")
        else:
            # Adjust epoch number for display so it starts at 1 for the benchmark phase
            bench_epoch = epoch - NUM_WARMUP_EPOCHS + 1
            print(f"[Benchmark] Epoch {bench_epoch}/{NUM_EPOCHS} | Average Loss: {final_avg_loss:.6f}")

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    end_time = time.time()

    # Calculate duration only for the benchmark epochs
    total_measured_time = end_time - start_time

    return total_measured_time, final_avg_loss

# ====================================================================
# 4. Execution and Reporting
# ====================================================================

def print_summary_report(device, total_time, final_loss):
    """Prints a formatted summary report for a given device."""
    device_name = f"{device} ({torch.cuda.get_device_name(device.index)})"
    # throughput calc uses NUM_EPOCHS (measured ones) not total run epochs
    total_samples_processed = NUM_SAMPLES * NUM_EPOCHS
    throughput = total_samples_processed / total_time

    report_width = 70
    key_width = 23

    print("\n" + "="*report_width)
    print("Benchmark Summary Report".center(report_width))
    print("="*report_width)
    print(f"{'Device:':<{key_width}} {device_name}")
    print("-" * report_width)
    print("Parameters:")
    print(f"  {'Data Samples:':<{key_width-2}} {NUM_SAMPLES:,}")
    print(f"  {'Warmup Epochs:':<{key_width-2}} {NUM_WARMUP_EPOCHS}")
    print(f"  {'Measured Epochs:':<{key_width-2}} {NUM_EPOCHS}")
    print(f"  {'Batch Size:':<{key_width-2}} {BATCH_SIZE}")
    print("-" * report_width)
    print("Results (Excluding Warmup):")
    print(f"  {'Measured Time:':<{key_width-2}} {total_time:.4f} seconds")
    print(f"  {'Final Average Loss:':<{key_width-2}} {final_loss:.6f}")
    print(f"  {'Throughput:':<{key_width-2}} {throughput:,.2f} samples/sec")
    print("="*report_width)


if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("Error: No NVIDIA GPU detected. This benchmark requires at least one CUDA-enabled GPU.")
    else:
        num_gpus = torch.cuda.device_count()
        print(f"Detected {num_gpus} CUDA device(s).")

        for i in range(num_gpus):
            device = torch.device(f'cuda:{i}')
            device_name = torch.cuda.get_device_name(i)

            # --- Use a cleaner, more consistent header ---
            header_width = 70
            print("\n" + "="*header_width)
            print(f"Starting Benchmark on GPU {i}: {device_name}".center(header_width))
            print("="*header_width + "\n")

            # Run the benchmark
            total_time, final_loss = run_benchmark(device)

            # Print the summary report for this device
            print_summary_report(device, total_time, final_loss)