fifa-world-cup-football
The Big Match Cloud OFFER
Kick off for the Big Stage with ₹20,000 in GPU credits
fifa-world-cup-footballs
fifa-world-cup-football
Kick off with ₹20,000 in Free GPU credits

Keras GPU: Using Keras On Single GPU, Multi-GPU, And TPUs

Jason Karlin's profile image
Jason Karlin
Last Updated: Jun 29, 2026
11 Minute Read
2580 Views

A Keras model may work correctly on a CPU, but long training runs can quickly slow experimentation. Moving to a single GPU can reduce training time without requiring major model changes, although the environment still needs compatible drivers, packages, and correct device detection.

When one GPU no longer provides enough throughput, Keras training can scale across several GPUs using model replicas and synchronized gradient updates. A compatible TPU provides another option for workloads dominated by large tensor operations.

This blog uses Keras 3 with the TensorFlow backend. Keras also supports JAX and PyTorch, but their accelerator and distribution workflows differ. You will reuse one model across a single GPU, multiple GPUs, and a TensorFlow-compatible TPU rather than maintaining three implementations.

How does Keras use GPUs and TPUs?

Keras provides the model-building and training interface. Its selected backend performs tensor operations, device placement, gradient calculation, and distributed execution.

With the TensorFlow backend:

  • Supported operations can run automatically on one detected GPU.
  • tf.distribute.MirroredStrategy distributes training across GPUs.
  • tf.distribute.TPUStrategy distributes training across compatible TPU devices.

Keras 3 supports TensorFlow, JAX, and PyTorch, but this guide stays with TensorFlow to provide one consistent workflow. Starting with TensorFlow 2.16, installing TensorFlow installs Keras 3 by default.

Keras uses its selected backend to execute models on available accelerators. TensorFlow users can train automatically on one GPU, scale across GPUs with MirroredStrategy, and use compatible TPUs with TPUStrategy.

Keras GPU Setup and Verification

A reliable installation begins with compatible operating system, Python, driver and TensorFlow versions because each component affects GPU discovery.

Before installing anything, confirm the following:

  • OS: Linux or Windows with WSL2 (macOS does not support NVIDIA CUDA)
  • Python: 3.9 to 3.12 for the current documented TensorFlow pip installation
  • NVIDIA driver: Compatible with the required runtime
  • CUDA and cuDNN libraries: Installed as Python dependencies with the pip method below

Native Windows GPU support ended with TensorFlow 2.10. Current Windows GPU users generally need WSL2. TensorFlow does not provide official GPU support on macOS.

Build a stronger foundation with this practical TensorFlow guide.

Installing TensorFlow with CUDA Dependencies

Create an isolated virtual environment, update pip and install TensorFlow with its CUDA dependencies.

python3 -m venv .venv
source .venv/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install "tensorflow[and-cuda]"

The tensorflow[and-cuda] package installs the required NVIDIA runtime libraries as Python dependencies, but the host still needs a compatible NVIDIA driver. Confirm that the operating system can see the GPU:

nvidia-smi

Selecting the TensorFlow Backend

Set the backend before importing Keras because Keras cannot change backends after the package has been imported.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import tensorflow as tf

print("Keras:", keras.__version__)
print("TensorFlow:", tf.__version__)

Verifying GPU Detection

Use TensorFlow’s physical-device list after confirming that nvidia-smi can identify the installed GPU.

gpus = tf.config.list_physical_devices("GPU")
print(gpus)

if not gpus:
    raise RuntimeError("TensorFlow did not detect a GPU.")

TensorFlow normally selects a supported GPU implementation when one is available. Temporarily enable device-placement logs only for debugging, because it can generate large logs and slow or obscure normal benchmark runs:

tf.debugging.set_log_device_placement(True)

Managing TensorFlow GPU Memory

By default, TensorFlow maps nearly all memory on visible GPUs to reduce fragmentation. Enabling memory growth allows the process to request additional GPU memory as needed instead. Configure it before TensorFlow initializes any GPU.

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

How to Train Keras on a Single GPU?

A single GPU is the best baseline because it avoids synchronization overhead and makes input-pipeline and memory bottlenecks easier to identify.

The revised example uses CIFAR-10 instead of MNIST. It is still compact enough for a tutorial, but its three-channel images and larger CNN provide a more realistic accelerator demonstration. It should not be treated as a production benchmark.

keras.utils.set_random_seed(42)

AUTOTUNE = tf.data.AUTOTUNE

(x_train, y_train), (x_test, y_test) = (
    keras.datasets.cifar10.load_data()
)

y_train = y_train.squeeze()
y_test = y_test.squeeze()

def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

def make_dataset(
    images,
    labels,
    batch_size,
    training=False,
    drop_remainder=False,
):
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))

    if training:
        dataset = dataset.shuffle(
            len(images),
            reshuffle_each_iteration=True,
        )

    return (
        dataset
        .map(preprocess, num_parallel_calls=AUTOTUNE)
        .batch(batch_size, drop_remainder=drop_remainder)
        .prefetch(AUTOTUNE)
    )

