PyTorch Lightning

PyTorch Lightning provides lifecycle hooks that make metric logging seamless. Use the on_train_epoch_end hook to automatically log metrics after each training epoch without cluttering your training loop.


Quick Example

import pytorch_lightning as pl
import json

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Your model setup...
    
    def training_step(self, batch, batch_idx):
        # Your training logic...
        loss = self.compute_loss(batch)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Your validation logic...
        loss = self.compute_loss(batch)
        acc = self.compute_accuracy(batch)
        self.log('val_loss', loss)
        self.log('val_acc', acc)
    
    def on_train_epoch_end(self):
        # Log to Valohai after each epoch
        metrics = self.trainer.callback_metrics
        
        print(json.dumps({
            "epoch": self.current_epoch,
            "train_loss": float(metrics.get('train_loss', 0)),
            "val_loss": float(metrics.get('val_loss', 0)),
            "val_acc": float(metrics.get('val_acc', 0))
        }))

Why Use Hooks?

PyTorch Lightning already tracks metrics internally. The hook pattern lets you:

  • Access all logged metrics in one place

  • Log to Valohai without modifying training logic

  • Keep your code clean and framework-idiomatic


Complete Working Example

Here's a full training script with Valohai integration:

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import valohai

class LitMNIST(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()
        
        # Simple CNN
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # PyTorch Lightning automatically tracks this
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self):
        # Access metrics collected during the epoch
        metrics = self.trainer.callback_metrics
        
        # Log to Valohai
        metadata = {
            "epoch": self.current_epoch
        }
        
        # Log all collected metrics
        if 'train_loss' in metrics:
            metadata["train_loss"] = float(metrics['train_loss'])
        if 'val_loss' in metrics:
            metadata["val_loss"] = float(metrics['val_loss'])
        if 'val_acc' in metrics:
            metadata["val_acc"] = float(metrics['val_acc'])
        
        # Log current learning rate
        current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        metadata["learning_rate"] = current_lr
        
        print(json.dumps(metadata))
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]

# Training script
def main():
    # Prepare data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST(
        '/valohai/inputs/dataset',
        train=True,
        download=True,
        transform=transform
    )
    val_dataset = datasets.MNIST(
        '/valohai/inputs/dataset',
        train=False,
        transform=transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    # Initialize model
    model = LitMNIST(learning_rate=0.001)
    
    # Log hyperparameters at the start
    import json
    print(json.dumps({
        "model": "LitMNIST",
        "optimizer": "adam",
        "learning_rate": 0.001,
        "batch_size": 32,
        "scheduler": "StepLR"
    }))
    
    # Train
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator='auto',
        devices=1
    )
    
    trainer.fit(model, train_loader, val_loader)
    
    # Log final test results
    test_loader = DataLoader(val_dataset, batch_size=32)
    test_results = trainer.test(model, test_loader)
    
    with valohai.metadata.logger() as logger:
        logger.log("final_test_loss", test_results[0]['test_loss'])
        logger.log("final_test_acc", test_results[0]['test_acc'])
    
    # Save model
    output_path = valohai.outputs().path('model.ckpt')
    trainer.save_checkpoint(output_path)

if __name__ == '__main__':
    main()

valohai.yaml Configuration

Make sure to change the input data and environment to match your own values.

- step:
    name: train-lightning
    image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
    command:
      - pip install pytorch-lightning
      - python train.py
    inputs:
      - name: dataset
        default: dataset://training-data/latest
    parameters:
      - name: learning_rate
        type: float
        default: 0.001
      - name: batch_size
        type: integer
        default: 32
      - name: epochs
        type: integer
        default: 10
    environment: aws-eu-west-1-g4dn-xlarge

Alternative: Using Callbacks

If you prefer not to modify your LightningModule, create a custom callback:

import pytorch_lightning as pl
import json
import torch

class ValohaiMetricsCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        
        metadata = {
            "epoch": trainer.current_epoch + 1
        }
        
        # Log all metrics tracked by Lightning
        for key, value in metrics.items():
            if isinstance(value, torch.Tensor):
                value = float(value)
            metadata[key] = value
        
        # Log learning rate
        if trainer.optimizers:
            lr = trainer.optimizers[0].param_groups[0]['lr']
            metadata["learning_rate"] = lr
        
        print(json.dumps(metadata))

# Use the callback
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[ValohaiMetricsCallback()]
)

Benefits of the callback approach:

  • No changes to your LightningModule

  • Reusable across projects

  • Easier to enable/disable Valohai logging


Logging Custom Metrics

You can log any metric you compute:

