9. The predict function

(Source)

The logic of the Model.predict() function is very similar to Model.fit() as shown in the following pseudo code. It first wraps the input x into a DataHandler. Then, build the Model.predict_step() into a tf.function with Model.make_predict_function(). .predict_step() make predicitons for a single batch of data, which can also be overridden to customize the predict behavior. Similar to .make_train_function(), .make_predict_function() would handle the distribute strategy while building the tf.function.

class Model(Layer):
    def predict(self, x, ...):
        data_handler = DataHandler(x)
        self.predict_function = self.make_predict_function()
        outputs = []
        for epoch, iterator in data_handler.enumerate_epochs():
            for step in data_handler.steps():
                outputs.append(self.predict_function(iterator))
        return outputs

By default, the Model.predict_step() function would just unpack x from the provided data (because data may contain y) and call the model to do a forward pass with the data as shown in the following pseudo code.

(Source)

class Model(Layer):
    def predict_step(self, data):
        x = data_adapter.unpack(data)
        return self(x, training=False)