4. The Functional class
There is another way of using the Model
class besides subclassing it, which
is the functional API. It connects the layers to each other to form a directed
acyclic graph (DAG), where the nodes are layer call events, and the edges are
KerasTensors. Please refer to this
tutorial for more details of how to
use it. Following is a code example of using the functional API.
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
Although it looks like it is still using the Model
class, it is
using the Functional
class, which is an internal class not exposed to the
public API. In Model.__new__()
, it creates a Functional
instance if using
the functional API. The source code looks like this.
class Model(Layer):
def __new__(cls, *args, **kwargs):
if is_functional_model_init_params(args, kwargs) and cls == Model:
return functional.Functional(skip_init=True, *args, **kwargs)
Now, let's see how Functional
is tracking these layers and intermediate
outputs in the computational graph.
The KerasTensor
, which looks like a class, but it is a function, which
returns a KerasTensor
print(type(keras.Input(shape=(28, 28, 1))))
<class 'keras.engine.keras_tensor.KerasTensor'>
is a class just to represent the intermediate output tensors of
the layers in a Keras model, which has some useful properties like shape
class KerasTensor(object):
def shape(self):
def dtype(self):
It is passed to each of the layers by calling them as shown in the functional
API example. The purpose is for the layers to create the weights using the
shape and type information of the input tensor. That is also why we have a
special judge to see if the input tensor is a KerasTensor
as we introduced before.
class Layer(module.Module, ...):
def __call__(self, inputs, **kwargs):
if isinstance(inputs, keras_tensor.KerasTensor):
inputs = convert_to_tf_tensor(inputs)
outputs = self.call(inputs)
return convert_to_keras_tensor(outputs)
From the source code above, we can see if we call a layer with a KerasTensor
the return value is also a KerasTensor
, which will be used to call the next
Connecting the layers
The question we try to answer here is: How did the computational graph being
recorded and fetched only given the inputs
and outputs
? This functionality
is implemented in
The graph is being fetched starting from the outputs
, which is a list of
s. Each KerasTensor
records the Layer
instance that produces
it during the call of the Layer
. The algorithm is like this. First, from
the outputs
, we got the Layer
producing these outputs
. Second, use
the Layer
to get the input KerasTensors
. Third, use these KerasTensor
s to
get the previous layers. Keep doing this until the inputs to the model are
Here are another two questions to answer:
How does a
get theLayer
producing it? -
How does the
get the inputKerasTensor
To answer these two questions, there are two important classes or concepts to
make clear first:
A Node
is created at each call of a layer, to represent the connectivity
between the two Layer
s. In other words, a Node
corresponds to a call of a
. Each Layer
has an attribute of
to track the input Node
s. The reason why there can be multiple inbound
s is that a Layer
may be used in multiple places in a model for weight
In the following example, layer_a
and layer_b
are all called multiple
times. Therefore, node1
and node4
are in layer_a._inbound_nodes
and node5
are in layer_b._inbound_nodes
node1 -> layer_a -> node2 -> layer_b -> node3
node4 -> layer_a -> node5 -> layer_b -> node6
When building a functional model, we can call a Layer
with multiple
s. For example, the
layer add multiple
tensors together. Therefore, a call of a layer corresponds to multiple
s. A Node
also corresponds to a call of a Layer
. Therefore,
a Node
may correspond to multiple KerasTensor
s. A Node
record these
s in Node._keras_inputs
is for the KerasTensor
to find the input Layer
and the
corresponding Node
. It is stored in the attribute of
. KerasHistory.layer
records the Layer
producing it. KerasHistory.node_index
records the index of the corresponding
in the KerasHistory.layer._inbound_nodes
For example, if node2
has a corresponding KerasTensor
, named
, keras_tensor_2._keras_history.node_index
records the index
of node1
in layer_a._inbound_nodes
With all these recording mechanisms, we can have the following pseudo-code.
class Layer(tf.Module):
def __call__(inputs):
outputs = self.call(inputs)
node = Node(self, inputs, outputs)
class Node:
def __init__(self, layer, inputs, outputs):
self._keras_inputs = inputs
node_index = len(self.layer._inbound_nodes) - 1
for keras_tensor in outputs:
keras_tensor._keras_history = KerasHistory(layer, node_index)
Now, we have the answers to the two questions above. A KerasTensor
find the
layer producing it with KerasTensor._keras_history.layer
. A Layer
find the
input KerasTensor
with Layer._inbound_nodes[0]._keras_inputs
Finally, we can fetch the entire computational graph with the algorithm described above. The pseudo-code is shown as follows.
def fetch_nodes(output_keras_tensor):
if output_keras_tensor in model.inputs:
return []
layer = output_keras_tensor._keras_history.layer
node_index = output_keras_tensor._keras_history.node_index
node = layer._inbound_nodes[node_index]
node_list = [node]
for input_keras_tensor in node._keras_inputs:
node_list += fetch_nodes(input_keras_tensor)
return node_list
The actual code would not only fetch the nodes, but also the layers, and sort
them in topological order. In Functional.call
, it calls the layer in
topological order to produce the outputs.