def on_train_epoch_end(self):
    metadata = {
        "epoch": self.current_epoch + 1,
        "train_loss": float(self.trainer.callback_metrics['train_loss'])
    }
    
    # Custom computations
    model_size_mb = sum(p.numel() for p in self.parameters()) * 4 / 1024 / 1024
    metadata["model_size_mb"] = model_size_mb
    
    # Gradient norms
    total_norm = 0
    for p in self.parameters():
        if p.grad is not None:
            total_norm += p.grad.data.norm(2).item() ** 2
    total_norm = total_norm ** 0.5
    metadata["gradient_norm"] = total_norm
    
    print(json.dumps(metadata))

Available Hooks

PyTorch Lightning provides many hooks you can use:

Training hooks:

  • on_train_epoch_start — Before each epoch

  • on_train_epoch_end — After each epoch (most common)

  • on_train_batch_end — After each batch (use sparingly)

Validation hooks:

  • on_validation_epoch_end — After validation completes

Testing hooks:

  • on_test_epoch_end — After testing completes

Example using multiple hooks:

def on_train_epoch_start(self):
    # Log at the start of each epoch
    print(json.dumps({
        "epoch_start": self.current_epoch + 1,
        "learning_rate": self.trainer.optimizers[0].param_groups[0]['lr']
    }))

def on_train_epoch_end(self):
    # Log at the end of each epoch
    print(json.dumps({
        "epoch": self.current_epoch + 1,
        "train_loss": float(self.trainer.callback_metrics['train_loss'])
    }))

def on_validation_epoch_end(self):
    # Log validation results
    print(json.dumps({
        "val_loss": float(self.trainer.callback_metrics['val_loss']),
        "val_acc": float(self.trainer.callback_metrics['val_acc'])
    }))

Best Practices

Convert Tensors to Python Types

PyTorch Lightning metrics are often tensors. Convert them before logging:

# ✅ Good: Convert to float
metadata["loss"] = float(metrics['train_loss'])

# ❌ Avoid: Logging tensors directly
metadata["loss"] = metrics['train_loss']  # May cause JSON serialization issues

Use Consistent Metric Names

Keep metric names consistent with PyTorch Lightning's conventions:

# ✅ Good: Clear naming
self.log('train_loss', loss)  # In training_step
logger.log("train_loss", float(metrics['train_loss']))  # In hook

# ❌ Avoid: Renaming metrics
self.log('loss', loss)  # In training_step
logger.log("training_loss", float(metrics['loss']))  # In hook

Don't Log Every Batch

Logging after every batch creates too much data:

# ✅ Good: Log per epoch
def on_train_epoch_end(self):
    with valohai.metadata.logger() as logger:
        logger.log("epoch", self.current_epoch + 1)
        logger.log("train_loss", float(self.trainer.callback_metrics['train_loss']))

# ❌ Avoid: Log every batch (unless debugging)
def on_train_batch_end(self, outputs, batch, batch_idx):
    with valohai.metadata.logger() as logger:
        logger.log("batch", batch_idx)
        logger.log("loss", float(outputs['loss']))

Log Hyperparameters at Start

Log hyperparameters once before training begins:

def main():
    model = LitModel(learning_rate=0.001)
    
    # Log hyperparameters
    import json
    print(json.dumps({
        "model": model.__class__.__name__,
        "learning_rate": model.hparams.learning_rate,
        "optimizer": "adam",
        "batch_size": 32
    }))
    
    trainer = pl.Trainer(max_epochs=10)
    trainer.fit(model, train_loader, val_loader)

Common Issues

Metrics Not Available in Hook

Symptom: KeyError when accessing metrics

Cause: Metric not logged in training_step or validation_step

Solution: Make sure you call self.log() for metrics you want to access:

def validation_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    acc = self.compute_accuracy(batch)
    
    # Log metrics so they're available in hooks
    self.log('val_loss', loss)
    self.log('val_acc', acc)

def on_train_epoch_end(self):
    metrics = self.trainer.callback_metrics
    
    # Now these are available
    if 'val_loss' in metrics:
        with valohai.metadata.logger() as logger:
            logger.log("val_loss", float(metrics['val_loss']))

Tensor Serialization Errors

Symptom: TypeError: Object of type Tensor is not JSON serializable

Solution: Convert tensors to Python types:

# ✅ Good
logger.log("loss", float(metrics['train_loss']))

# ❌ Wrong
logger.log("loss", metrics['train_loss'])  # Still a tensor

Example Project

Check out our complete working example on GitHub:

valohai/pytorch-lightning-example

The repository includes:

  • Complete training script with Valohai integration

  • valohai.yaml configuration

  • Example notebooks

  • Step-by-step setup instructions


Next Steps

Last updated

Was this helpful?