• Ray Christian

Saving multi-gpu models with Keras ModelCheckpoint


For all but the simplest of neural network models, computational requirements are extreme. At Textpert, our models for AiME range from several hundred thousand to millions of parameters. Input examples alone are up to hundreds of MB each comprised of several modalities, many of which are temporal—full color video at high frame rates, high sample rate audio, word vectors, and more—and each model sees 10s to 100s of thousands of these training examples.

Suffice to say, in practice neural networks are resource hungry, in particular for high computational throughput devices: GPUs, APUs, and TPUs. Normally with high-level frameworks such as Keras, the differences are negligible or nonexistent between coding, training, and deploying GPU models vs. CPU models.

Major contributors to AiME's parameter overhead

AiME combinatorial explosion parameter overhead

Enter multi-GPU models

Multi-GPU setups are common enough to have warranted a built-in abstraction in Keras for a popular implementation using data parallelism, see multi_gpu_model(), which requires only a few extra lines of code to use.

This method works by making a “meta-model” that splits large input batches to smaller batches distributed to several GPUs. After the forward pass on each GPU (in parallel), predictions are merged back into the original large batch size, and loss and backpropagation are then performed.

Pain point

Due to Keras’ architecture and the fact that you are now technically working with a new model, some subtle changes are required to continue working with the model in other environments, particularly when available devices change.

Consider an example in which you train a model on 8 GPUs and save it after some number of epochs. If you used ModelCheckpoint or model.save to store your model to disk, loading the model architecture and weights simultaneously via load_model() requires you to have the same device profile as when you trained the model—in this example 8 GPUs. But what if you just want to make some ad-hoc predictions on CPU? Or serve the model on production servers with only 1 GPU each?

Saving the template model

As Keras documentation for multi_gpu_model() suggests, rather than saving the multi-GPU model, you should be saving the “template” model, i.e. the original single-device model passed to the multi_gpu_model() function. Since the template model and the meta-model share weights, at the end of training your template model represents your simple single-device model with the learnings from the multi-GPU model.

Integrating with ModelCheckpoint

A limitation of Keras Callback architecture as it is now is that you are limited to operating on the model being trained. You can set an alternate model to be used as the “callback model” but it is an undocumented feature, and you can’t set that alternate model on a per-callback basis.

In our case at Textpert, we need to use the multi-GPU model on our other callbacks for performance reasons, but we also need the template model for ModelCheckpoint and some other callbacks. For that reason, we made a tiny adapter called AltModelCheckpoint to wrap ModelCheckpoint with the checkpointed model being explicitly specified.

Installation is easy:

pip install alt-model-checkpoint

And so is general usage:

from alt_model_checkpoint import AltModelCheckpoint from keras.models import Model from keras.utils import multi_gpu_model base_model = Model(...) gpu_model = multi_gpu_model(base_model) gpu_model.compile(...) gpu_model.fit(..., callbacks= AltModelCheckpoint('save/path/for/model.hdf5', base_model) ])