# PyTorch Lightning

PyTorch Lightning provides lifecycle hooks that make metric logging seamless. Use the `on_train_epoch_end` hook to automatically log metrics after each training epoch without cluttering your training loop.

***

### Quick Example

```python
import pytorch_lightning as pl
import json


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Your model setup...

    def training_step(self, batch, batch_idx):
        # Your training logic...
        loss = self.compute_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # Your validation logic...
        loss = self.compute_loss(batch)
        acc = self.compute_accuracy(batch)
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def on_train_epoch_end(self):
        # Log to Valohai after each epoch
        metrics = self.trainer.callback_metrics

        print(
            json.dumps(
                {
                    "epoch": self.current_epoch,
                    "train_loss": float(metrics.get("train_loss", 0)),
                    "val_loss": float(metrics.get("val_loss", 0)),
                    "val_acc": float(metrics.get("val_acc", 0)),
                },
            ),
        )
```

***

### Why Use Hooks?

PyTorch Lightning already tracks metrics internally. The hook pattern lets you:

* Access all logged metrics in one place
* Log to Valohai without modifying training logic
* Keep your code clean and framework-idiomatic

***

### Complete Working Example

Here's a full training script with Valohai integration:

```python
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import valohai


class LitMNIST(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()

        # Simple CNN
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # PyTorch Lightning automatically tracks this
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        # Access metrics collected during the epoch
        metrics = self.trainer.callback_metrics

        # Log to Valohai
        metadata = {
            "epoch": self.current_epoch,
        }

        # Log all collected metrics
        if "train_loss" in metrics:
            metadata["train_loss"] = float(metrics["train_loss"])
        if "val_loss" in metrics:
            metadata["val_loss"] = float(metrics["val_loss"])
        if "val_acc" in metrics:
            metadata["val_acc"] = float(metrics["val_acc"])

        # Log current learning rate
        current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        metadata["learning_rate"] = current_lr

        print(json.dumps(metadata))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]


# Training script
def main():
    # Prepare data
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ],
    )

    train_dataset = datasets.MNIST(
        "/valohai/inputs/dataset",
        train=True,
        download=True,
        transform=transform,
    )
    val_dataset = datasets.MNIST(
        "/valohai/inputs/dataset",
        train=False,
        transform=transform,
    )

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)

    # Initialize model
    model = LitMNIST(learning_rate=0.001)

    # Log hyperparameters at the start
    import json

    print(
        json.dumps(
            {
                "model": "LitMNIST",
                "optimizer": "adam",
                "learning_rate": 0.001,
                "batch_size": 32,
                "scheduler": "StepLR",
            },
        ),
    )

    # Train
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1,
    )

    trainer.fit(model, train_loader, val_loader)

    # Log final test results
    test_loader = DataLoader(val_dataset, batch_size=32)
    test_results = trainer.test(model, test_loader)

    with valohai.metadata.logger() as logger:
        logger.log("final_test_loss", test_results[0]["test_loss"])
        logger.log("final_test_acc", test_results[0]["test_acc"])

    # Save model
    output_path = valohai.outputs().path("model.ckpt")
    trainer.save_checkpoint(output_path)


if __name__ == "__main__":
    main()
```

***

### valohai.yaml Configuration

Make sure to change the input data and environment to match your own values.

```yaml
- step:
    name: train-lightning
    image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
    command:
      - pip install pytorch-lightning
      - python train.py
    inputs:
      - name: dataset
        default: dataset://training-data/latest
    parameters:
      - name: learning_rate
        type: float
        default: 0.001
      - name: batch_size
        type: integer
        default: 32
      - name: epochs
        type: integer
        default: 10
    environment: aws-eu-west-1-g4dn-xlarge
```

***

### Alternative: Using Callbacks

If you prefer not to modify your LightningModule, create a custom callback:

```python
import pytorch_lightning as pl
import json
import torch


class ValohaiMetricsCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics

        metadata = {
            "epoch": trainer.current_epoch + 1,
        }

        # Log all metrics tracked by Lightning
        for key, value in metrics.items():
            if isinstance(value, torch.Tensor):
                value = float(value)
            metadata[key] = value

        # Log learning rate
        if trainer.optimizers:
            lr = trainer.optimizers[0].param_groups[0]["lr"]
            metadata["learning_rate"] = lr

        print(json.dumps(metadata))


# Use the callback
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[ValohaiMetricsCallback()],
)
```

**Benefits of the callback approach:**

* No changes to your LightningModule
* Reusable across projects
* Easier to enable/disable Valohai logging

***

### Logging Custom Metrics

You can log any metric you compute:

```python
def on_train_epoch_end(self):
    metadata = {
        "epoch": self.current_epoch + 1,
        "train_loss": float(self.trainer.callback_metrics["train_loss"]),
    }

    # Custom computations
    model_size_mb = sum(p.numel() for p in self.parameters()) * 4 / 1024 / 1024
    metadata["model_size_mb"] = model_size_mb

    # Gradient norms
    total_norm = 0
    for p in self.parameters():
        if p.grad is not None:
            total_norm += p.grad.data.norm(2).item() ** 2
    total_norm = total_norm**0.5
    metadata["gradient_norm"] = total_norm

    print(json.dumps(metadata))
```

