> For the complete documentation index, see [llms.txt](https://docs.valohai.com/llms.txt). Markdown versions of documentation pages are available by appending `.md` to page URLs; this page is available as [Markdown](https://docs.valohai.com/experiment-tracking/collect-metrics/tensorflow.md).

# TensorFlow/Keras

TensorFlow and Keras provide a callback system that makes metric logging clean and automatic. Create a custom callback to log metrics at the end of each epoch without cluttering your training code.

***

## Quick Example

```python
import tensorflow as tf
import valohai


class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)
            logger.log("accuracy", logs["accuracy"])
            logger.log("loss", logs["loss"])
            logger.log("val_accuracy", logs["val_accuracy"])
            logger.log("val_loss", logs["val_loss"])


# Use the callback
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10,
    callbacks=[ValohaiMetricsCallback()],
)
```

***

## Why Use Callbacks?

Keras callbacks run at specific points during training. They let you:

* Access all training metrics automatically
* Keep metric logging separate from model code
* Reuse the same callback across projects

***

## Complete Working Example

Here's a full training script with Valohai integration:

```python
import numpy as np
import tensorflow as tf
import valohai
import json


class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    """Log training metrics to Valohai after each epoch"""

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)

            # Log all available metrics
            for key, value in logs.items():
                logger.log(key, float(value))


def main():
    # Configure Valohai
    valohai.prepare(
        step="train-model",
        image="tensorflow/tensorflow:2.13.0",
        default_inputs={
            "dataset": "https://valohaidemo.blob.core.windows.net/mnist/preprocessed_mnist.npz",
        },
        default_parameters={
            "learning_rate": 0.001,
            "epochs": 10,
            "batch_size": 32,
        },
    )

    # Load data from Valohai inputs
    input_path = valohai.inputs("dataset").path()
    with np.load(input_path, allow_pickle=True) as f:
        x_train, y_train = f["x_train"], f["y_train"]
        x_test, y_test = f["x_test"], f["y_test"]

    # Normalize
    x_train = x_train / 255.0
    x_test = x_test / 255.0

    # Build model
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation="softmax"),
        ],
    )

    # Compile model
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=valohai.parameters("learning_rate").value,
    )
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # Log hyperparameters
    print(
        json.dumps(
            {
                "model": "Sequential",
                "optimizer": "adam",
                "learning_rate": valohai.parameters("learning_rate").value,
                "batch_size": valohai.parameters("batch_size").value,
                "loss": "sparse_categorical_crossentropy",
            },
        ),
    )

    # Train with Valohai callback
    history = model.fit(
        x_train,
        y_train,
        batch_size=valohai.parameters("batch_size").value,
        epochs=valohai.parameters("epochs").value,
        validation_split=0.1,
        callbacks=[ValohaiMetricsCallback()],
    )

    # Evaluate and log test results
    test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=2)

    with valohai.metadata.logger() as logger:
        logger.log("final_test_accuracy", test_accuracy)
        logger.log("final_test_loss", test_loss)

    # Save model
    output_path = valohai.outputs().path("model.h5")
    model.save(output_path)

    print(f"Model saved to {output_path}")


if __name__ == "__main__":
    main()
```

***

## valohai.yaml Configuration

Make sure to change the input data and environment to match your own values.

```yaml
- step:
    name: train-tensorflow
    image: tensorflow/tensorflow:2.13.0
    command:
      - pip install valohai-utils
      - python train.py {parameters}
    inputs:
      - name: dataset
        default: dataset://training-data/latest
    parameters:
      - name: learning_rate
        type: float
        default: 0.001
      - name: batch_size
        type: integer
        default: 32
      - name: epochs
        type: integer
        default: 10
```

***

## Logging Without valohai-utils

You can also log metrics using plain JSON:

```python
import json
import tensorflow as tf


class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        metadata = {
            "epoch": epoch + 1,
            "accuracy": float(logs.get("accuracy", 0)),
            "loss": float(logs.get("loss", 0)),
            "val_accuracy": float(logs.get("val_accuracy", 0)),
            "val_loss": float(logs.get("val_loss", 0)),
        }

        print(json.dumps(metadata))


# Use the callback
model.fit(
    x_train,
    y_train,
    epochs=10,
    validation_split=0.1,
    callbacks=[ValohaiMetricsCallback()],
)
```

***

## Logging Learning Rate

Track learning rate changes during training:

```python
class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)

            # Log training metrics
            for key, value in logs.items():
                logger.log(key, float(value))

            # Log current learning rate
            if hasattr(self.model.optimizer, "learning_rate"):
                lr = self.model.optimizer.learning_rate
                if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
                    lr = lr(self.model.optimizer.iterations)
                logger.log("learning_rate", float(lr))
```

***

## Combining Multiple Callbacks

Use multiple callbacks together:

```python
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Valohai metrics callback
valohai_callback = ValohaiMetricsCallback()

# Early stopping
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
)

# Model checkpointing
checkpoint = ModelCheckpoint(
    valohai.outputs().path("best_model.h5"),
    monitor="val_accuracy",
    save_best_only=True,
)

# Train with all callbacks
model.fit(
    x_train,
    y_train,
    epochs=50,
    validation_split=0.1,
    callbacks=[valohai_callback, early_stop, checkpoint],
)
```

***

## Logging Custom Metrics

Add your own computed metrics:

```python
class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)

            # Standard metrics
            for key, value in logs.items():
                logger.log(key, float(value))

            # Custom: Training/validation gap
            if "loss" in logs and "val_loss" in logs:
                gap = logs["val_loss"] - logs["loss"]
                logger.log("train_val_gap", float(gap))

            # Custom: Accuracy improvement
            if "val_accuracy" in logs and hasattr(self, "prev_val_acc"):
                improvement = logs["val_accuracy"] - self.prev_val_acc
                logger.log("accuracy_improvement", float(improvement))

            self.prev_val_acc = logs.get("val_accuracy", 0)
```

***

## Using LambdaCallback (Shorter Syntax)

For simple logging, use `LambdaCallback`:

```python
import tensorflow as tf
import valohai


def log_metrics(epoch, logs):
    with valohai.metadata.logger() as logger:
        logger.log("epoch", epoch + 1)
        logger.log("accuracy", logs["accuracy"])
        logger.log("loss", logs["loss"])
        logger.log("val_accuracy", logs["val_accuracy"])
        logger.log("val_loss", logs["val_loss"])


# Create callback
metrics_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_metrics)

# Train
model.fit(
    x_train,
    y_train,
    epochs=10,
    validation_split=0.1,
    callbacks=[metrics_callback],
)
```

***

## Logging Per-Batch Metrics (Advanced)

For very long epochs, you might want to log progress mid-epoch:

```python
class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_every_n_batches=100):
        super().__init__()
        self.log_every_n_batches = log_every_n_batches
        self.batch_count = 0

    def on_batch_end(self, batch, logs=None):
        self.batch_count += 1

        # Only log every N batches
        if self.batch_count % self.log_every_n_batches == 0:
            logs = logs or {}
            with valohai.metadata.logger() as logger:
                logger.log("batch", self.batch_count)
                logger.log("batch_loss", float(logs.get("loss", 0)))
                logger.log("batch_accuracy", float(logs.get("accuracy", 0)))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)
            for key, value in logs.items():
                logger.log(key, float(value))
```

**Use sparingly:** Logging every batch creates a lot of data. Only use for debugging or very long epochs.

***

## Best Practices

### Always Convert to Python Types

Keras metrics are NumPy types. Convert to Python types for JSON serialization:

```python
# ✅ Good: Convert to float
logger.log("accuracy", float(logs["accuracy"]))

# ❌ Avoid: NumPy types
logger.log("accuracy", logs["accuracy"])  # May cause issues
```

***

### Handle Missing Metrics

Not all metrics are available in every callback:

```python
def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}

    with valohai.metadata.logger() as logger:
        logger.log("epoch", epoch + 1)

        # Safe: Check if metric exists
        if "val_loss" in logs:
            logger.log("val_loss", float(logs["val_loss"]))

        # Or use .get() with default
        logger.log("loss", float(logs.get("loss", 0)))
```

***

### Use Descriptive Metric Names

Keep names consistent with Keras conventions:

```python
# ✅ Good: Standard Keras names
"loss"

"val_loss"
"accuracy"
"val_accuracy"

# ❌ Avoid: Custom abbreviations
"acc"
"vloss"
"train_a"
```

***

## Common Issues

### Metrics Not Appearing

**Symptom:** Callback runs but no metrics in Valohai

**Causes & Fixes:**

* Missing `validation_data` → Add validation split or data
* Incorrect metric names → Check available keys in `logs`
* JSON serialization error → Convert NumPy/Tensor types to float

**Debug:**

```python
def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    print(f"Available metrics: {list(logs.keys())}")  # Debug

    with valohai.metadata.logger() as logger:
        logger.log("epoch", epoch + 1)
        for key, value in logs.items():
            try:
                logger.log(key, float(value))
            except Exception as e:
                print(f"Failed to log {key}: {e}")
```

***

### Validation Metrics Missing

**Symptom:** Only training metrics logged, no validation metrics

**Solution:** Make sure you provide validation data:

```python
# ✅ Good: Includes validation
model.fit(
    x_train,
    y_train,
    validation_split=0.1,  # or validation_data=(x_val, y_val)
    callbacks=[ValohaiMetricsCallback()],
)

# ❌ No validation metrics
model.fit(
    x_train,
    y_train,
    callbacks=[ValohaiMetricsCallback()],  # logs won't have val_* metrics
)
```

***

## Example Project

Check out our complete working example on GitHub:

[**valohai/tensorflow-example**](https://github.com/valohai/tensorflow-example)

The repository includes:

* Complete training script with Valohai integration
* `valohai.yaml` configuration
* Example notebooks
* Step-by-step setup instructions

***

## Next Steps

* [Visualize your metrics](/experiment-tracking/visualize-metrics.md) in Valohai
* [Compare experiments](/experiment-tracking/compare-executions.md) to find the best hyperparameters
* Learn more about [Keras callbacks](https://keras.io/guides/writing_your_own_callbacks/)
* Back to [Collect Metrics overview](/experiment-tracking/collect-metrics.md)


---

# Agent Instructions
This documentation is published with GitBook. GitBook is the documentation platform designed so that both humans and AI agents can read, navigate, and reason over technical content effectively. Learn more at gitbook.com.

## Querying This Documentation
If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://docs.valohai.com/experiment-tracking/collect-metrics/tensorflow.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
