Skip to content

8. The fit function


The first thing the function does is to convert the user passed dataset (x and y) into a compatible format that ready to be used for the training. They are wrapped into a DataHandler, which can be used conveniently during the training. It has a method, named enumerate_epochs(), which returns the current epoch number and the iterator of the dataset. It also has a method, named steps(), which returns the current step number. These two functions are mainly used by the function to iterate over the dataset for each epoch.

In side DataHandler, it would convert different types of data into a object with different DataAdapters. That is how Keras supports so many different types of data inputs.

With DataHandler, we prepared the dataset to be iterated batch-by-batch. For each batch of data, we would do a forward pass and updates of the weights, which is called a step in an epoch. We want to build a function to execute a single step, and compile it into a tf.function to accelerate the process. Here we use Model.make_train_function() to get the function.

Now we can use the following pseudo code to summarize what happend in We first wrap the data into a DataHandler. Get the tf.function for running a step. Use a for loop to iterate through the epochs, and use an inner for loop to iterate the steps. In each step, we just call the function to execute the step.

class Model(Layer):
    def fit(self, x, y, ...):
        data_handler = DataHandler(x, y)
        self.train_function = self.make_train_function()
        for epoch, iterator in data_handler.enumerate_epochs():
            for step in data_handler.steps():

Train step

Now, it comes to the question of what happens in Model.make_train_function(). You can think it just returns the Model.train_step()). Notebly, the user can also override this Model.train_step() function to customize their training step.

Following is the pseudo code for Model.train_step(). It runs the forward pass using self(x) and compute the loss value, while recording all the gradients with tf.GradientTape(). Then, it use the optimizer to minimize the loss function to update the trainable variables using the gradients. Finally, it returns the metrics.

class Model(Layer):
    def train_step(self, data):
        x, y = data_adapter.unpack(data)
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
        self.compiled_metrics.update_state(y, y_pred)
        return { metric.result() for metric in self.metrics}

Distributed training

Actually, Model.make_train_function() adds one more functionality to Model.train_step(), which is to support the distributed training. Let's see how distributed training is supported in Keras with the TensorFlow APIs.

First, we need to use tf.distribute.Strategy.scope(), which opens up a scope to track all the TensorFlow variables created in this scope, for example the weights of the neural network.

TensorFlow API
tf.distribute.Strategy.scope(). scope() opens up a scope that any tf.Variable() created inside the scope is caught by TensorFlow to run distributedly.

To ensure everything is caught by the distributed strategy, we need to put almost the entire function in the scope as shown in the following pseudo code.

class Model(Layer):
    def fit(self, x, y, ...):
        with self.distribute_strategy.scope():
            data_handler = data_adapter.get_data_handler(x, y)
            self.train_function = self.make_train_function()
            for epoch, iterator in data_handler.enumerate_epochs():
                for step in data_handler.steps():

Another TensorFlow API we need to use here is When run distributedly, the Model.train_step() function needs to run on each replica in parallel. The Model.make_train_function() function wraps Model.train_step() into another function that uses to run train_step() distributedly. It also convert the function into a tf.function.

TensorFlow API runs a function on each replica with the given arguments. For example,, args=(arg1, arg2)).

The pseudo code of wrapping Model.train_step() is as follows. We wrap .train_step() into a new function, train_function(), convert it to a tf.function, and return it to .fit(). In train_function(), we just call .train_step() using and aggregate the outputs and return it.

class Model(Layer):
    def make_train_function(self, ...):
        def train_function(iterator):
            data = next(iterator)
            outputs =, args=(data,))
            outputs = reduce_per_replica(outputs)
            return outputs
        train_function = tf.function(train_function)
        return train_function

There is another TensorFlow distribute strategy API that is used by Keras is tf.distribute.get_strategy(). That is how Keras get the distribute strategy defined by the user.

TensorFlow API
tf.distribute.get_strategy() Returns the current tf.distribute.Strategy object.