tf.data supports parallel transformation, batching, and prefetching so the next batch can be prepared while the accelerator processes the current one.

Creating a Reusable Keras Model

def build_model():
    inputs = keras.Input(shape=(32, 32, 3))

    x = keras.layers.Conv2D(
        64, 3, padding="same", activation="relu"
    )(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling2D()(x)

    x = keras.layers.Conv2D(
        128, 3, padding="same", activation="relu"
    )(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling2D()(x)

    x = keras.layers.Conv2D(
        256, 3, padding="same", activation="relu"
    )(x)

    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.3)(x)

    logits = keras.layers.Dense(10)(x)
    outputs = keras.layers.Activation(
        "softmax",
        dtype="float32",
    )(logits)

    return keras.Model(inputs, outputs)

The final output remains in float32 because TensorFlow recommends a float32 model output, particularly when a softmax feeds directly into the loss during mixed-precision training.

Establishing a Single-GPU Baseline

Run the first benchmark with float32:

keras.mixed_precision.set_global_policy("float32")

BATCH_SIZE = 128

train_dataset = make_dataset(
    x_train, y_train, BATCH_SIZE, training=True
)

test_dataset = make_dataset(
    x_test, y_test, BATCH_SIZE
)

model = build_model()

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=5,
)

Explore the best GPUs for deep learning based on memory, speed, and workload size.

Testing Mixed-Precision Training

After recording the float32 result, clear the previous model and build a new one under the mixed-precision policy:

keras.backend.clear_session()

keras.mixed_precision.set_global_policy(
    "mixed_float16"
)

mixed_model = build_model()

Compile and train mixed_model using the same data and optimizer settings.

Mixed precision can reduce tensor and activation memory use and improve throughput on GPUs with suitable hardware support. The actual benefit depends on the model, GPU, batch size, and operations used.

Record:

  • Samples per second
  • Median epoch duration
  • Peak GPU memory
  • GPU utilization
  • Validation accuracy

How to Train Keras on Multiple GPU?

MirroredStrategy implements synchronous data parallelism. It creates a complete model replica on each GPU, distributes different portions of the global batch to those replicas, aggregates their gradients, and keeps their weights synchronized.

strategy = tf.distribute.MirroredStrategy()

PER_REPLICA_BATCH_SIZE = 128
GLOBAL_BATCH_SIZE = (
    PER_REPLICA_BATCH_SIZE
    * strategy.num_replicas_in_sync
)

train_dataset = make_dataset(
    x_train,
    y_train,
    GLOBAL_BATCH_SIZE,
    training=True,
    drop_remainder=True,
)

test_dataset = make_dataset(
    x_test,
    y_test,
    GLOBAL_BATCH_SIZE,
)

with strategy.scope():
    model = build_model()

    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=5,
)

Calculating Global Batch Size

The global batch equals:

per-replica batch size × number of replicas

Increasing it linearly preserves the amount of work assigned to each GPU, but it is not mandatory. For convergence comparisons, you may keep the global batch fixed. For throughput scaling, you may keep the per-replica batch fixed and increase the global batch. Larger global batches can require learning-rate and regularization retuning.

Why GPU Memory Is Not Automatically Combined

No. Each GPU normally stores a complete model replica, optimizer state, and activations for its local batch.

Adding GPUs primarily increases throughput. Models that do not fit on one GPU may require a larger-memory device, gradient checkpointing, parameter sharding, or model parallelism.

How Does Batch Normalization Behave?

By default, each replica calculates Batch Normalization statistics from its local batch. TensorFlow-backed Keras supports BatchNormalization(synchronized=True) when global replica statistics are required, but synchronization adds communication overhead.

Why Multi-GPU Scaling Can Be Sublinear

Adding GPUs does not guarantee proportional speedups. Common causes include:

  • Gradient synchronization overhead: The all-reduce step consumes time proportional to model size. Smaller models spend more of their wall time on synchronization than on useful compute.
  • Slow input pipeline: A CPU-bound data pipeline starves GPUs between batches. Profile first with the TensorFlow Profiler before concluding that more GPUs are the solution.
  • Small per-GPU batch: Each replica needs enough work to justify the synchronization cost. Per-replica batches that are too small reduce GPU utilization.
  • Weak interconnect: NVLink offers much higher bandwidth than PCIe. PCIe generally provides less peer-to-peer bandwidth than NVLink. Whether communication becomes the primary bottleneck depends on the model, gradient volume, topology, and amount of useful computation performed by each replica.

When Multi-Worker Training Becomes Relevant

MultiWorkerMirroredStrategy extends synchronous training across several workers. It requires worker processes, a consistent TF_CONFIG, correct dataset sharding, coordinated checkpointing, and sufficient network bandwidth.

How to Run the Same Keras Model on a TPU

TPU support depends on both the backend and TPU generation. Google currently supports JAX and PyTorch, but not TensorFlow, on TPU7x. Confirm framework support before provisioning a TPU for this TensorFlow workflow.

In a configured TensorFlow-compatible TPU runtime:

tpu = (
    tf.distribute.cluster_resolver
    .TPUClusterResolver
    .connect()
)