***

### Available Hooks

PyTorch Lightning provides many hooks you can use:

**Training hooks:**

* `on_train_epoch_start` — Before each epoch
* `on_train_epoch_end` — After each epoch (most common)
* `on_train_batch_end` — After each batch (use sparingly)

**Validation hooks:**

* `on_validation_epoch_end` — After validation completes

**Testing hooks:**

* `on_test_epoch_end` — After testing completes

**Example using multiple hooks:**

```python
def on_train_epoch_start(self):
    # Log at the start of each epoch
    print(
        json.dumps(
            {
                "epoch_start": self.current_epoch + 1,
                "learning_rate": self.trainer.optimizers[0].param_groups[0]["lr"],
            },
        ),
    )


def on_train_epoch_end(self):
    # Log at the end of each epoch
    print(
        json.dumps(
            {
                "epoch": self.current_epoch + 1,
                "train_loss": float(self.trainer.callback_metrics["train_loss"]),
            },
        ),
    )


def on_validation_epoch_end(self):
    # Log validation results
    print(
        json.dumps(
            {
                "val_loss": float(self.trainer.callback_metrics["val_loss"]),
                "val_acc": float(self.trainer.callback_metrics["val_acc"]),
            },
        ),
    )
```

***

### Best Practices

#### Convert Tensors to Python Types

PyTorch Lightning metrics are often tensors. Convert them before logging:

```python
# ✅ Good: Convert to float
metadata["loss"] = float(metrics["train_loss"])

# ❌ Avoid: Logging tensors directly
metadata["loss"] = metrics["train_loss"]  # May cause JSON serialization issues
```

***

#### Use Consistent Metric Names

Keep metric names consistent with PyTorch Lightning's conventions:

```python
# ✅ Good: Clear naming
self.log("train_loss", loss)  # In training_step
logger.log("train_loss", float(metrics["train_loss"]))  # In hook

# ❌ Avoid: Renaming metrics
self.log("loss", loss)  # In training_step
logger.log("training_loss", float(metrics["loss"]))  # In hook
```

***

#### Don't Log Every Batch

Logging after every batch creates too much data:

```python
# ✅ Good: Log per epoch
def on_train_epoch_end(self):
    with valohai.metadata.logger() as logger:
        logger.log("epoch", self.current_epoch + 1)
        logger.log("train_loss", float(self.trainer.callback_metrics["train_loss"]))


# ❌ Avoid: Log every batch (unless debugging)
def on_train_batch_end(self, outputs, batch, batch_idx):
    with valohai.metadata.logger() as logger:
        logger.log("batch", batch_idx)
        logger.log("loss", float(outputs["loss"]))
```

***

#### Log Hyperparameters at Start

Log hyperparameters once before training begins:

```python
def main():
    model = LitModel(learning_rate=0.001)

    # Log hyperparameters
    import json

    print(
        json.dumps(
            {
                "model": model.__class__.__name__,
                "learning_rate": model.hparams.learning_rate,
                "optimizer": "adam",
                "batch_size": 32,
            },
        ),
    )

    trainer = pl.Trainer(max_epochs=10)
    trainer.fit(model, train_loader, val_loader)
```

***

### Common Issues

#### Metrics Not Available in Hook

**Symptom:** `KeyError` when accessing metrics

**Cause:** Metric not logged in `training_step` or `validation_step`

**Solution:** Make sure you call `self.log()` for metrics you want to access:

```python
def validation_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    acc = self.compute_accuracy(batch)

    # Log metrics so they're available in hooks
    self.log("val_loss", loss)
    self.log("val_acc", acc)


def on_train_epoch_end(self):
    metrics = self.trainer.callback_metrics

    # Now these are available
    if "val_loss" in metrics:
        with valohai.metadata.logger() as logger:
            logger.log("val_loss", float(metrics["val_loss"]))
```

***

#### Tensor Serialization Errors

**Symptom:** `TypeError: Object of type Tensor is not JSON serializable`

**Solution:** Convert tensors to Python types:

```python
# ✅ Good
logger.log("loss", float(metrics["train_loss"]))

# ❌ Wrong
logger.log("loss", metrics["train_loss"])  # Still a tensor
```

***

### Example Project

Check out our complete working example on GitHub:

[**valohai/pytorch-lightning-example**](https://github.com/valohai/valohai-pytorch-lightning-example)

The repository includes:

* Complete training script with Valohai integration
* `valohai.yaml` configuration
* Example notebooks
* Step-by-step setup instructions

***

### Next Steps

* [Visualize your metrics](/experiment-tracking/visualize-metrics.md) in Valohai
* [Compare experiments](/experiment-tracking/compare-executions.md) to find the best hyperparameters
* Learn more about [PyTorch Lightning callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html)
* Back to [Collect Metrics overview](/experiment-tracking/collect-metrics.md)


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://docs.valohai.com/experiment-tracking/collect-metrics/pytorch-lightning.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
