9. The predict function
(Source)
The logic of the Model.predict()
function is very similar to Model.fit()
as
shown in the following pseudo code. It first wraps the input x
into a
DataHandler
. Then, build the
Model.predict_step()
into a tf.function
with
Model.make_predict_function()
.
.predict_step()
make predicitons for a single batch of data, which can also
be overridden to customize the predict behavior. Similar to
.make_train_function()
, .make_predict_function()
would handle the
distribute strategy while building the tf.function
.
class Model(Layer):
def predict(self, x, ...):
data_handler = DataHandler(x)
self.predict_function = self.make_predict_function()
outputs = []
for epoch, iterator in data_handler.enumerate_epochs():
for step in data_handler.steps():
outputs.append(self.predict_function(iterator))
return outputs
By default, the Model.predict_step()
function would just unpack x
from the
provided data
(because data
may contain y
) and call the model to do a
forward pass with the data as shown in the following pseudo code.
(Source)
class Model(Layer):
def predict_step(self, data):
x = data_adapter.unpack(data)
return self(x, training=False)