11. The losses
All loss functions implemented in Keras are subclasses of the
For example, you can implement a
MeanSquaredError loss as follows.
class MeanSquaredError(Loss): def call(self, y_true, y_pred): return tf.reduce_mean(tf.math.square(y_pred - y_true), axis=-1)
You can use it as a standalone function as follows.
Loss.call() is called by
Loss.__call__() under the hood.
loss = MeanSquaredError() print(loss(np.array([0.0, 1.0]), np.array([1.0, 1.0]))) # tf.Tensor(0.5, shape=(), dtype=float64)
All the built-in losses are implemented in a similar way, which is to override
call() function. It computes the loss for the given ground truth and
As you may know, the loss passed to
.compile() can be a string, a function, or
Loss subclass instance. However, they are all converted to a
subclass in the end. In
.comiple(), it uses
to convert the losses. The
get() function can accept a string, which is the
name of the loss, and returns either a loss function or a
instance. If it is a function, Keras would further convert it to a
subclass instance using the
LossFunctionWrapper, which wraps the function
Loss subclass instance. The overall converting process is shown in
the following pseudo-code.
def get_loss(loss): loss = keras.losses.get(loss) # If it is an function if not isinstance(loss, keras.losses.Loss): loss = losses_mod.LossFunctionWrapper(loss, name=loss_name) return loss
LossFunctionWrapper class is just a subclass of
Loss and calls the
provided loss function in its
.call() function as shown in the following
pseudo-code. We also show an example of wrapping up a
loss function into a
Loss subclass instance.
class LossFunctionWrapper(Loss): def __init__(self, fn): self.fn = fn def call(self, y_true, y_pred): return self.fn(y_true, y_pred) def mean_squared_error(y_true, y_pred): return tf.reduce_mean(tf.math.square(y_pred - y_true), axis=-1) loss = LossFunctionWrapper(mean_squared_error)
It is a common pattern in Keras, which wraps a function into a class instance. We will see this pattern again in metrics as well, which will be introduced later.
As you can see from above, the
Loss.call() function is only dealing with a
single batch of data. During training, computing the loss of a single batch of
data is enough for backpropagation. However, we also need to compute the
average loss value of all the trained batches to print to screen during
training. It is done by the
class, which we mentioned when introducing
.compile(). It manages a metric
to track the historical loss values for the batches. Unlike a loss function,
the metric can not only compute the metric for one batch of data but also
record some statistics across the historical batches. We will introduce more
details about how metrics work in Keras later.
LossesContainer class also manages multiple losses, which is supported in
Keras. The user can pass multiple losses to a model with multiple heads, where
each loss corresponds to one head. The
LossesContainer class contains all
these losses and exposes methods that can manage them as if they are a single
This is also a pattern. Using containers to manage a collection of objects and behaves similarly to a single object. We will see it again when we introduce the metrics.
The only method of
LossesContainer that directly called by the
.__call__(). As we introduced in
self.compiled_loss(y, y_pred) to compute the loss value.
LossesContainer.__call__(), it iterates through the different heads of the
model and computes the losses, and sums them up. The pseudo-code is as follows.
class LossesContainer(Container): def __call__(self, y_true, y_pred): # y_pred is a list of outputs. # Each element in the list is the output of one of the heads. # y_true is the ground truth for the heads in a similar format. for single_y_true, single_y_pred, single_loss in zip( y_true, y_pred, self._losses): loss_values.append( single_loss(single_y_true, single_y_pred)) total_loss = sum(loss_values) # Updating the metric tracking the average loss value # across the historical batches. self._losses_metric.update_state(total_loss) return total_loss