PyTorch Lightning

PyTorch Lightning provides lifecycle hooks that make metric logging seamless. Use the on_train_epoch_end hook to automatically log metrics after each training epoch without cluttering your training loop.


Quick Example

import pytorch_lightning as pl
import json


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Your model setup...

    def training_step(self, batch, batch_idx):
        # Your training logic...
        loss = self.compute_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # Your validation logic...
        loss = self.compute_loss(batch)
        acc = self.compute_accuracy(batch)
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def on_train_epoch_end(self):
        # Log to Valohai after each epoch
        metrics = self.trainer.callback_metrics

        print(
            json.dumps(
                {
                    "epoch": self.current_epoch,
                    "train_loss": float(metrics.get("train_loss", 0)),
                    "val_loss": float(metrics.get("val_loss", 0)),
                    "val_acc": float(metrics.get("val_acc", 0)),
                },
            ),
        )

Why Use Hooks?

PyTorch Lightning already tracks metrics internally. The hook pattern lets you:

  • Access all logged metrics in one place

  • Log to Valohai without modifying training logic

  • Keep your code clean and framework-idiomatic


Complete Working Example

Here's a full training script with Valohai integration:


valohai.yaml Configuration

Make sure to change the input data and environment to match your own values.


Alternative: Using Callbacks

If you prefer not to modify your LightningModule, create a custom callback:

Benefits of the callback approach:

  • No changes to your LightningModule

  • Reusable across projects

  • Easier to enable/disable Valohai logging


Logging Custom Metrics

You can log any metric you compute:


Available Hooks

PyTorch Lightning provides many hooks you can use:

Training hooks:

  • on_train_epoch_start — Before each epoch

  • on_train_epoch_end — After each epoch (most common)

  • on_train_batch_end — After each batch (use sparingly)

Validation hooks:

  • on_validation_epoch_end — After validation completes

Testing hooks:

  • on_test_epoch_end — After testing completes

Example using multiple hooks:


Best Practices

Convert Tensors to Python Types

PyTorch Lightning metrics are often tensors. Convert them before logging:


Use Consistent Metric Names

Keep metric names consistent with PyTorch Lightning's conventions:


Don't Log Every Batch

Logging after every batch creates too much data:


Log Hyperparameters at Start

Log hyperparameters once before training begins:


Common Issues

Metrics Not Available in Hook

Symptom: KeyError when accessing metrics

Cause: Metric not logged in training_step or validation_step

Solution: Make sure you call self.log() for metrics you want to access:


Tensor Serialization Errors

Symptom: TypeError: Object of type Tensor is not JSON serializable

Solution: Convert tensors to Python types:


Example Project

Check out our complete working example on GitHub:

valohai/pytorch-lightning-example

The repository includes:

  • Complete training script with Valohai integration

  • valohai.yaml configuration

  • Example notebooks

  • Step-by-step setup instructions


Next Steps

Last updated

Was this helpful?