strategy = tf.distribute.TPUStrategy(tpu)

keras.mixed_precision.set_global_policy(
    "mixed_bfloat16"
)

PER_REPLICA_BATCH_SIZE = 128
GLOBAL_BATCH_SIZE = (
    PER_REPLICA_BATCH_SIZE
    * strategy.num_replicas_in_sync
)

train_dataset = make_dataset(
    x_train,
    y_train,
    GLOBAL_BATCH_SIZE,
    training=True,
    drop_remainder=True,
)

with strategy.scope():
    model = build_model()

    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
        steps_per_execution=50,
    )

model.fit(train_dataset, epochs=5)

drop_remainder=True creates fixed batch shapes but discards the final incomplete batch. Disclose this when reporting sample counts or accuracy.

steps_per_execution can reduce host overhead by running multiple batches in one compiled call. Batch-level callbacks and logs then run once per execution block rather than after every batch.

For TPU training:

  • Use tf.data.
  • Test larger global batches.
  • Reduce Python-side work.
  • Keep tensor shapes stable where practical.
  • Ensure workers can access data efficiently.
  • Use mixed_bfloat16 when appropriate.

Benchmarking Single-GPU, Multi-GPU, and TPU Performance

Keep the dataset, model, optimizer, precision policy, epochs, and validation method consistent.

Use a warm-up period followed by at least three measured runs. Report the median rather than the fastest result.

Measure:

  • Samples per second
  • Median epoch duration
  • Peak memory
  • Time to target accuracy
  • Scaling efficiency
  • Cost per epoch
  • Cost to target accuracy

Scaling efficiency =

multi-GPU throughput
÷ (single-GPU throughput × GPU count)
× 100

TensorFlow Profiler can help distinguish host, input-pipeline, and accelerator bottlenecks.

Do not publish estimated TPU or multi-GPU numbers as measured results.

See how Ada, Ampere, Hopper, and Blackwell GPUs compare for AI workloads.

Common Keras GPU and TPU Problems

ProblemFirst checkLikely action
GPU not detectednvidia-smi and TensorFlow device listCorrect the driver or environment
Out of memoryBatch and input sizesReduce memory use or choose more GPU memory
Low utilizationInput pipelinePrefetch, parallelize, and profile
Multi-GPU slowdownPer-replica workloadIncrease useful work per device
Accuracy changesBatch size and precisionRetune and validate numerics
TPU failureBackend and generationUse a compatible runtime

Decision Table

OptionBest suited forMain limitationFirst metric to check
Single GPUDevelopment and moderate trainingOne device’s memory and throughputPeak memory
Multi-GPUThroughput-limited trainingSynchronization overheadScaling efficiency
Multi-workerWorkloads exceeding one hostNetwork and operational complexityWorker utilization
TPULarge, compatible tensor workloadsBackend and runtime constraintsTime to target accuracy

Talk to an expert about choosing between a high-memory GPU and a multi-GPU configuration.

Scale Your Keras Workloads Faster with AceCloud GPU Infrastructure

The right Keras accelerator strategy starts with measurement, not assumptions. Build a reliable single-GPU baseline, verify device detection, optimize tf.data, and test mixed precision before adding more hardware.

If throughput remains the bottleneck, use MirroredStrategy to scale across multiple GPUs and monitor global batch size, synchronization overhead, memory use, and scaling efficiency. TPUs may suit large, regular tensor workloads, but backend and generation compatibility must be checked first.

AceCloud helps ML engineers and data scientists run Keras and TensorFlow workloads on scalable NVIDIA GPU infrastructure without managing physical servers.

Book a free consultation with AceCloud to select the right NVIDIA GPU configuration for your model, dataset, memory requirement, batch size, precision policy and training target.

Frequently Asked Questions

Yes. Keras can run supported operations on a GPU through a compatible backend such as TensorFlow, JAX, or PyTorch.

With TensorFlow, supported operations are generally placed on an available GPU after the device and runtime are detected correctly.

Create a tf.distribute.MirroredStrategy and build the model inside its scope. Keep the global batch fixed for convergence comparisons or the per-replica batch fixed for throughput tests.

No. Standard data parallelism normally stores a complete model replica on every GPU.

Yes, provided the selected Keras backend, TensorFlow version, and TPU generation are compatible.

Common causes include a missing NVIDIA driver, unsupported native-Windows configuration, unavailable runtime libraries, hidden devices, or installation in the wrong Python environment.

No. Compare the same model using time to target accuracy and total job cost. Performance depends on architecture, batch size, input throughput, supported operations, and compilation overhead.

Jason Karlin's profile image
Jason Karlin
author
Industry veteran with over 10 years of experience architecting and managing GPU-powered cloud solutions. Specializes in enabling scalable AI/ML and HPC workloads for enterprise and research applications. Former lead solutions architect for top-tier cloud providers and startups in the AI infrastructure space.

Get in Touch

Explore trends, industry updates and expert opinions to drive your business forward.

    We value your privacy and will never share your information with any third-party vendors. See Privacy Policy