import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision import datasets, transforms
from tqdm import tqdm
import yaml
from pathlib import Path
import json
def get_dataloaders(batch_size=64):
transform = transforms.Compose([transforms.ToTensor()])
train = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test = datasets.MNIST(root="data", train=False, download=True, transform=transform)
return DataLoader(train, batch_size=batch_size, shuffle=True), DataLoader(test, batch_size=batch_size)
def load_config(config_path="experiments/config.yaml"):
with open(config_path) as f:
return yaml.safe_load(f)
def get_dataloaders(config):
transform_list = [transforms.ToTensor()]
if config['data'].get('normalize', True):
transform_list.append(transforms.Normalize((0.1307,), (0.3081,)))
if config['data']['augmentation'].get('random_rotation'):
transform_list.append(transforms.RandomRotation(
config['data']['augmentation']['random_rotation']
))
transform = transforms.Compose(transform_list)
full_train = datasets.MNIST(root="data", train=True, download=True, transform=transform)
train_size = int(0.8 * len(full_train))
val_size = len(full_train) - train_size
train_dataset, val_dataset = random_split(full_train, [train_size, val_size])
test_dataset = datasets.MNIST(root="data", train=False, download=True, transform=transform)
batch_size = config['training']['batch_size']
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
return train_loader, val_loader, test_loader
class MLP(nn.Module):
def __init__(self, hidden=128):
def __init__(self, config):
super().__init__()
hidden = config['model']['hidden_size']
dropout = config['model']['dropout']
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, hidden // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, 10),
nn.Linear(hidden // 2, 10),
)
def forward(self, x):
return self.net(x)
def train_model(epochs=1, lr=1e-3, device=None):
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model = MLP().to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
def train_model(config_path="experiments/config.yaml"):
config = load_config(config_path)
device = "cuda" if torch.cuda.is_available() and config['training']['use_amp'] else "cpu"
torch.manual_seed(42)
if device == "cuda":
torch.cuda.manual_seed_all(42)
model = MLP(config).to(device)
opt = torch.optim.Adam(
model.parameters(),
lr=config['training']['learning_rate'],
weight_decay=config['training']['weight_decay']
)
loss_fn = nn.CrossEntropyLoss()
train_loader, _ = get_dataloaders()
+ # Seed for reproducibility
+ torch.manual_seed(42)
+ if device == "cuda":
+ torch.cuda.manual_seed_all(42)
+ # AMP + Scheduler
+ scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
model.train()
for epoch in range(epochs):
total, correct = 0, 0
for x, y in tqdm(train_loader, desc=f"epoch {epoch+1}"):
train_loader, val_loader, test_loader = get_dataloaders(config)
use_amp = config['training']['use_amp'] and device == "cuda"
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=config['training']['epochs']
)
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
for epoch in range(config['training']['epochs']):
model.train()
train_loss, train_correct, train_total = 0, 0, 0
for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['training']['epochs']}"):
x, y = x.to(device), y.to(device)
opt.zero_grad(set_to_none=True)
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
opt.step()
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(x)
loss = loss_fn(logits, y)
scaler.scale(loss).backward()
if config['training']['gradient_clip'] > 0:
scaler.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), config['training']['gradient_clip'])
scaler.step(opt)
scaler.update()
+ preds = logits.argmax(dim=1)
+ total += y.size(0)
+ correct += (preds == y).sum().item()
+ acc = correct / max(1, total)
train_loss += loss.item() * x.size(0)
train_correct += (logits.argmax(1) == y).sum().item()
train_total += x.size(0)
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = loss_fn(logits, y)
val_loss += loss.item() * x.size(0)
val_correct += (logits.argmax(1) == y).sum().item()
val_total += x.size(0)
train_loss = train_loss / train_total
train_acc = train_correct / train_total
val_loss = val_loss / val_total
val_acc = val_correct / val_total
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f}")
scheduler.step()
+ print(f"epoch {epoch+1}: acc={acc:.3f}")
return model`,
if (epoch + 1) % 5 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'history': history,
'config': config
}
Path('checkpoints').mkdir(exist_ok=True)
torch.save(checkpoint, f'checkpoints/model_epoch_{epoch+1}.pt')
Path('results').mkdir(exist_ok=True)
with open('results/training_history.json', 'w') as f:
json.dump(history, f, indent=2)
return model, history, test_loader