# TensorFlow/Keras

TensorFlow and Keras provide a callback system that makes metric logging clean and automatic. Create a custom callback to log metrics at the end of each epoch without cluttering your training code.

***

## Quick Example

```python
import tensorflow as tf
import valohai


class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)
            logger.log("accuracy", logs["accuracy"])
            logger.log("loss", logs["loss"])
            logger.log("val_accuracy", logs["val_accuracy"])
            logger.log("val_loss", logs["val_loss"])


# Use the callback
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10,
    callbacks=[ValohaiMetricsCallback()],
)
```

***

## Why Use Callbacks?

Keras callbacks run at specific points during training. They let you:

* Access all training metrics automatically
* Keep metric logging separate from model code
* Reuse the same callback across projects

***

## Complete Working Example

Here's a full training script with Valohai integration:

```python
import numpy as np
import tensorflow as tf
import valohai
import json


class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    """Log training metrics to Valohai after each epoch"""

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)

            # Log all available metrics
            for key, value in logs.items():
                logger.log(key, float(value))


def main():
    # Configure Valohai
    valohai.prepare(
        step="train-model",
        image="tensorflow/tensorflow:2.13.0",
        default_inputs={
            "dataset": "https://valohaidemo.blob.core.windows.net/mnist/preprocessed_mnist.npz",
        },
        default_parameters={
            "learning_rate": 0.001,
            "epochs": 10,
            "batch_size": 32,
        },
    )

    # Load data from Valohai inputs
    input_path = valohai.inputs("dataset").path()
    with np.load(input_path, allow_pickle=True) as f:
        x_train, y_train = f["x_train"], f["y_train"]
        x_test, y_test = f["x_test"], f["y_test"]

    # Normalize
    x_train = x_train / 255.0
    x_test = x_test / 255.0

    # Build model
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation="softmax"),
        ],
    )

    # Compile model
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=valohai.parameters("learning_rate").value,
    )
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # Log hyperparameters
    print(
        json.dumps(
            {
                "model": "Sequential",
                "optimizer": "adam",
                "learning_rate": valohai.parameters("learning_rate").value,
                "batch_size": valohai.parameters("batch_size").value,
                "loss": "sparse_categorical_crossentropy",
            },
        ),
    )

    # Train with Valohai callback
    history = model.fit(
        x_train,
        y_train,
        batch_size=valohai.parameters("batch_size").value,
        epochs=valohai.parameters("epochs").value,
        validation_split=0.1,
        callbacks=[ValohaiMetricsCallback()],
    )

    # Evaluate and log test results
    test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=2)

    with valohai.metadata.logger() as logger:
        logger.log("final_test_accuracy", test_accuracy)
        logger.log("final_test_loss", test_loss)

    # Save model
    output_path = valohai.outputs().path("model.h5")
    model.save(output_path)

    print(f"Model saved to {output_path}")


if __name__ == "__main__":
    main()
```

***

## valohai.yaml Configuration

Make sure to change the input data and environment to match your own values.

```yaml
- step:
    name: train-tensorflow
    image: tensorflow/tensorflow:2.13.0
    command:
      - pip install valohai-utils
      - python train.py {parameters}
    inputs:
      - name: dataset
        default: dataset://training-data/latest
    parameters:
      - name: learning_rate
        type: float
        default: 0.001
      - name: batch_size
        type: integer
        default: 32
      - name: epochs
        type: integer
        default: 10
```

***

## Logging Without valohai-utils

You can also log metrics using plain JSON:

```python
import json
import tensorflow as tf


class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        metadata = {
            "epoch": epoch + 1,
            "accuracy": float(logs.get("accuracy", 0)),
            "loss": float(logs.get("loss", 0)),
            "val_accuracy": float(logs.get("val_accuracy", 0)),
            "val_loss": float(logs.get("val_loss", 0)),
        }

        print(json.dumps(metadata))


# Use the callback
model.fit(
    x_train,
    y_train,
    epochs=10,
    validation_split=0.1,
    callbacks=[ValohaiMetricsCallback()],
)
```

***

## Logging Learning Rate

Track learning rate changes during training:

```python
class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)

            # Log training metrics
            for key, value in logs.items():
                logger.log(key, float(value))

            # Log current learning rate
            if hasattr(self.model.optimizer, "learning_rate"):
                lr = self.model.optimizer.learning_rate
                if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
                    lr = lr(self.model.optimizer.iterations)
                logger.log("learning_rate", float(lr))
```

***

## Combining Multiple Callbacks

Use multiple callbacks together:

```python
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Valohai metrics callback
valohai_callback = ValohaiMetricsCallback()

# Early stopping
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=3,
    restore_best_weights=True,
)

# Model checkpointing
checkpoint = ModelCheckpoint(
    valohai.outputs().path("best_model.h5"),
    monitor="val_accuracy",
    save_best_only=True,
)

# Train with all callbacks
model.fit(
    x_train,
    y_train,
    epochs=50,
    validation_split=0.1,
    callbacks=[valohai_callback, early_stop, checkpoint],
)
```

***

## Logging Custom Metrics

Add your own computed metrics:

