PyTorch Lightning
PyTorch Lightning provides lifecycle hooks that make metric logging seamless. Use the on_train_epoch_end hook to automatically log metrics after each training epoch without cluttering your training loop.
Quick Example
import pytorch_lightning as pl
import json
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Your model setup...
def training_step(self, batch, batch_idx):
# Your training logic...
loss = self.compute_loss(batch)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
# Your validation logic...
loss = self.compute_loss(batch)
acc = self.compute_accuracy(batch)
self.log('val_loss', loss)
self.log('val_acc', acc)
def on_train_epoch_end(self):
# Log to Valohai after each epoch
metrics = self.trainer.callback_metrics
print(json.dumps({
"epoch": self.current_epoch,
"train_loss": float(metrics.get('train_loss', 0)),
"val_loss": float(metrics.get('val_loss', 0)),
"val_acc": float(metrics.get('val_acc', 0))
}))Why Use Hooks?
PyTorch Lightning already tracks metrics internally. The hook pattern lets you:
Access all logged metrics in one place
Log to Valohai without modifying training logic
Keep your code clean and framework-idiomatic
Complete Working Example
Here's a full training script with Valohai integration:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import valohai
class LitMNIST(pl.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.save_hyperparameters()
# Simple CNN
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
self.dropout1 = torch.nn.Dropout(0.25)
self.dropout2 = torch.nn.Dropout(0.5)
self.fc1 = torch.nn.Linear(9216, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# PyTorch Lightning automatically tracks this
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def on_train_epoch_end(self):
# Access metrics collected during the epoch
metrics = self.trainer.callback_metrics
# Log to Valohai
metadata = {
"epoch": self.current_epoch
}
# Log all collected metrics
if 'train_loss' in metrics:
metadata["train_loss"] = float(metrics['train_loss'])
if 'val_loss' in metrics:
metadata["val_loss"] = float(metrics['val_loss'])
if 'val_acc' in metrics:
metadata["val_acc"] = float(metrics['val_acc'])
# Log current learning rate
current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
metadata["learning_rate"] = current_lr
print(json.dumps(metadata))
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
return [optimizer], [scheduler]
# Training script
def main():
# Prepare data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(
'/valohai/inputs/dataset',
train=True,
download=True,
transform=transform
)
val_dataset = datasets.MNIST(
'/valohai/inputs/dataset',
train=False,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# Initialize model
model = LitMNIST(learning_rate=0.001)
# Log hyperparameters at the start
import json
print(json.dumps({
"model": "LitMNIST",
"optimizer": "adam",
"learning_rate": 0.001,
"batch_size": 32,
"scheduler": "StepLR"
}))
# Train
trainer = pl.Trainer(
max_epochs=10,
accelerator='auto',
devices=1
)
trainer.fit(model, train_loader, val_loader)
# Log final test results
test_loader = DataLoader(val_dataset, batch_size=32)
test_results = trainer.test(model, test_loader)
with valohai.metadata.logger() as logger:
logger.log("final_test_loss", test_results[0]['test_loss'])
logger.log("final_test_acc", test_results[0]['test_acc'])
# Save model
output_path = valohai.outputs().path('model.ckpt')
trainer.save_checkpoint(output_path)
if __name__ == '__main__':
main()valohai.yaml Configuration
Make sure to change the input data and environment to match your own values.
- step:
name: train-lightning
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
command:
- pip install pytorch-lightning
- python train.py
inputs:
- name: dataset
default: dataset://training-data/latest
parameters:
- name: learning_rate
type: float
default: 0.001
- name: batch_size
type: integer
default: 32
- name: epochs
type: integer
default: 10
environment: aws-eu-west-1-g4dn-xlargeAlternative: Using Callbacks
If you prefer not to modify your LightningModule, create a custom callback:
import pytorch_lightning as pl
import json
import torch
class ValohaiMetricsCallback(pl.Callback):
def on_train_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
metadata = {
"epoch": trainer.current_epoch + 1
}
# Log all metrics tracked by Lightning
for key, value in metrics.items():
if isinstance(value, torch.Tensor):
value = float(value)
metadata[key] = value
# Log learning rate
if trainer.optimizers:
lr = trainer.optimizers[0].param_groups[0]['lr']
metadata["learning_rate"] = lr
print(json.dumps(metadata))
# Use the callback
trainer = pl.Trainer(
max_epochs=10,
callbacks=[ValohaiMetricsCallback()]
)Benefits of the callback approach:
No changes to your LightningModule
Reusable across projects
Easier to enable/disable Valohai logging
Logging Custom Metrics
You can log any metric you compute:
def on_train_epoch_end(self):
metadata = {
"epoch": self.current_epoch + 1,
"train_loss": float(self.trainer.callback_metrics['train_loss'])
}
# Custom computations
model_size_mb = sum(p.numel() for p in self.parameters()) * 4 / 1024 / 1024
metadata["model_size_mb"] = model_size_mb
# Gradient norms
total_norm = 0
for p in self.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
metadata["gradient_norm"] = total_norm
print(json.dumps(metadata))Available Hooks
PyTorch Lightning provides many hooks you can use:
Training hooks:
on_train_epoch_start— Before each epochon_train_epoch_end— After each epoch (most common)on_train_batch_end— After each batch (use sparingly)
Validation hooks:
on_validation_epoch_end— After validation completes
Testing hooks:
on_test_epoch_end— After testing completes
Example using multiple hooks:
def on_train_epoch_start(self):
# Log at the start of each epoch
print(json.dumps({
"epoch_start": self.current_epoch + 1,
"learning_rate": self.trainer.optimizers[0].param_groups[0]['lr']
}))
def on_train_epoch_end(self):
# Log at the end of each epoch
print(json.dumps({
"epoch": self.current_epoch + 1,
"train_loss": float(self.trainer.callback_metrics['train_loss'])
}))
def on_validation_epoch_end(self):
# Log validation results
print(json.dumps({
"val_loss": float(self.trainer.callback_metrics['val_loss']),
"val_acc": float(self.trainer.callback_metrics['val_acc'])
}))Best Practices
Convert Tensors to Python Types
PyTorch Lightning metrics are often tensors. Convert them before logging:
# ✅ Good: Convert to float
metadata["loss"] = float(metrics['train_loss'])
# ❌ Avoid: Logging tensors directly
metadata["loss"] = metrics['train_loss'] # May cause JSON serialization issuesUse Consistent Metric Names
Keep metric names consistent with PyTorch Lightning's conventions:
# ✅ Good: Clear naming
self.log('train_loss', loss) # In training_step
logger.log("train_loss", float(metrics['train_loss'])) # In hook
# ❌ Avoid: Renaming metrics
self.log('loss', loss) # In training_step
logger.log("training_loss", float(metrics['loss'])) # In hookDon't Log Every Batch
Logging after every batch creates too much data:
# ✅ Good: Log per epoch
def on_train_epoch_end(self):
with valohai.metadata.logger() as logger:
logger.log("epoch", self.current_epoch + 1)
logger.log("train_loss", float(self.trainer.callback_metrics['train_loss']))
# ❌ Avoid: Log every batch (unless debugging)
def on_train_batch_end(self, outputs, batch, batch_idx):
with valohai.metadata.logger() as logger:
logger.log("batch", batch_idx)
logger.log("loss", float(outputs['loss']))Log Hyperparameters at Start
Log hyperparameters once before training begins:
def main():
model = LitModel(learning_rate=0.001)
# Log hyperparameters
import json
print(json.dumps({
"model": model.__class__.__name__,
"learning_rate": model.hparams.learning_rate,
"optimizer": "adam",
"batch_size": 32
}))
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)Common Issues
Metrics Not Available in Hook
Symptom: KeyError when accessing metrics
Cause: Metric not logged in training_step or validation_step
Solution: Make sure you call self.log() for metrics you want to access:
def validation_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
acc = self.compute_accuracy(batch)
# Log metrics so they're available in hooks
self.log('val_loss', loss)
self.log('val_acc', acc)
def on_train_epoch_end(self):
metrics = self.trainer.callback_metrics
# Now these are available
if 'val_loss' in metrics:
with valohai.metadata.logger() as logger:
logger.log("val_loss", float(metrics['val_loss']))Tensor Serialization Errors
Symptom: TypeError: Object of type Tensor is not JSON serializable
Solution: Convert tensors to Python types:
# ✅ Good
logger.log("loss", float(metrics['train_loss']))
# ❌ Wrong
logger.log("loss", metrics['train_loss']) # Still a tensorExample Project
Check out our complete working example on GitHub:
valohai/pytorch-lightning-example
The repository includes:
Complete training script with Valohai integration
valohai.yamlconfigurationExample notebooks
Step-by-step setup instructions
Next Steps
Visualize your metrics in Valohai
Compare experiments to find the best hyperparameters
Learn more about PyTorch Lightning callbacks
Back to Collect Metrics overview
Last updated
Was this helpful?
