TensorFlow/Keras

TensorFlow and Keras provide a callback system that makes metric logging clean and automatic. Create a custom callback to log metrics at the end of each epoch without cluttering your training code.


Quick Example

import tensorflow as tf
import valohai

class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with valohai.metadata.logger() as logger:
            logger.log('epoch', epoch + 1)
            logger.log('accuracy', logs['accuracy'])
            logger.log('loss', logs['loss'])
            logger.log('val_accuracy', logs['val_accuracy'])
            logger.log('val_loss', logs['val_loss'])

# Use the callback
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10,
    callbacks=[ValohaiMetricsCallback()]
)

Why Use Callbacks?

Keras callbacks run at specific points during training. They let you:

  • Access all training metrics automatically

  • Keep metric logging separate from model code

  • Reuse the same callback across projects


Complete Working Example

Here's a full training script with Valohai integration:

import numpy as np
import tensorflow as tf
import valohai
import json

class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    """Log training metrics to Valohai after each epoch"""
    
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        
        with valohai.metadata.logger() as logger:
            logger.log('epoch', epoch + 1)
            
            # Log all available metrics
            for key, value in logs.items():
                logger.log(key, float(value))

def main():
    # Configure Valohai
    valohai.prepare(
        step='train-model',
        image='tensorflow/tensorflow:2.13.0',
        default_inputs={
            'dataset': 'https://valohaidemo.blob.core.windows.net/mnist/preprocessed_mnist.npz',
        },
        default_parameters={
            'learning_rate': 0.001,
            'epochs': 10,
            'batch_size': 32,
        },
    )
    
    # Load data from Valohai inputs
    input_path = valohai.inputs('dataset').path()
    with np.load(input_path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
    
    # Normalize
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    
    # Build model
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax'),
    ])
    
    # Compile model
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=valohai.parameters('learning_rate').value
    )
    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Log hyperparameters
    print(json.dumps({
        "model": "Sequential",
        "optimizer": "adam",
        "learning_rate": valohai.parameters('learning_rate').value,
        "batch_size": valohai.parameters('batch_size').value,
        "loss": "sparse_categorical_crossentropy"
    }))
    
    # Train with Valohai callback
    history = model.fit(
        x_train,
        y_train,
        batch_size=valohai.parameters('batch_size').value,
        epochs=valohai.parameters('epochs').value,
        validation_split=0.1,
        callbacks=[ValohaiMetricsCallback()]
    )
    
    # Evaluate and log test results
    test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=2)
    
    with valohai.metadata.logger() as logger:
        logger.log('final_test_accuracy', test_accuracy)
        logger.log('final_test_loss', test_loss)
    
    # Save model
    output_path = valohai.outputs().path('model.h5')
    model.save(output_path)
    
    print(f"Model saved to {output_path}")

if __name__ == '__main__':
    main()

valohai.yaml Configuration

Make sure to change the input data and environment to match your own values.

- step:
    name: train-tensorflow
    image: tensorflow/tensorflow:2.13.0
    command:
      - pip install valohai-utils
      - python train.py {parameters}
    inputs:
      - name: dataset
        default: dataset://training-data/latest
    parameters:
      - name: learning_rate
        type: float
        default: 0.001
      - name: batch_size
        type: integer
        default: 32
      - name: epochs
        type: integer
        default: 10

Logging Without valohai-utils

You can also log metrics using plain JSON:

import json
import tensorflow as tf

class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        
        metadata = {
            'epoch': epoch + 1,
            'accuracy': float(logs.get('accuracy', 0)),
            'loss': float(logs.get('loss', 0)),
            'val_accuracy': float(logs.get('val_accuracy', 0)),
            'val_loss': float(logs.get('val_loss', 0))
        }
        
        print(json.dumps(metadata))

# Use the callback
model.fit(
    x_train,
    y_train,
    epochs=10,
    validation_split=0.1,
    callbacks=[ValohaiMetricsCallback()]
)

Logging Learning Rate

Track learning rate changes during training:

class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        
        with valohai.metadata.logger() as logger:
            logger.log('epoch', epoch + 1)
            
            # Log training metrics
            for key, value in logs.items():
                logger.log(key, float(value))
            
            # Log current learning rate
            if hasattr(self.model.optimizer, 'learning_rate'):
                lr = self.model.optimizer.learning_rate
                if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
                    lr = lr(self.model.optimizer.iterations)
                logger.log('learning_rate', float(lr))

Combining Multiple Callbacks

Use multiple callbacks together:

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Valohai metrics callback
valohai_callback = ValohaiMetricsCallback()

# Early stopping
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

# Model checkpointing
checkpoint = ModelCheckpoint(
    valohai.outputs().path('best_model.h5'),
    monitor='val_accuracy',
    save_best_only=True
)

