4. The Functional class
(Source)
There is another way of using the Model class besides subclassing it, which
is the functional API. It connects the layers to each other to form a directed
acyclic graph (DAG), where the nodes are layer call events, and the edges are
KerasTensors. Please refer to this
tutorial for more details of how to
use it. Following is a code example of using the functional API.
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
Although it looks like it is still using the Model class, it is
using the Functional class, which is an internal class not exposed to the
public API. In Model.__new__(), it creates a Functional instance if using
the functional API. The source code looks like this.
(Source)
class Model(Layer):
def __new__(cls, *args, **kwargs):
if is_functional_model_init_params(args, kwargs) and cls == Model:
return functional.Functional(skip_init=True, *args, **kwargs)
...
Now, let's see how Functional is tracking these layers and intermediate
outputs in the computational graph.
The KerasTensor class
keras.Input(), which looks like a class, but it is a function, which
returns a KerasTensor object.
print(type(keras.Input(shape=(28, 28, 1))))
Outputs:
<class 'keras.engine.keras_tensor.KerasTensor'>
KerasTensor is a class just to represent the intermediate output tensors of
the layers in a Keras model, which has some useful properties like shape and
dtype.
(Source)
class KerasTensor(object):
@property
def shape(self):
...
@property
def dtype(self):
...
It is passed to each of the layers by calling them as shown in the functional
API example. The purpose is for the layers to create the weights using the
shape and type information of the input tensor. That is also why we have a
special judge to see if the input tensor is a KerasTensor in
Layer.__call__() as we introduced before.
class Layer(module.Module, ...):
def __call__(self, inputs, **kwargs):
if isinstance(inputs, keras_tensor.KerasTensor):
inputs = convert_to_tf_tensor(inputs)
outputs = self.call(inputs)
return convert_to_keras_tensor(outputs)
...
From the source code above, we can see if we call a layer with a KerasTensor,
the return value is also a KerasTensor, which will be used to call the next
layer.
Connecting the layers
The question we try to answer here is: How did the computational graph being
recorded and fetched only given the inputs and outputs? This functionality
is implemented in
Functional._init_graph_network().
The graph is being fetched starting from the outputs, which is a list of
KerasTensors. Each KerasTensor records the Layer instance that produces
it during the call of the Layer. The algorithm is like this. First, from
the outputs, we got the Layer producing these outputs. Second, use
the Layer to get the input KerasTensors. Third, use these KerasTensors to
get the previous layers. Keep doing this until the inputs to the model are
reached.
Here are another two questions to answer:
-
How does a
KerasTensorget theLayerproducing it? -
How does the
Layerget the inputKerasTensors?
To answer these two questions, there are two important classes or concepts to
make clear first:
Node
and
KerasHistory.
A Node is created at each call of a layer, to represent the connectivity
between the two Layers. In other words, a Node corresponds to a call of a
Layer. Each Layer has an attribute of
_inbound_nodes
to track the input Nodes. The reason why there can be multiple inbound
Nodes is that a Layer may be used in multiple places in a model for weight
sharing.
In the following example, layer_a and layer_b are all called multiple
times. Therefore, node1 and node4 are in layer_a._inbound_nodes.
node2 and node5 are in layer_b._inbound_nodes.
node1 -> layer_a -> node2 -> layer_b -> node3
node4 -> layer_a -> node5 -> layer_b -> node6
When building a functional model, we can call a Layer with multiple
KerasTensors. For example, the
Add layer add multiple
tensors together. Therefore, a call of a layer corresponds to multiple
KerasTensors. A Node also corresponds to a call of a Layer. Therefore,
a Node may correspond to multiple KerasTensors. A Node record these
KerasTensors in Node._keras_inputs.
KerasHistory is for the KerasTensor to find the input Layer and the
corresponding Node. It is stored in the attribute of
KerasTensor._keras_history. KerasHistory.layer records the Layer
producing it. KerasHistory.node_index records the index of the corresponding
Node in the KerasHistory.layer._inbound_nodes.
For example, if node2 has a corresponding KerasTensor, named
keras_tensor_2, keras_tensor_2._keras_history.node_index records the index
of node1 in layer_a._inbound_nodes.
With all these recording mechanisms, we can have the following pseudo-code.
(Source)
class Layer(tf.Module):
def __call__(inputs):
...
outputs = self.call(inputs)
node = Node(self, inputs, outputs)
...
(Source)
class Node:
def __init__(self, layer, inputs, outputs):
...
self._keras_inputs = inputs
layer._inbound_nodes.append(self)
node_index = len(self.layer._inbound_nodes) - 1
for keras_tensor in outputs:
keras_tensor._keras_history = KerasHistory(layer, node_index)
...
Now, we have the answers to the two questions above. A KerasTensor find the
layer producing it with KerasTensor._keras_history.layer. A Layer find the
input KerasTensor with Layer._inbound_nodes[0]._keras_inputs.
Finally, we can fetch the entire computational graph with the algorithm described above. The pseudo-code is shown as follows.
(Source)
def fetch_nodes(output_keras_tensor):
if output_keras_tensor in model.inputs:
return []
layer = output_keras_tensor._keras_history.layer
node_index = output_keras_tensor._keras_history.node_index
node = layer._inbound_nodes[node_index]
node_list = [node]
for input_keras_tensor in node._keras_inputs:
node_list += fetch_nodes(input_keras_tensor)
return node_list
The actual code would not only fetch the nodes, but also the layers, and sort
them in topological order. In Functional.call, it calls the layer in
topological order to produce the outputs.