In this notebook you’ll get to know about the Trax framework and learn about some of its basic building blocks.
TensorFlow and PyTorch are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.
Trax is much more concise. It runs on a TensorFlow backend but allows you to train models with 1 line commands. Trax also runs end to end, allowing you to get data, model and train all with a single terse statements. This means you can focus on learning, instead of spending hours on the idiosyncrasies of big framework implementation.
Keras is now part of Tensorflow itself from 2.0 onwards. Also, trax is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs,GPUs and TPUs as well with comparatively lesser modifications in code.
Building models in Trax relies on 2 key concepts:- layers and combinators. Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing you to build layers and models of any complexity.
You already know that Trax uses Tensorflow as a backend, but it also uses the JAX library to speed up computation too. You can view JAX as an enhanced and optimized version of numpy.
Watch out for assignments which import import trax.fastmath.numpy as np
. If you see this line, remember that when calling np
you are really calling Trax’s version of numpy that is compatible with JAX.
As a result of this, where you used to encounter the type numpy.ndarray
now you will find the type jax.interpreters.xla.DeviceArray
.
Tensor2Tensor is another name you might have heard. It started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So you can view Trax as the new improved version that operates much faster and simpler.
Trax has dependencies on JAX and some libraries like JAX which are yet to be supported in Windows but work well in Ubuntu and MacOS. We would suggest that if you are working on Windows, try to install Trax on WSL2.
Official maintained documentation - trax-ml not to be confused with this TraX
!pip install trax==1.3.9 #Use this version for this notebook
Requirement already satisfied: trax==1.3.9 in /opt/conda/lib/python3.7/site-packages (1.3.9)
Requirement already satisfied: gin-config in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (0.4.0)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (1.15.0)
Requirement already satisfied: absl-py in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (0.12.0)
Requirement already satisfied: funcsigs in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (1.0.2)
Requirement already satisfied: scipy in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (1.4.1)
Requirement already satisfied: tensorflow-datasets in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (4.3.0)
Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (1.19.5)
Requirement already satisfied: tensorflow-text in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (2.5.0)
Requirement already satisfied: jaxlib in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (0.1.67)
Requirement already satisfied: matplotlib in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (3.3.2)
Requirement already satisfied: t5 in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (0.9.1)
Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (5.7.0)
Requirement already satisfied: jax in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (0.2.14)
Requirement already satisfied: gym in /opt/conda/lib/python3.7/site-packages (from trax==1.3.9) (0.18.3)
Requirement already satisfied: pyglet<=1.5.15,>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from gym->trax==1.3.9) (1.5.15)
Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /opt/conda/lib/python3.7/site-packages (from gym->trax==1.3.9) (1.4.1)
Requirement already satisfied: Pillow<=8.2.0 in /opt/conda/lib/python3.7/site-packages (from gym->trax==1.3.9) (7.1.2)
Requirement already satisfied: opt-einsum in /opt/conda/lib/python3.7/site-packages (from jax->trax==1.3.9) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/conda/lib/python3.7/site-packages (from jaxlib->trax==1.3.9) (1.12)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->trax==1.3.9) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /opt/conda/lib/python3.7/site-packages (from matplotlib->trax==1.3.9) (2.4.7)
Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.7/site-packages (from matplotlib->trax==1.3.9) (0.10.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->trax==1.3.9) (2.8.1)
Requirement already satisfied: certifi>=2020.06.20 in /opt/conda/lib/python3.7/site-packages (from matplotlib->trax==1.3.9) (2021.5.30)
Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (0.1.96)
Requirement already satisfied: rouge-score in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (0.0.4)
Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (0.22.2.post1)
Requirement already satisfied: mesh-tensorflow[transformer]>=0.1.13 in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (0.1.19)
Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (1.2.4)
Requirement already satisfied: babel in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (2.9.1)
Requirement already satisfied: torch in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (1.9.0)
Requirement already satisfied: nltk in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (3.5)
Requirement already satisfied: transformers>=2.7.0 in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (4.7.0)
Requirement already satisfied: tfds-nightly in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (4.3.0.dev202106180109)
Requirement already satisfied: seqio in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (0.0.5)
Requirement already satisfied: sacrebleu in /opt/conda/lib/python3.7/site-packages (from t5->trax==1.3.9) (1.5.1)
Requirement already satisfied: future in /opt/conda/lib/python3.7/site-packages (from mesh-tensorflow[transformer]>=0.1.13->t5->trax==1.3.9) (0.18.2)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (0.10.3)
Requirement already satisfied: pyyaml in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (5.3.1)
Requirement already satisfied: huggingface-hub==0.0.8 in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (0.0.8)
Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (2021.4.4)
Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (4.45.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (2.23.0)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (1.6.0)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (20.1)
Requirement already satisfied: sacremoses in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (0.0.45)
Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers>=2.7.0->t5->trax==1.3.9) (3.0.12)
Requirement already satisfied: pytz>=2015.7 in /opt/conda/lib/python3.7/site-packages (from babel->t5->trax==1.3.9) (2020.1)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers>=2.7.0->t5->trax==1.3.9) (3.1.0)
Requirement already satisfied: joblib in /opt/conda/lib/python3.7/site-packages (from nltk->t5->trax==1.3.9) (0.14.1)
Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (from nltk->t5->trax==1.3.9) (7.1.2)
Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=2.7.0->t5->trax==1.3.9) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=2.7.0->t5->trax==1.3.9) (1.25.9)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=2.7.0->t5->trax==1.3.9) (2.9)
Requirement already satisfied: portalocker==2.0.0 in /opt/conda/lib/python3.7/site-packages (from sacrebleu->t5->trax==1.3.9) (2.0.0)
Requirement already satisfied: importlib-resources in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (5.1.4)
Requirement already satisfied: tensorflow-metadata in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (1.0.0)
Requirement already satisfied: protobuf>=3.12.2 in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (3.17.3)
Requirement already satisfied: attrs>=18.1.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (19.3.0)
Requirement already satisfied: termcolor in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (1.1.0)
Requirement already satisfied: dill in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (0.3.1.1)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (3.7.4.2)
Requirement already satisfied: promise in /opt/conda/lib/python3.7/site-packages (from tensorflow-datasets->trax==1.3.9) (2.3)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-metadata->tensorflow-datasets->trax==1.3.9) (1.53.0)
Requirement already satisfied: tensorflow<2.6,>=2.5.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-text->trax==1.3.9) (2.5.0)
Requirement already satisfied: tensorflow-hub>=0.8.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow-text->trax==1.3.9) (0.12.0)
Requirement already satisfied: wrapt~=1.12.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.12.1)
Requirement already satisfied: astunparse~=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.6.3)
Requirement already satisfied: wheel~=0.35 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.36.2)
Requirement already satisfied: grpcio~=1.34.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.34.1)
Requirement already satisfied: google-pasta~=0.2 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.2.0)
Requirement already satisfied: gast==0.4.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.4.0)
Requirement already satisfied: tensorflow-estimator<2.6.0,>=2.5.0rc0 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (2.5.0)
Requirement already satisfied: keras-preprocessing~=1.1.2 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.1.2)
Requirement already satisfied: h5py~=3.1.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (3.1.0)
Requirement already satisfied: keras-nightly~=2.5.0.dev in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (2.5.0.dev2021032900)
Requirement already satisfied: tensorboard~=2.5 in /opt/conda/lib/python3.7/site-packages (from tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (2.5.0)
Requirement already satisfied: cached-property in /opt/conda/lib/python3.7/site-packages (from h5py~=3.1.0->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.5.2)
Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.0.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.6.1)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (3.2.1)
Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (46.1.3.post20200325)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.8.0)
Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.14.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.4.1)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (4.1.0)
Requirement already satisfied: rsa<4.1,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (4.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (1.3.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow<2.6,>=2.5.0->tensorflow-text->trax==1.3.9) (3.0.1)
[33mWARNING: You are using pip version 21.1.2; however, version 22.1.2 is available.
You should consider upgrading via the '/opt/conda/bin/python3 -m pip install --upgrade pip' command.[0m
import numpy as np # regular ol' numpy
from trax import layers as tl # core building block
from trax import shapes # data signatures: dimensionality and type
from trax import fastmath # uses jax, offers numpy on steroids
# Trax version 1.3.9 or better
!pip list | grep trax
trax 1.3.9
[33mWARNING: You are using pip version 21.1.2; however, version 22.1.2 is available.
You should consider upgrading via the '/opt/conda/bin/python3 -m pip install --upgrade pip' command.[0m
Layers are the core building blocks in Trax or as mentioned in the lectures, they are the base classes.
They take inputs, compute functions/custom calculations and return outputs.
You can also inspect layer properties. Let me show you some examples.
First let’s see how to build a relu activation function as a layer. A layer like this is one of the simplest types. Notice there is no object initialization so it works just like a math function.
Note: Activation functions are also layers in Trax, which might look odd if you have been using other frameworks for a longer time.
# Layers
# Create a relu trax layer
relu = tl.Relu()
# Inspect properties
print("-- Properties --")
print("name :", relu.name)
print("expected inputs :", relu.n_in)
print("promised outputs :", relu.n_out, "\n")
# Inputs
x = np.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")
# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
-- Properties --
name : Serial
expected inputs : 1
promised outputs : 1
-- Inputs --
x : [-2 -1 0 1 2]
-- Outputs --
y : [0 0 0 1 2]
Now let’s check how to build a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2.
# Create a concatenate trax layer
concat = tl.Concatenate()
print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")
# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")
# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)
-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
-- Outputs --
y : [-10. -20. -30. 1. 2. 3.]
You can change the default settings of layers. For example, you can change the expected inputs for a concatenate layer from 2 to 3 using the optional parameter n_items
.
# Configure a concatenate layer
concat_3 = tl.Concatenate(n_items=3) # configure the layer's expected inputs
print("-- Properties --")
print("name :", concat_3.name)
print("expected inputs :", concat_3.n_in)
print("promised outputs :", concat_3.n_out, "\n")
# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")
# Outputs
y = concat_3([x1, x2, x3])
print("-- Outputs --")
print("y :", y)
-- Properties --
name : Concatenate
expected inputs : 3
promised outputs : 1
-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97]
-- Outputs --
y : [-10. -20. -30. 1. 2. 3. 0.99 1.98 2.97]
Note: At any point,if you want to refer the function help/ look up the documentation or use help function.
help(tl.Concatenate) #Uncomment this to see the function docstring with explaination
Help on class Concatenate in module trax.layers.combinators:
class Concatenate(trax.layers.base.Layer)
| Concatenate(n_items=2, axis=-1)
|
| Concatenates a number of tensors into a single tensor.
|
| For example::
|
| x = np.array([1, 2])
| y = np.array([3, 4])
| z = np.array([5, 6])
| concat3 = tl.Concatenate(n_items=3)
| z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6]
|
| Use the `axis` argument to specify on which axis to concatenate the tensors.
| By default it's the last axis, `axis=-1`, and `n_items=2`.
|
| Method resolution order:
| Concatenate
| trax.layers.base.Layer
| builtins.object
|
| Methods defined here:
|
| __init__(self, n_items=2, axis=-1)
| Creates a partially initialized, unconnected layer instance.
|
| Args:
| n_in: Number of inputs expected by this layer.
| n_out: Number of outputs promised by this layer.
| name: Class-like name for this layer; for use when printing this layer.
| sublayers_to_print: Sublayers to display when printing out this layer;
| if None (the default), display all sublayers.
|
| forward(self, xs)
| Executes this layer as part of a forward pass through the model.
|
| ----------------------------------------------------------------------
| Methods inherited from trax.layers.base.Layer:
|
| __call__(self, x, weights=None, state=None, rng=None)
| Makes layers callable; for use in tests or interactive settings.
|
| This convenience method helps library users play with, test, or otherwise
| probe the behavior of layers outside of a full training environment. It
| presents the layer as callable function from inputs to outputs, with the
| option of manually specifying weights and non-parameter state per individual
| call. For convenience, weights and non-parameter state are cached per layer
| instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
| and acquiring non-empty values either by initialization or from values
| explicitly provided via the weights and state keyword arguments, in which
| case the old weights will be preserved, and the state will be updated.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: Weights or `None`; if `None`, use self's cached weights value.
| state: State or `None`; if `None`, use self's cached state value.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
|
| Returns:
| Zero or more output tensors, packaged as described in the `Layer` class
| docstring.
|
| __repr__(self)
| Renders this layer as a medium-detailed string, to help in debugging.
|
| Subclasses should aim for high-signal/low-noise when overriding this
| method.
|
| Returns:
| A high signal-to-noise string representing this layer.
|
| __setattr__(self, attr, value)
| Sets class attributes and protects from typos.
|
| In Trax layers, we only allow to set the following public attributes::
|
| - weights
| - state
| - rng
|
| This function prevents from setting other public attributes to avoid typos,
| for example, this is not possible and would be without this function::
|
| [typo] layer.weighs = some_tensor
|
| If you need to set other public attributes in a derived class (which we
| do not recommend as in almost all cases it suffices to use a private
| attribute), override self._settable_attrs to include the attribute name.
|
| Args:
| attr: Name of the attribute to be set.
| value: Value to be assigned to the attribute.
|
| backward(self, inputs, output, grad, weights, state, new_state, rng)
| Custom backward pass to propagate gradients in a custom way.
|
| Args:
| inputs: Input tensors; can be a (possibly nested) tuple.
| output: The result of running this layer on inputs.
| grad: Gradient signal computed based on subsequent layers; its structure
| and shape must match output.
| weights: This layer's weights.
| state: This layer's state prior to the current forward pass.
| new_state: This layer's state after the current forward pass.
| rng: Single-use random number generator (JAX PRNG key).
|
| Returns:
| The custom gradient signal for the input. Note that we need to return
| a gradient for each argument of forward, so it will usually be a tuple
| of signals: the gradient for inputs and weights.
|
| init(self, input_signature, rng=None, use_cache=False)
| Initializes weights/state of this layer and its sublayers recursively.
|
| Initialization creates layer weights and state, for layers that use them.
| It derives the necessary array shapes and data types from the layer's input
| signature, which is itself just shape and data type information.
|
| For layers without weights or state, this method safely does nothing.
|
| This method is designed to create weights/state only once for each layer
| instance, even if the same layer instance occurs in multiple places in the
| network. This enables weight sharing to be implemented as layer sharing.
|
| Args:
| input_signature: `ShapeDtype` instance (if this layer takes one input)
| or list/tuple of `ShapeDtype` instances.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
| use_cache: If `True`, and if this layer instance has already been
| initialized elsewhere in the network, then return special marker
| values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
| Else return this layer's newly initialized weights and state.
|
| Returns:
| A `(weights, state)` tuple.
|
| init_from_file(self, file_name, weights_only=False, input_signature=None)
| Initializes this layer and its sublayers from a pickled checkpoint.
|
| In the common case (`weights_only=False`), the file must be a gziped pickled
| dictionary containing items with keys `'flat_weights', `'flat_state'` and
| `'input_signature'`, which are used to initialize this layer.
| If `input_signature` is specified, it's used instead of the one in the file.
| If `weights_only` is `True`, the dictionary does not need to have the
| `'flat_state'` item and the state it not restored either.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, initialize only the layer's weights. Else
| initialize both weights and state.
| input_signature: Input signature to be used instead of the one from file.
|
| Returns:
| A `(weights, state)` tuple.
|
| init_weights_and_state(self, input_signature)
| Initializes weights and state, to handle input with the given signature.
|
| A layer subclass must override this method if the layer uses weights or
| state. To initialize weights, set `self.weights` to desired (typically
| random) values. To initialize state (uncommon), set `self.state` to desired
| starting values.
|
| Args:
| input_signature: A `ShapeDtype` instance (if this layer takes one input)
| or a list/tuple of `ShapeDtype` instances.
|
| output_signature(self, input_signature)
| Returns output signature this layer would give for `input_signature`.
|
| pure_fn(self, x, weights, state, rng, use_cache=False)
| Applies this layer as a pure function with no optional args.
|
| This method exposes the layer's computation as a pure function. This is
| especially useful for JIT compilation. Do not override, use `forward`
| instead.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: A tuple or list of trainable weights, with one element for this
| layer if this layer has no sublayers, or one for each sublayer if
| this layer has sublayers. If a layer (or sublayer) has no trainable
| weights, the corresponding weights element is an empty tuple.
| state: Layer-specific non-parameter state that can update between batches.
| rng: Single-use random number generator (JAX PRNG key).
| use_cache: if `True`, cache weights and state in the layer object; used
| to implement layer sharing in combinators.
|
| Returns:
| A tuple of `(tensors, state)`. The tensors match the number (`n_out`)
| promised by this layer, and are packaged as described in the `Layer`
| class docstring.
|
| save_to_file(self, file_name, weights_only=False, input_signature=None)
| Saves this layer and its sublayers to a pickled checkpoint.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, save only the layer's weights. Else
| save both weights and state.
| input_signature: Input signature to be used.
|
| weights_and_state_signature(self, input_signature, unsafe=False)
| Return a pair containing the signatures of weights and state.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from trax.layers.base.Layer:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| has_backward
| Returns `True` if this layer provides its own custom backward pass code.
|
| A layer subclass that provides custom backward pass code (for custom
| gradients) must override this method to return `True`.
|
| n_in
| Returns how many tensors this layer expects as input.
|
| n_out
| Returns how many tensors this layer promises as output.
|
| name
| Returns the name of this layer.
|
| rng
| Returns this layer's current single-use random number generator.
|
| Code that wants to base random samples on this generator must explicitly
| split off new generators from it. (See, for example, the `rng` setter code
| below.)
|
| state
| Returns a tuple containing this layer's state; may be empty.
|
| If the layer has sublayers, the state by convention will be
| a tuple of length `len(sublayers)` containing sublayer states.
| Note that in this case self._state only marks which ones are shared.
|
| sublayers
| Returns a tuple containing this layer's sublayers; may be empty.
|
| weights
| Returns this layer's weights.
|
| Depending on the layer, the weights can be in the form of:
|
| - an empty tuple
| - a tensor (ndarray)
| - a nested structure of tuples and tensors
|
| If the layer has sublayers, the weights by convention will be
| a tuple of length `len(sublayers)` containing the weights of sublayers.
| Note that in this case self._weights only marks which ones are shared.
Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use.
For example the LayerNorm
layer calculates normalized data, that is also scaled by weights and biases. During initialization you pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases.
# Uncomment any of them to see information regarding the function
help(tl.LayerNorm)
help(shapes.signature)
Help on class LayerNorm in module trax.layers.normalization:
class LayerNorm(LayerNorm)
| LayerNorm(center=True, epsilon=1e-06)
|
| Layer normalization.
|
| Method resolution order:
| LayerNorm
| LayerNorm
| trax.layers.base.Layer
| builtins.object
|
| Methods defined here:
|
| __init__(self, center=True, epsilon=1e-06)
|
| ----------------------------------------------------------------------
| Methods inherited from LayerNorm:
|
| forward(self, x)
| Computes this layer's output as part of a forward pass through the model.
|
| A layer subclass overrides this method to define how the layer computes
| outputs from inputs. If the layer depends on weights, state, or randomness
| as part of the computation, the needed information can be accessed as
| properties of the layer object: `self.weights`, `self.state`, and
| `self.rng`. (See numerous examples in `trax.layers.core`.)
|
| Args:
| inputs: Zero or more input tensors, packaged as described in the `Layer`
| class docstring.
|
| Returns:
| Zero or more output tensors, packaged as described in the `Layer` class
| docstring.
|
| init_weights_and_state(self, input_signature)
| Initializes weights and state, to handle input with the given signature.
|
| A layer subclass must override this method if the layer uses weights or
| state. To initialize weights, set `self.weights` to desired (typically
| random) values. To initialize state (uncommon), set `self.state` to desired
| starting values.
|
| Args:
| input_signature: A `ShapeDtype` instance (if this layer takes one input)
| or a list/tuple of `ShapeDtype` instances.
|
| ----------------------------------------------------------------------
| Methods inherited from trax.layers.base.Layer:
|
| __call__(self, x, weights=None, state=None, rng=None)
| Makes layers callable; for use in tests or interactive settings.
|
| This convenience method helps library users play with, test, or otherwise
| probe the behavior of layers outside of a full training environment. It
| presents the layer as callable function from inputs to outputs, with the
| option of manually specifying weights and non-parameter state per individual
| call. For convenience, weights and non-parameter state are cached per layer
| instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
| and acquiring non-empty values either by initialization or from values
| explicitly provided via the weights and state keyword arguments, in which
| case the old weights will be preserved, and the state will be updated.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: Weights or `None`; if `None`, use self's cached weights value.
| state: State or `None`; if `None`, use self's cached state value.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
|
| Returns:
| Zero or more output tensors, packaged as described in the `Layer` class
| docstring.
|
| __repr__(self)
| Renders this layer as a medium-detailed string, to help in debugging.
|
| Subclasses should aim for high-signal/low-noise when overriding this
| method.
|
| Returns:
| A high signal-to-noise string representing this layer.
|
| __setattr__(self, attr, value)
| Sets class attributes and protects from typos.
|
| In Trax layers, we only allow to set the following public attributes::
|
| - weights
| - state
| - rng
|
| This function prevents from setting other public attributes to avoid typos,
| for example, this is not possible and would be without this function::
|
| [typo] layer.weighs = some_tensor
|
| If you need to set other public attributes in a derived class (which we
| do not recommend as in almost all cases it suffices to use a private
| attribute), override self._settable_attrs to include the attribute name.
|
| Args:
| attr: Name of the attribute to be set.
| value: Value to be assigned to the attribute.
|
| backward(self, inputs, output, grad, weights, state, new_state, rng)
| Custom backward pass to propagate gradients in a custom way.
|
| Args:
| inputs: Input tensors; can be a (possibly nested) tuple.
| output: The result of running this layer on inputs.
| grad: Gradient signal computed based on subsequent layers; its structure
| and shape must match output.
| weights: This layer's weights.
| state: This layer's state prior to the current forward pass.
| new_state: This layer's state after the current forward pass.
| rng: Single-use random number generator (JAX PRNG key).
|
| Returns:
| The custom gradient signal for the input. Note that we need to return
| a gradient for each argument of forward, so it will usually be a tuple
| of signals: the gradient for inputs and weights.
|
| init(self, input_signature, rng=None, use_cache=False)
| Initializes weights/state of this layer and its sublayers recursively.
|
| Initialization creates layer weights and state, for layers that use them.
| It derives the necessary array shapes and data types from the layer's input
| signature, which is itself just shape and data type information.
|
| For layers without weights or state, this method safely does nothing.
|
| This method is designed to create weights/state only once for each layer
| instance, even if the same layer instance occurs in multiple places in the
| network. This enables weight sharing to be implemented as layer sharing.
|
| Args:
| input_signature: `ShapeDtype` instance (if this layer takes one input)
| or list/tuple of `ShapeDtype` instances.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
| use_cache: If `True`, and if this layer instance has already been
| initialized elsewhere in the network, then return special marker
| values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
| Else return this layer's newly initialized weights and state.
|
| Returns:
| A `(weights, state)` tuple.
|
| init_from_file(self, file_name, weights_only=False, input_signature=None)
| Initializes this layer and its sublayers from a pickled checkpoint.
|
| In the common case (`weights_only=False`), the file must be a gziped pickled
| dictionary containing items with keys `'flat_weights', `'flat_state'` and
| `'input_signature'`, which are used to initialize this layer.
| If `input_signature` is specified, it's used instead of the one in the file.
| If `weights_only` is `True`, the dictionary does not need to have the
| `'flat_state'` item and the state it not restored either.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, initialize only the layer's weights. Else
| initialize both weights and state.
| input_signature: Input signature to be used instead of the one from file.
|
| Returns:
| A `(weights, state)` tuple.
|
| output_signature(self, input_signature)
| Returns output signature this layer would give for `input_signature`.
|
| pure_fn(self, x, weights, state, rng, use_cache=False)
| Applies this layer as a pure function with no optional args.
|
| This method exposes the layer's computation as a pure function. This is
| especially useful for JIT compilation. Do not override, use `forward`
| instead.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: A tuple or list of trainable weights, with one element for this
| layer if this layer has no sublayers, or one for each sublayer if
| this layer has sublayers. If a layer (or sublayer) has no trainable
| weights, the corresponding weights element is an empty tuple.
| state: Layer-specific non-parameter state that can update between batches.
| rng: Single-use random number generator (JAX PRNG key).
| use_cache: if `True`, cache weights and state in the layer object; used
| to implement layer sharing in combinators.
|
| Returns:
| A tuple of `(tensors, state)`. The tensors match the number (`n_out`)
| promised by this layer, and are packaged as described in the `Layer`
| class docstring.
|
| save_to_file(self, file_name, weights_only=False, input_signature=None)
| Saves this layer and its sublayers to a pickled checkpoint.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, save only the layer's weights. Else
| save both weights and state.
| input_signature: Input signature to be used.
|
| weights_and_state_signature(self, input_signature, unsafe=False)
| Return a pair containing the signatures of weights and state.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from trax.layers.base.Layer:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| has_backward
| Returns `True` if this layer provides its own custom backward pass code.
|
| A layer subclass that provides custom backward pass code (for custom
| gradients) must override this method to return `True`.
|
| n_in
| Returns how many tensors this layer expects as input.
|
| n_out
| Returns how many tensors this layer promises as output.
|
| name
| Returns the name of this layer.
|
| rng
| Returns this layer's current single-use random number generator.
|
| Code that wants to base random samples on this generator must explicitly
| split off new generators from it. (See, for example, the `rng` setter code
| below.)
|
| state
| Returns a tuple containing this layer's state; may be empty.
|
| If the layer has sublayers, the state by convention will be
| a tuple of length `len(sublayers)` containing sublayer states.
| Note that in this case self._state only marks which ones are shared.
|
| sublayers
| Returns a tuple containing this layer's sublayers; may be empty.
|
| weights
| Returns this layer's weights.
|
| Depending on the layer, the weights can be in the form of:
|
| - an empty tuple
| - a tensor (ndarray)
| - a nested structure of tuples and tensors
|
| If the layer has sublayers, the weights by convention will be
| a tuple of length `len(sublayers)` containing the weights of sublayers.
| Note that in this case self._weights only marks which ones are shared.
Help on function signature in module trax.shapes:
signature(obj)
Returns a `ShapeDtype` signature for the given `obj`.
A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
instances. Note that this function is permissive with respect to its inputs
(accepts lists or tuples or dicts, and underlying objects can be any type
as long as they have shape and dtype attributes) and returns the corresponding
nested structure of `ShapeDtype`.
Args:
obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
of such objects.
Returns:
A corresponding nested structure of `ShapeDtype` instances.
# Layer initialization
norm = tl.LayerNorm()
# You first must know what the input data will look like
x = np.array([0, 1, 2, 3], dtype="float")
# Use the input data signature to get shape and type for initializing weights and biases
norm.init(shapes.signature(x)) # We need to convert the input datatype from usual tuple to trax ShapeDtype
print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))
# Inspect properties
print("-- Properties --")
print("name :", norm.name)
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")
# Inputs
print("-- Inputs --")
print("x :", x)
# Outputs
y = norm(x)
print("-- Outputs --")
print("y :", y)
Normal shape: (4,) Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float64} Data Type: <class 'trax.shapes.ShapeDtype'>
-- Properties --
name : LayerNorm
expected inputs : 1
promised outputs : 1
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.]
-- Inputs --
x : [0. 1. 2. 3.]
-- Outputs --
y : [-1.3416404 -0.44721344 0.44721344 1.3416404 ]
This is where things start getting more interesting!
You can create your own custom layers too and define custom functions for computations by using tl.Fn
. Let me show you how.
help(tl.Fn) # Uncomment to see information regarding the function
Help on function Fn in module trax.layers.base:
Fn(name, f, n_out=1)
Returns a layer with no weights that applies the function `f`.
`f` can take and return any number of arguments, and takes only positional
arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
The following, for example, would create a layer that takes two inputs and
returns two outputs -- element-wise sums and maxima:
`Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
The layer's number of inputs (`n_in`) is automatically set to number of
positional arguments in `f`, but you must explicitly set the number of
outputs (`n_out`) whenever it's not the default value 1.
Args:
name: Class-like name for the resulting layer; for use in debugging.
f: Pure function from input tensors to output tensors, where each input
tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
Output tensors must be packaged as specified in the `Layer` class
docstring.
n_out: Number of outputs promised by the layer; default value 1.
Returns:
Layer executing the function `f`.
# Define a custom layer
# In this example you will create a layer to calculate the input times 2
def TimesTwo():
layer_name = "TimesTwo" #don't forget to give your custom layer a name to identify
# Custom function for the custom layer
def func(x):
return x * 2
return tl.Fn(layer_name, func)
# Test it
times_two = TimesTwo()
# Inspect properties
print("-- Properties --")
print("name :", times_two.name)
print("expected inputs :", times_two.n_in)
print("promised outputs :", times_two.n_out, "\n")
# Inputs
x = np.array([1, 2, 3])
print("-- Inputs --")
print("x :", x, "\n")
# Outputs
y = times_two(x)
print("-- Outputs --")
print("y :", y)
-- Properties --
name : TimesTwo
expected inputs : 1
promised outputs : 1
-- Inputs --
x : [1 2 3]
-- Outputs --
y : [2 4 6]
You can combine layers to build more complex layers. Trax provides a set of objects named combinator layers to make this happen. Combinators are themselves layers, so behavior commutes.
This is the most common and easiest to use. For example could build a simple neural network by combining layers into a single layer using the Serial
combinator. This new layer then acts just like a single layer, so you can inspect intputs, outputs and weights. Or even combine it into another layer! Combinators can then be used as trainable models. Try adding more layers
Note: As you must have guessed, if there is serial combinator, there must be a parallel combinator as well. Do try to explore about combinators and other layers from the trax documentation and look at the repo to understand how these layers are written.
# Uncomment any of them to see information regarding the function
help(tl.Serial)
help(tl.Parallel)
Help on class Serial in module trax.layers.combinators:
class Serial(trax.layers.base.Layer)
| Serial(*sublayers, name=None, sublayers_to_print=None)
|
| Combinator that applies layers serially (by function composition).
|
| This combinator is commonly used to construct deep networks, e.g., like this::
|
| mlp = tl.Serial(
| tl.Dense(128),
| tl.Relu(),
| tl.Dense(10),
| )
|
| A Serial combinator uses stack semantics to manage data for its sublayers.
| Each sublayer sees only the inputs it needs and returns only the outputs it
| has generated. The sublayers interact via the data stack. For instance, a
| sublayer k, following sublayer j, gets called with the data stack in the
| state left after layer j has applied. The Serial combinator then:
|
| - takes n_in items off the top of the stack (n_in = k.n_in) and calls
| layer k, passing those items as arguments; and
|
| - takes layer k's n_out return values (n_out = k.n_out) and pushes
| them onto the data stack.
|
| A Serial instance with no sublayers acts as a special-case (but useful)
| 1-input 1-output no-op.
|
| Method resolution order:
| Serial
| trax.layers.base.Layer
| builtins.object
|
| Methods defined here:
|
| __init__(self, *sublayers, name=None, sublayers_to_print=None)
| Creates a partially initialized, unconnected layer instance.
|
| Args:
| n_in: Number of inputs expected by this layer.
| n_out: Number of outputs promised by this layer.
| name: Class-like name for this layer; for use when printing this layer.
| sublayers_to_print: Sublayers to display when printing out this layer;
| if None (the default), display all sublayers.
|
| forward(self, xs)
| Executes this layer as part of a forward pass through the model.
|
| init_weights_and_state(self, input_signature)
| Initializes weights and state for inputs with the given signature.
|
| ----------------------------------------------------------------------
| Methods inherited from trax.layers.base.Layer:
|
| __call__(self, x, weights=None, state=None, rng=None)
| Makes layers callable; for use in tests or interactive settings.
|
| This convenience method helps library users play with, test, or otherwise
| probe the behavior of layers outside of a full training environment. It
| presents the layer as callable function from inputs to outputs, with the
| option of manually specifying weights and non-parameter state per individual
| call. For convenience, weights and non-parameter state are cached per layer
| instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
| and acquiring non-empty values either by initialization or from values
| explicitly provided via the weights and state keyword arguments, in which
| case the old weights will be preserved, and the state will be updated.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: Weights or `None`; if `None`, use self's cached weights value.
| state: State or `None`; if `None`, use self's cached state value.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
|
| Returns:
| Zero or more output tensors, packaged as described in the `Layer` class
| docstring.
|
| __repr__(self)
| Renders this layer as a medium-detailed string, to help in debugging.
|
| Subclasses should aim for high-signal/low-noise when overriding this
| method.
|
| Returns:
| A high signal-to-noise string representing this layer.
|
| __setattr__(self, attr, value)
| Sets class attributes and protects from typos.
|
| In Trax layers, we only allow to set the following public attributes::
|
| - weights
| - state
| - rng
|
| This function prevents from setting other public attributes to avoid typos,
| for example, this is not possible and would be without this function::
|
| [typo] layer.weighs = some_tensor
|
| If you need to set other public attributes in a derived class (which we
| do not recommend as in almost all cases it suffices to use a private
| attribute), override self._settable_attrs to include the attribute name.
|
| Args:
| attr: Name of the attribute to be set.
| value: Value to be assigned to the attribute.
|
| backward(self, inputs, output, grad, weights, state, new_state, rng)
| Custom backward pass to propagate gradients in a custom way.
|
| Args:
| inputs: Input tensors; can be a (possibly nested) tuple.
| output: The result of running this layer on inputs.
| grad: Gradient signal computed based on subsequent layers; its structure
| and shape must match output.
| weights: This layer's weights.
| state: This layer's state prior to the current forward pass.
| new_state: This layer's state after the current forward pass.
| rng: Single-use random number generator (JAX PRNG key).
|
| Returns:
| The custom gradient signal for the input. Note that we need to return
| a gradient for each argument of forward, so it will usually be a tuple
| of signals: the gradient for inputs and weights.
|
| init(self, input_signature, rng=None, use_cache=False)
| Initializes weights/state of this layer and its sublayers recursively.
|
| Initialization creates layer weights and state, for layers that use them.
| It derives the necessary array shapes and data types from the layer's input
| signature, which is itself just shape and data type information.
|
| For layers without weights or state, this method safely does nothing.
|
| This method is designed to create weights/state only once for each layer
| instance, even if the same layer instance occurs in multiple places in the
| network. This enables weight sharing to be implemented as layer sharing.
|
| Args:
| input_signature: `ShapeDtype` instance (if this layer takes one input)
| or list/tuple of `ShapeDtype` instances.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
| use_cache: If `True`, and if this layer instance has already been
| initialized elsewhere in the network, then return special marker
| values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
| Else return this layer's newly initialized weights and state.
|
| Returns:
| A `(weights, state)` tuple.
|
| init_from_file(self, file_name, weights_only=False, input_signature=None)
| Initializes this layer and its sublayers from a pickled checkpoint.
|
| In the common case (`weights_only=False`), the file must be a gziped pickled
| dictionary containing items with keys `'flat_weights', `'flat_state'` and
| `'input_signature'`, which are used to initialize this layer.
| If `input_signature` is specified, it's used instead of the one in the file.
| If `weights_only` is `True`, the dictionary does not need to have the
| `'flat_state'` item and the state it not restored either.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, initialize only the layer's weights. Else
| initialize both weights and state.
| input_signature: Input signature to be used instead of the one from file.
|
| Returns:
| A `(weights, state)` tuple.
|
| output_signature(self, input_signature)
| Returns output signature this layer would give for `input_signature`.
|
| pure_fn(self, x, weights, state, rng, use_cache=False)
| Applies this layer as a pure function with no optional args.
|
| This method exposes the layer's computation as a pure function. This is
| especially useful for JIT compilation. Do not override, use `forward`
| instead.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: A tuple or list of trainable weights, with one element for this
| layer if this layer has no sublayers, or one for each sublayer if
| this layer has sublayers. If a layer (or sublayer) has no trainable
| weights, the corresponding weights element is an empty tuple.
| state: Layer-specific non-parameter state that can update between batches.
| rng: Single-use random number generator (JAX PRNG key).
| use_cache: if `True`, cache weights and state in the layer object; used
| to implement layer sharing in combinators.
|
| Returns:
| A tuple of `(tensors, state)`. The tensors match the number (`n_out`)
| promised by this layer, and are packaged as described in the `Layer`
| class docstring.
|
| save_to_file(self, file_name, weights_only=False, input_signature=None)
| Saves this layer and its sublayers to a pickled checkpoint.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, save only the layer's weights. Else
| save both weights and state.
| input_signature: Input signature to be used.
|
| weights_and_state_signature(self, input_signature, unsafe=False)
| Return a pair containing the signatures of weights and state.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from trax.layers.base.Layer:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| has_backward
| Returns `True` if this layer provides its own custom backward pass code.
|
| A layer subclass that provides custom backward pass code (for custom
| gradients) must override this method to return `True`.
|
| n_in
| Returns how many tensors this layer expects as input.
|
| n_out
| Returns how many tensors this layer promises as output.
|
| name
| Returns the name of this layer.
|
| rng
| Returns this layer's current single-use random number generator.
|
| Code that wants to base random samples on this generator must explicitly
| split off new generators from it. (See, for example, the `rng` setter code
| below.)
|
| state
| Returns a tuple containing this layer's state; may be empty.
|
| If the layer has sublayers, the state by convention will be
| a tuple of length `len(sublayers)` containing sublayer states.
| Note that in this case self._state only marks which ones are shared.
|
| sublayers
| Returns a tuple containing this layer's sublayers; may be empty.
|
| weights
| Returns this layer's weights.
|
| Depending on the layer, the weights can be in the form of:
|
| - an empty tuple
| - a tensor (ndarray)
| - a nested structure of tuples and tensors
|
| If the layer has sublayers, the weights by convention will be
| a tuple of length `len(sublayers)` containing the weights of sublayers.
| Note that in this case self._weights only marks which ones are shared.
Help on class Parallel in module trax.layers.combinators:
class Parallel(trax.layers.base.Layer)
| Parallel(*sublayers, name=None)
|
| Combinator that applies a list of layers in parallel to its inputs.
|
| Layers in the list apply to successive spans of inputs, where the spans are
| determined how many inputs each layer takes. The resulting output is the
| (flattened) concatenation of the respective layer outputs.
|
| For example, suppose one has three layers:
|
| - F: 1 input, 1 output
| - G: 3 inputs, 1 output
| - H: 2 inputs, 2 outputs (h1, h2)
|
| Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:
|
| - inputs: a, b, c, d, e, f
| - outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f)
|
| As an important special case, a None argument to Parallel acts as if it takes
| one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For
| example:
|
| Parallel(None, F)
|
| creates a layer that passes its first input unchanged and applies F to the
| following input(s).
|
| Method resolution order:
| Parallel
| trax.layers.base.Layer
| builtins.object
|
| Methods defined here:
|
| __init__(self, *sublayers, name=None)
| The constructor.
|
| Args:
| *sublayers: A list of sublayers.
| name: Descriptive name for this layer.
|
| Returns:
| A new layer in which each of the given sublayers applies to its
| corresponding span of elements in the dataflow stack.
|
| forward(self, inputs)
| Executes this layer as part of a forward pass through the model.
|
| init_weights_and_state(self, input_signature)
| Initializes weights and state for inputs with the given signature.
|
| ----------------------------------------------------------------------
| Methods inherited from trax.layers.base.Layer:
|
| __call__(self, x, weights=None, state=None, rng=None)
| Makes layers callable; for use in tests or interactive settings.
|
| This convenience method helps library users play with, test, or otherwise
| probe the behavior of layers outside of a full training environment. It
| presents the layer as callable function from inputs to outputs, with the
| option of manually specifying weights and non-parameter state per individual
| call. For convenience, weights and non-parameter state are cached per layer
| instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
| and acquiring non-empty values either by initialization or from values
| explicitly provided via the weights and state keyword arguments, in which
| case the old weights will be preserved, and the state will be updated.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: Weights or `None`; if `None`, use self's cached weights value.
| state: State or `None`; if `None`, use self's cached state value.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
|
| Returns:
| Zero or more output tensors, packaged as described in the `Layer` class
| docstring.
|
| __repr__(self)
| Renders this layer as a medium-detailed string, to help in debugging.
|
| Subclasses should aim for high-signal/low-noise when overriding this
| method.
|
| Returns:
| A high signal-to-noise string representing this layer.
|
| __setattr__(self, attr, value)
| Sets class attributes and protects from typos.
|
| In Trax layers, we only allow to set the following public attributes::
|
| - weights
| - state
| - rng
|
| This function prevents from setting other public attributes to avoid typos,
| for example, this is not possible and would be without this function::
|
| [typo] layer.weighs = some_tensor
|
| If you need to set other public attributes in a derived class (which we
| do not recommend as in almost all cases it suffices to use a private
| attribute), override self._settable_attrs to include the attribute name.
|
| Args:
| attr: Name of the attribute to be set.
| value: Value to be assigned to the attribute.
|
| backward(self, inputs, output, grad, weights, state, new_state, rng)
| Custom backward pass to propagate gradients in a custom way.
|
| Args:
| inputs: Input tensors; can be a (possibly nested) tuple.
| output: The result of running this layer on inputs.
| grad: Gradient signal computed based on subsequent layers; its structure
| and shape must match output.
| weights: This layer's weights.
| state: This layer's state prior to the current forward pass.
| new_state: This layer's state after the current forward pass.
| rng: Single-use random number generator (JAX PRNG key).
|
| Returns:
| The custom gradient signal for the input. Note that we need to return
| a gradient for each argument of forward, so it will usually be a tuple
| of signals: the gradient for inputs and weights.
|
| init(self, input_signature, rng=None, use_cache=False)
| Initializes weights/state of this layer and its sublayers recursively.
|
| Initialization creates layer weights and state, for layers that use them.
| It derives the necessary array shapes and data types from the layer's input
| signature, which is itself just shape and data type information.
|
| For layers without weights or state, this method safely does nothing.
|
| This method is designed to create weights/state only once for each layer
| instance, even if the same layer instance occurs in multiple places in the
| network. This enables weight sharing to be implemented as layer sharing.
|
| Args:
| input_signature: `ShapeDtype` instance (if this layer takes one input)
| or list/tuple of `ShapeDtype` instances.
| rng: Single-use random number generator (JAX PRNG key), or `None`;
| if `None`, use a default computed from an integer 0 seed.
| use_cache: If `True`, and if this layer instance has already been
| initialized elsewhere in the network, then return special marker
| values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
| Else return this layer's newly initialized weights and state.
|
| Returns:
| A `(weights, state)` tuple.
|
| init_from_file(self, file_name, weights_only=False, input_signature=None)
| Initializes this layer and its sublayers from a pickled checkpoint.
|
| In the common case (`weights_only=False`), the file must be a gziped pickled
| dictionary containing items with keys `'flat_weights', `'flat_state'` and
| `'input_signature'`, which are used to initialize this layer.
| If `input_signature` is specified, it's used instead of the one in the file.
| If `weights_only` is `True`, the dictionary does not need to have the
| `'flat_state'` item and the state it not restored either.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, initialize only the layer's weights. Else
| initialize both weights and state.
| input_signature: Input signature to be used instead of the one from file.
|
| Returns:
| A `(weights, state)` tuple.
|
| output_signature(self, input_signature)
| Returns output signature this layer would give for `input_signature`.
|
| pure_fn(self, x, weights, state, rng, use_cache=False)
| Applies this layer as a pure function with no optional args.
|
| This method exposes the layer's computation as a pure function. This is
| especially useful for JIT compilation. Do not override, use `forward`
| instead.
|
| Args:
| x: Zero or more input tensors, packaged as described in the `Layer` class
| docstring.
| weights: A tuple or list of trainable weights, with one element for this
| layer if this layer has no sublayers, or one for each sublayer if
| this layer has sublayers. If a layer (or sublayer) has no trainable
| weights, the corresponding weights element is an empty tuple.
| state: Layer-specific non-parameter state that can update between batches.
| rng: Single-use random number generator (JAX PRNG key).
| use_cache: if `True`, cache weights and state in the layer object; used
| to implement layer sharing in combinators.
|
| Returns:
| A tuple of `(tensors, state)`. The tensors match the number (`n_out`)
| promised by this layer, and are packaged as described in the `Layer`
| class docstring.
|
| save_to_file(self, file_name, weights_only=False, input_signature=None)
| Saves this layer and its sublayers to a pickled checkpoint.
|
| Args:
| file_name: Name/path of the pickled weights/state file.
| weights_only: If `True`, save only the layer's weights. Else
| save both weights and state.
| input_signature: Input signature to be used.
|
| weights_and_state_signature(self, input_signature, unsafe=False)
| Return a pair containing the signatures of weights and state.
|
| ----------------------------------------------------------------------
| Data descriptors inherited from trax.layers.base.Layer:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| has_backward
| Returns `True` if this layer provides its own custom backward pass code.
|
| A layer subclass that provides custom backward pass code (for custom
| gradients) must override this method to return `True`.
|
| n_in
| Returns how many tensors this layer expects as input.
|
| n_out
| Returns how many tensors this layer promises as output.
|
| name
| Returns the name of this layer.
|
| rng
| Returns this layer's current single-use random number generator.
|
| Code that wants to base random samples on this generator must explicitly
| split off new generators from it. (See, for example, the `rng` setter code
| below.)
|
| state
| Returns a tuple containing this layer's state; may be empty.
|
| If the layer has sublayers, the state by convention will be
| a tuple of length `len(sublayers)` containing sublayer states.
| Note that in this case self._state only marks which ones are shared.
|
| sublayers
| Returns a tuple containing this layer's sublayers; may be empty.
|
| weights
| Returns this layer's weights.
|
| Depending on the layer, the weights can be in the form of:
|
| - an empty tuple
| - a tensor (ndarray)
| - a nested structure of tuples and tensors
|
| If the layer has sublayers, the weights by convention will be
| a tuple of length `len(sublayers)` containing the weights of sublayers.
| Note that in this case self._weights only marks which ones are shared.
# Serial combinator
serial = tl.Serial(
tl.LayerNorm(), # normalize input
tl.Relu(), # convert negative values to zero
times_two, # the custom layer you created above, multiplies the input recieved from above by 2
### START CODE HERE
# tl.Dense(n_units=2), # try adding more layers. eg uncomment these lines
# tl.Dense(n_units=1), # Binary classification, maybe? uncomment at your own peril
# tl.LogSoftmax() # Yes, LogSoftmax is also a layer
### END CODE HERE
)
# Initialization
x = np.array([-2, -1, 0, 1, 2]) #input
serial.init(shapes.signature(x)) #initialising serial instance
print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")
# Inputs
print("-- Inputs --")
print("x :", x, "\n")
# Outputs
y = serial(x)
print("-- Outputs --")
print("y :", y)
-- Serial Model --
Serial[
LayerNorm
Serial[
Relu
]
TimesTwo
]
-- Properties --
name : Serial
sublayers : [LayerNorm, Serial[
Relu
], TimesTwo]
expected inputs : 1
promised outputs : 1
weights & biases: ((DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), ((), (), ()), ())
-- Inputs --
x : [-2 -1 0 1 2]
-- Outputs --
y : [0. 0. 0. 1.4142132 2.8284264]
Just remember to lookout for which numpy you are using, the regular ol’ numpy or Trax’s JAX compatible numpy. Both tend to use the alias np so watch those import blocks.
Note: There are certain things which are still not possible in fastmath.numpy which can be done in numpy so you will see in assignments we will switch between them to get our work done.
# Numpy vs fastmath.numpy have different data types
# Regular ol' numpy
x_numpy = np.array([1, 2, 3])
print("good old numpy : ", type(x_numpy), "\n")
# Fastmath and jax numpy
x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy : ", type(x_jax))
good old numpy : <class 'numpy.ndarray'>
jax trax numpy : <class 'jaxlib.xla_extension.DeviceArray'>
Trax is a concise framework, built on TensorFlow, for end to end machine learning. The key building blocks are layers and combinators. This notebook is just a taste, but sets you up with some key inuitions to take forward into the rest of the course and assignments where you will build end to end models.