# Train with all callbacks
model.fit(
    x_train,
    y_train,
    epochs=50,
    validation_split=0.1,
    callbacks=[valohai_callback, early_stop, checkpoint]
)

Logging Custom Metrics

Add your own computed metrics:

class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        
        with valohai.metadata.logger() as logger:
            logger.log('epoch', epoch + 1)
            
            # Standard metrics
            for key, value in logs.items():
                logger.log(key, float(value))
            
            # Custom: Training/validation gap
            if 'loss' in logs and 'val_loss' in logs:
                gap = logs['val_loss'] - logs['loss']
                logger.log('train_val_gap', float(gap))
            
            # Custom: Accuracy improvement
            if 'val_accuracy' in logs and hasattr(self, 'prev_val_acc'):
                improvement = logs['val_accuracy'] - self.prev_val_acc
                logger.log('accuracy_improvement', float(improvement))
            
            self.prev_val_acc = logs.get('val_accuracy', 0)

Using LambdaCallback (Shorter Syntax)

For simple logging, use LambdaCallback:

import tensorflow as tf
import valohai

def log_metrics(epoch, logs):
    with valohai.metadata.logger() as logger:
        logger.log('epoch', epoch + 1)
        logger.log('accuracy', logs['accuracy'])
        logger.log('loss', logs['loss'])
        logger.log('val_accuracy', logs['val_accuracy'])
        logger.log('val_loss', logs['val_loss'])

# Create callback
metrics_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_metrics)

# Train
model.fit(
    x_train,
    y_train,
    epochs=10,
    validation_split=0.1,
    callbacks=[metrics_callback]
)

Logging Per-Batch Metrics (Advanced)

For very long epochs, you might want to log progress mid-epoch:

class ValohaiMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_every_n_batches=100):
        super().__init__()
        self.log_every_n_batches = log_every_n_batches
        self.batch_count = 0
    
    def on_batch_end(self, batch, logs=None):
        self.batch_count += 1
        
        # Only log every N batches
        if self.batch_count % self.log_every_n_batches == 0:
            logs = logs or {}
            with valohai.metadata.logger() as logger:
                logger.log('batch', self.batch_count)
                logger.log('batch_loss', float(logs.get('loss', 0)))
                logger.log('batch_accuracy', float(logs.get('accuracy', 0)))
    
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        
        with valohai.metadata.logger() as logger:
            logger.log('epoch', epoch + 1)
            for key, value in logs.items():
                logger.log(key, float(value))

Use sparingly: Logging every batch creates a lot of data. Only use for debugging or very long epochs.


Best Practices

Always Convert to Python Types

Keras metrics are NumPy types. Convert to Python types for JSON serialization:

# ✅ Good: Convert to float
logger.log("accuracy", float(logs['accuracy']))

# ❌ Avoid: NumPy types
logger.log("accuracy", logs['accuracy'])  # May cause issues

Handle Missing Metrics

Not all metrics are available in every callback:

def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    
    with valohai.metadata.logger() as logger:
        logger.log('epoch', epoch + 1)
        
        # Safe: Check if metric exists
        if 'val_loss' in logs:
            logger.log('val_loss', float(logs['val_loss']))
        
        # Or use .get() with default
        logger.log('loss', float(logs.get('loss', 0)))

Use Descriptive Metric Names

Keep names consistent with Keras conventions:

# ✅ Good: Standard Keras names
"loss"
"val_loss"
"accuracy"
"val_accuracy"

# ❌ Avoid: Custom abbreviations
"acc"
"vloss"
"train_a"

Common Issues

Metrics Not Appearing

Symptom: Callback runs but no metrics in Valohai

Causes & Fixes:

  • Missing validation_data → Add validation split or data

  • Incorrect metric names → Check available keys in logs

  • JSON serialization error → Convert NumPy/Tensor types to float

Debug:

def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    print(f"Available metrics: {list(logs.keys())}")  # Debug
    
    with valohai.metadata.logger() as logger:
        logger.log('epoch', epoch + 1)
        for key, value in logs.items():
            try:
                logger.log(key, float(value))
            except Exception as e:
                print(f"Failed to log {key}: {e}")

Validation Metrics Missing

Symptom: Only training metrics logged, no validation metrics

Solution: Make sure you provide validation data:

# ✅ Good: Includes validation
model.fit(
    x_train,
    y_train,
    validation_split=0.1,  # or validation_data=(x_val, y_val)
    callbacks=[ValohaiMetricsCallback()]
)

# ❌ No validation metrics
model.fit(
    x_train,
    y_train,
    callbacks=[ValohaiMetricsCallback()]  # logs won't have val_* metrics
)

Example Project

Check out our complete working example on GitHub:

valohai/tensorflow-example

The repository includes:

  • Complete training script with Valohai integration

  • valohai.yaml configuration

  • Example notebooks

  • Step-by-step setup instructions


Next Steps

Last updated

Was this helpful?