```python
class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)

            # Standard metrics
            for key, value in logs.items():
                logger.log(key, float(value))

            # Custom: Training/validation gap
            if "loss" in logs and "val_loss" in logs:
                gap = logs["val_loss"] - logs["loss"]
                logger.log("train_val_gap", float(gap))

            # Custom: Accuracy improvement
            if "val_accuracy" in logs and hasattr(self, "prev_val_acc"):
                improvement = logs["val_accuracy"] - self.prev_val_acc
                logger.log("accuracy_improvement", float(improvement))

            self.prev_val_acc = logs.get("val_accuracy", 0)
```

***

## Using LambdaCallback (Shorter Syntax)

For simple logging, use `LambdaCallback`:

```python
import tensorflow as tf
import valohai


def log_metrics(epoch, logs):
    with valohai.metadata.logger() as logger:
        logger.log("epoch", epoch + 1)
        logger.log("accuracy", logs["accuracy"])
        logger.log("loss", logs["loss"])
        logger.log("val_accuracy", logs["val_accuracy"])
        logger.log("val_loss", logs["val_loss"])


# Create callback
metrics_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_metrics)

# Train
model.fit(
    x_train,
    y_train,
    epochs=10,
    validation_split=0.1,
    callbacks=[metrics_callback],
)
```

***

## Logging Per-Batch Metrics (Advanced)

For very long epochs, you might want to log progress mid-epoch:

```python
class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_every_n_batches=100):
        super().__init__()
        self.log_every_n_batches = log_every_n_batches
        self.batch_count = 0

    def on_batch_end(self, batch, logs=None):
        self.batch_count += 1

        # Only log every N batches
        if self.batch_count % self.log_every_n_batches == 0:
            logs = logs or {}
            with valohai.metadata.logger() as logger:
                logger.log("batch", self.batch_count)
                logger.log("batch_loss", float(logs.get("loss", 0)))
                logger.log("batch_accuracy", float(logs.get("accuracy", 0)))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        with valohai.metadata.logger() as logger:
            logger.log("epoch", epoch + 1)
            for key, value in logs.items():
                logger.log(key, float(value))
```

**Use sparingly:** Logging every batch creates a lot of data. Only use for debugging or very long epochs.

***

## Best Practices

### Always Convert to Python Types

Keras metrics are NumPy types. Convert to Python types for JSON serialization:

```python
# ✅ Good: Convert to float
logger.log("accuracy", float(logs["accuracy"]))

# ❌ Avoid: NumPy types
logger.log("accuracy", logs["accuracy"])  # May cause issues
```

***

### Handle Missing Metrics

Not all metrics are available in every callback:

```python
def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}

    with valohai.metadata.logger() as logger:
        logger.log("epoch", epoch + 1)

        # Safe: Check if metric exists
        if "val_loss" in logs:
            logger.log("val_loss", float(logs["val_loss"]))

        # Or use .get() with default
        logger.log("loss", float(logs.get("loss", 0)))
```

***

### Use Descriptive Metric Names

Keep names consistent with Keras conventions:

```python
# ✅ Good: Standard Keras names
"loss"

"val_loss"
"accuracy"
"val_accuracy"

# ❌ Avoid: Custom abbreviations
"acc"
"vloss"
"train_a"
```

***

## Common Issues

### Metrics Not Appearing

**Symptom:** Callback runs but no metrics in Valohai

**Causes & Fixes:**

* Missing `validation_data` → Add validation split or data
* Incorrect metric names → Check available keys in `logs`
* JSON serialization error → Convert NumPy/Tensor types to float

**Debug:**

```python
def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    print(f"Available metrics: {list(logs.keys())}")  # Debug

    with valohai.metadata.logger() as logger:
        logger.log("epoch", epoch + 1)
        for key, value in logs.items():
            try:
                logger.log(key, float(value))
            except Exception as e:
                print(f"Failed to log {key}: {e}")
```

***

### Validation Metrics Missing

**Symptom:** Only training metrics logged, no validation metrics

**Solution:** Make sure you provide validation data:

```python
# ✅ Good: Includes validation
model.fit(
    x_train,
    y_train,
    validation_split=0.1,  # or validation_data=(x_val, y_val)
    callbacks=[ValohaiMetricsCallback()],
)

# ❌ No validation metrics
model.fit(
    x_train,
    y_train,
    callbacks=[ValohaiMetricsCallback()],  # logs won't have val_* metrics
)
```

***

## Example Project

Check out our complete working example on GitHub:

[**valohai/tensorflow-example**](https://github.com/valohai/tensorflow-example)

The repository includes:

* Complete training script with Valohai integration
* `valohai.yaml` configuration
* Example notebooks
* Step-by-step setup instructions

***

## Next Steps

* [Visualize your metrics](https://docs.valohai.com/experiment-tracking/visualize-metrics) in Valohai
* [Compare experiments](https://docs.valohai.com/experiment-tracking/compare-executions) to find the best hyperparameters
* Learn more about [Keras callbacks](https://keras.io/guides/writing_your_own_callbacks/)
* Back to [Collect Metrics overview](https://docs.valohai.com/experiment-tracking/collect-metrics)
