Coursera

Ungraded Lab: Hyperparameter tuning and model training with TFX

In this lab, you will be again doing hyperparameter tuning but this time, it will be within a Tensorflow Extended (TFX) pipeline.

We have already introduced some TFX components in Course 2 of this specialization related to data ingestion, validation, and transformation. In this notebook, you will get to work with two more which are related to model development and training: Tuner and Trainer.

tfx pipeline image source: https://www.tensorflow.org/tfx/guide

You will again be working with the FashionMNIST dataset and will feed it though the TFX pipeline up to the Trainer component.You will quickly review the earlier components from Course 2, then focus on the two new components introduced.

Let’s begin!

Imports

import os
import pprint

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from absl import logging

from tfx import v1 as tfx
from tfx.proto import example_gen_pb2, trainer_pb2
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

tf.get_logger().propagate = False
tf.get_logger().setLevel('ERROR')
pp = pprint.PrettyPrinter()
logging.set_verbosity(logging.ERROR)

Load and prepare the dataset

As mentioned earlier, you will be using the Fashion MNIST dataset just like in the previous lab. This will allow you to compare the similarities and differences when using Keras Tuner as a standalone library and within an ML pipeline.

You will first need to setup the directories that you will use to store the dataset, as well as the pipeline artifacts and metadata store.

# Location of the pipeline metadata store
_pipeline_root = './pipeline/'

# Directory of the raw data files
_data_root = './data/fmnist'

# Temporary directory
tempdir = './tempdir'
# Create the dataset directory
!mkdir -p {_data_root}

# Create the TFX pipeline files directory
!mkdir {_pipeline_root}

You will now load FashionMNIST from Tensorflow Datasets. The with_info flag will be set to True so you can display information about the dataset in the next cell (i.e. using ds_info). This is already in your workspace so the download flag is set to False.

# Download the dataset
ds, ds_info = tfds.load('fashion_mnist', data_dir=tempdir, with_info=True, download=False)
# Display info about the dataset
print(ds_info)
tfds.core.DatasetInfo(
    name='fashion_mnist',
    full_name='fashion_mnist/3.0.1',
    description="""
    Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.
    """,
    homepage='https://github.com/zalandoresearch/fashion-mnist',
    data_path='./tempdir/fashion_mnist/3.0.1',
    file_format=tfrecord,
    download_size=29.45 MiB,
    dataset_size=36.42 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{DBLP:journals/corr/abs-1708-07747,
      author    = {Han Xiao and
                   Kashif Rasul and
                   Roland Vollgraf},
      title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
                   Algorithms},
      journal   = {CoRR},
      volume    = {abs/1708.07747},
      year      = {2017},
      url       = {http://arxiv.org/abs/1708.07747},
      archivePrefix = {arXiv},
      eprint    = {1708.07747},
      timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
      biburl    = {https://dblp.org/rec/bib/journals/corr/abs-1708-07747},
      bibsource = {dblp computer science bibliography, https://dblp.org}
    }""",
)

You can review the downloaded files with the code below. For this lab, you will be using the train TFRecord so you will need to take note of its filename. You will not use the test TFRecord in this lab.

# Define the location of the train tfrecord downloaded via TFDS
tfds_data_path = f'{tempdir}/{ds_info.name}/{ds_info.version}'

# Display contents of the TFDS data directory
os.listdir(tfds_data_path)
['fashion_mnist-train.tfrecord-00000-of-00001',
 'dataset_info.json',
 'fashion_mnist-test.tfrecord-00000-of-00001',
 'features.json',
 'label.labels.txt']

You will then copy the train split from the downloaded data so it can be consumed by the ExampleGen component in the next step. This component requires that your files are in a directory without extra files (e.g. JSONs and TXT files).

# Define the train tfrecord filename
train_filename = 'fashion_mnist-train.tfrecord-00000-of-00001'

# Copy the train tfrecord into the data root folder
!cp {tfds_data_path}/{train_filename} {_data_root}

TFX Pipeline

With the setup complete, you can now proceed to creating the pipeline.

Initialize the Interactive Context

You will start by initializing the InteractiveContext so you can run the components within this notebook environment. You can safely ignore the warning because you will just be using a local SQLite file for the metadata store.

# Initialize the InteractiveContext
context = InteractiveContext(pipeline_root=_pipeline_root)

ExampleGen

You will start the pipeline by ingesting the TFRecord you set aside. The ImportExampleGen consumes TFRecords and you can specify splits as shown below. For this exercise, you will split the train tfrecord to use 80% for the train set, and the remaining 20% as eval/validation set.

# Specify 80/20 split for the train and eval set
output = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
        example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=8),
        example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2),
    ]))

# Ingest the data through ExampleGen
example_gen = tfx.components.ImportExampleGen(input_base=_data_root, output_config=output)

# Run the component
context.run(example_gen)
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fcbb82a7710
.execution_id1
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs{}
.component.outputs
['examples']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
# Print split names and URI
artifact = example_gen.outputs['examples'].get()[0]
print(artifact.split_names, artifact.uri)
["train", "eval"] ./pipeline/ImportExampleGen/examples/1

StatisticsGen

Next, you will compute the statistics of the dataset with the StatisticsGen component.

# Run StatisticsGen
statistics_gen = tfx.components.StatisticsGen(
    examples=example_gen.outputs['examples'])

context.run(statistics_gen)
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fcb079ed390
.execution_id2
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs
['examples']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.outputs
['statistics']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }

SchemaGen

You can then infer the dataset schema with SchemaGen. This will be used to validate incoming data to ensure that it is formatted correctly.

# Run SchemaGen
schema_gen = tfx.components.SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
context.run(schema_gen)
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fcb9c103e10
.execution_id3
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs
['statistics']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.outputs
['schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
# Visualize the results
context.show(schema_gen.outputs['schema'])

Artifact at ./pipeline/SchemaGen/schema/3

<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
Type Presence Valency Domain
Feature name
'image' BYTES required -
'label' INT required -

ExampleValidator

You can assume that the dataset is clean since we downloaded it from TFDS. But just to review, let’s run it through ExampleValidator to detect if there are anomalies within the dataset.

# Run ExampleValidator
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(example_validator)
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fc989b5ea90
.execution_id4
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs
['statistics']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.outputs
['anomalies']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
# Visualize the results. There should be no anomalies.
context.show(example_validator.outputs['anomalies'])

Artifact at ./pipeline/ExampleValidator/anomalies/4

'train' split:

No anomalies found.

'eval' split:

No anomalies found.

Transform

Let’s now use the Transform component to scale the image pixels and convert the data types to float. You will first define the transform module containing these operations before you run the component.

_transform_module_file = 'fmnist_transform.py'
%%writefile {_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft

# Keys
_LABEL_KEY = 'label'
_IMAGE_KEY = 'image'


def _transformed_name(key):
    return key + '_xf'

def _image_parser(image_str):
    '''converts the images to a float tensor'''
    image = tf.image.decode_image(image_str, channels=1)
    image = tf.reshape(image, (28, 28, 1))
    image = tf.cast(image, tf.float32)
    return image


def _label_parser(label_id):
    '''converts the labels to a float tensor'''
    label = tf.cast(label_id, tf.float32)
    return label


def preprocessing_fn(inputs):
    """tf.transform's callback function for preprocessing inputs.
    Args:
        inputs: map from feature keys to raw not-yet-transformed features.
    Returns:
        Map from string feature key to transformed feature operations.
    """
    
    # Convert the raw image and labels to a float array
    with tf.device("/cpu:0"):
        outputs = {
            _transformed_name(_IMAGE_KEY):
                tf.map_fn(
                    _image_parser,
                    tf.squeeze(inputs[_IMAGE_KEY], axis=1),
                    dtype=tf.float32),
            _transformed_name(_LABEL_KEY):
                tf.map_fn(
                    _label_parser,
                    inputs[_LABEL_KEY],
                    dtype=tf.float32)
        }
    
    # scale the pixels from 0 to 1
    outputs[_transformed_name(_IMAGE_KEY)] = tft.scale_to_0_1(outputs[_transformed_name(_IMAGE_KEY)])
    
    return outputs
Writing fmnist_transform.py

You will run the component by passing in the examples, schema, and transform module file.

Note: You can safely ignore the warnings and udf_utils related errors.

# Setup the Transform component
transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_transform_module_file))

# Run the component
context.run(transform)
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[<class 'str'>, Union[<class 'NoneType'>, <class 'tfx.components.transform.executor._Dataset'>]], Union[<class 'NoneType'>, Dict[<class 'str'>, Dict[<class 'str'>, <class 'apache_beam.pvalue.PCollection'>]]], <class 'int'>] instead.
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[<class 'str'>, Union[<class 'NoneType'>, <class 'tfx.components.transform.executor._Dataset'>]], Union[<class 'NoneType'>, Dict[<class 'str'>, Dict[<class 'str'>, <class 'apache_beam.pvalue.PCollection'>]]], <class 'int'>] instead.
WARNING:root:This input type hint will be ignored and not used for type-checking purposes. Typically, input type hints for a PTransform are single (or nested) types wrapped by a PCollection, or PBegin. Got: Dict[<class 'tensorflow_transform.beam.analyzer_cache.DatasetKey'>, <class 'tensorflow_transform.beam.analyzer_cache.DatasetCache'>] instead.
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: List[<class 'apache_beam.pvalue.PDone'>] instead.
WARNING:root:This input type hint will be ignored and not used for type-checking purposes. Typically, input type hints for a PTransform are single (or nested) types wrapped by a PCollection, or PBegin. Got: Dict[<class 'tensorflow_transform.beam.analyzer_cache.DatasetKey'>, <class 'tensorflow_transform.beam.analyzer_cache.DatasetCache'>] instead.
WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: List[<class 'apache_beam.pvalue.PDone'>] instead.
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fcb9c7dea90
.execution_id5
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs
['examples']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.outputs
['transform_graph']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['transformed_examples']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['updated_analyzer_cache']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['pre_transform_schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['pre_transform_stats']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['post_transform_schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['post_transform_stats']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['post_transform_anomalies']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }

Tuner

As the name suggests, the Tuner component tunes the hyperparameters of your model. To use this, you will need to provide a tuner module file which contains a tuner_fn() function. In this function, you will mostly do the same steps as you did in the previous ungraded lab but with some key differences in handling the dataset.

The Transform component earlier saved the transformed examples as TFRecords compressed in .gz format and you will need to load that into memory. Once loaded, you will need to create batches of features and labels so you can finally use it for hypertuning. This process is modularized in the _input_fn() below.

Going back, the tuner_fn() function will return a TunerFnResult namedtuple containing your tuner object and a set of arguments to pass to tuner.search() method. You will see these in action in the following cells. When reviewing the module file, we recommend viewing the tuner_fn() first before looking at the other auxiliary functions.

# Declare name of module file
_tuner_module_file = 'tuner.py'
%%writefile {_tuner_module_file}

# Define imports
from kerastuner.engine import base_tuner
import kerastuner as kt
from tensorflow import keras
from typing import NamedTuple, Dict, Text, Any, List
from tfx.components.trainer.fn_args_utils import FnArgs, DataAccessor
import tensorflow as tf
import tensorflow_transform as tft

# Declare namedtuple field names
TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
                                             ('fit_kwargs', Dict[Text, Any])])

# Input key
_IMAGE_KEY = 'image_xf'

# Label key
_LABEL_KEY = 'label_xf'

# Callback for the search strategy
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)


def _gzip_reader_fn(filenames):
  '''Load compressed dataset
  
  Args:
    filenames - filenames of TFRecords to load

  Returns:
    TFRecordDataset loaded from the filenames
  '''

  # Load the dataset. Specify the compression type since it is saved as `.gz`
  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
  

def _input_fn(file_pattern,
              tf_transform_output,
              num_epochs=None,
              batch_size=32) -> tf.data.Dataset:
  '''Create batches of features and labels from TF Records

  Args:
    file_pattern - List of files or patterns of file paths containing Example records.
    tf_transform_output - transform output graph
    num_epochs - Integer specifying the number of times to read through the dataset. 
            If None, cycles through the dataset forever.
    batch_size - An int representing the number of records to combine in a single batch.

  Returns:
    A dataset of dict elements, (or a tuple of dict elements and label). 
    Each dict maps feature keys to Tensor or SparseTensor objects.
  '''

  # Get feature specification based on transform output
  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())
  
  # Create batches of features and labels
  dataset = tf.data.experimental.make_batched_features_dataset(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=transformed_feature_spec,
      reader=_gzip_reader_fn,
      num_epochs=num_epochs,
      label_key=_LABEL_KEY)
  
  return dataset


def model_builder(hp):
  '''
  Builds the model and sets up the hyperparameters to tune.

  Args:
    hp - Keras tuner object

  Returns:
    model with hyperparameters to tune
  '''

  # Initialize the Sequential API and start stacking the layers
  model = keras.Sequential()
  model.add(keras.layers.Input(shape=(28, 28, 1), name=_IMAGE_KEY))
  model.add(keras.layers.Flatten())

  # Tune the number of units in the first Dense layer
  # Choose an optimal value between 32-512
  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu', name='dense_1'))

  # Add next layers
  model.add(keras.layers.Dropout(0.2))
  model.add(keras.layers.Dense(10, activation='softmax'))

  # Tune the learning rate for the optimizer
  # Choose an optimal value from 0.01, 0.001, or 0.0001
  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(),
                metrics=['accuracy'])

  return model

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """Build the tuner using the KerasTuner API.
  Args:
    fn_args: Holds args as name/value pairs.

      - working_dir: working dir for tuning.
      - train_files: List of file paths containing training tf.Example data.
      - eval_files: List of file paths containing eval tf.Example data.
      - train_steps: number of train steps.
      - eval_steps: number of eval steps.
      - schema_path: optional schema of the input data.
      - transform_graph_path: optional transform graph produced by TFT.
  
  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner's run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner's implementation.
  """

  # Define tuner search strategy
  tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3,
                     directory=fn_args.working_dir,
                     project_name='kt_hyperband')

  # Load transform output
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)

  # Use _input_fn() to extract input features and labels from the train and val set
  train_set = _input_fn(fn_args.train_files[0], tf_transform_output)
  val_set = _input_fn(fn_args.eval_files[0], tf_transform_output)


  return TunerFnResult(
      tuner=tuner,
      fit_kwargs={ 
          "callbacks":[stop_early],
          'x': train_set,
          'validation_data': val_set,
          'steps_per_epoch': fn_args.train_steps,
          'validation_steps': fn_args.eval_steps
      }
  )
Writing tuner.py

With the module defined, you can now setup the Tuner component. You can see the description of each argument here.

Notice that we passed a num_steps argument to the train and eval args and this was used in the steps_per_epoch and validation_steps arguments in the tuner module above. This can be useful if you don’t want to go through the entire dataset when tuning. For example, if you have 10GB of training data, it would be incredibly time consuming if you will iterate through it entirely just for one epoch and one set of hyperparameters. You can set the number of steps so your program will only go through a fraction of the dataset.

You can compute for the total number of steps in one epoch by: number of examples / batch size. For this particular example, we have 48000 examples / 32 (default size) which equals 1500 steps per epoch for the train set (compute val steps from 12000 examples). Since you passed 500 in the num_steps of the train args, this means that some examples will be skipped. This will likely result in lower accuracy readings but will save time in doing the hypertuning. Try modifying this value later and see if you arrive at the same set of hyperparameters.

# Setup the Tuner component
tuner = tfx.components.Tuner(
    module_file=_tuner_module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=500),
    eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=100)
    )
# Run the component. This will take around 10 minutes to run.
# When done, it will summarize the results and show the 10 best trials.
context.run(tuner, enable_cache=False)
Trial 30 Complete [00h 00m 24s]
val_accuracy: 0.878125011920929

Best val_accuracy So Far: 0.8803125023841858
Total elapsed time: 00h 06m 37s
Results summary
Results in ./pipeline/.temp/6/kt_hyperband
Showing 10 best trials
<keras_tuner.engine.objective.Objective object at 0x7fc989d446d0>
Trial summary
Hyperparameters:
units: 448
learning_rate: 0.001
tuner/epochs: 10
tuner/initial_epoch: 4
tuner/bracket: 2
tuner/round: 2
tuner/trial_id: 0014
Score: 0.8803125023841858
Trial summary
Hyperparameters:
units: 192
learning_rate: 0.001
tuner/epochs: 10
tuner/initial_epoch: 0
tuner/bracket: 0
tuner/round: 0
Score: 0.878125011920929
Trial summary
Hyperparameters:
units: 256
learning_rate: 0.001
tuner/epochs: 10
tuner/initial_epoch: 0
tuner/bracket: 0
tuner/round: 0
Score: 0.8778125047683716
Trial summary
Hyperparameters:
units: 224
learning_rate: 0.001
tuner/epochs: 10
tuner/initial_epoch: 4
tuner/bracket: 2
tuner/round: 2
tuner/trial_id: 0013
Score: 0.8771874904632568
Trial summary
Hyperparameters:
units: 192
learning_rate: 0.001
tuner/epochs: 10
tuner/initial_epoch: 4
tuner/bracket: 1
tuner/round: 1
tuner/trial_id: 0020
Score: 0.8771874904632568
Trial summary
Hyperparameters:
units: 448
learning_rate: 0.001
tuner/epochs: 4
tuner/initial_epoch: 2
tuner/bracket: 2
tuner/round: 1
tuner/trial_id: 0008
Score: 0.8700000047683716
Trial summary
Hyperparameters:
units: 416
learning_rate: 0.001
tuner/epochs: 10
tuner/initial_epoch: 0
tuner/bracket: 0
tuner/round: 0
Score: 0.8693749904632568
Trial summary
Hyperparameters:
units: 224
learning_rate: 0.001
tuner/epochs: 4
tuner/initial_epoch: 2
tuner/bracket: 2
tuner/round: 1
tuner/trial_id: 0004
Score: 0.8612499833106995
Trial summary
Hyperparameters:
units: 352
learning_rate: 0.0001
tuner/epochs: 10
tuner/initial_epoch: 4
tuner/bracket: 1
tuner/round: 1
tuner/trial_id: 0022
Score: 0.8603125214576721
Trial summary
Hyperparameters:
units: 192
learning_rate: 0.001
tuner/epochs: 4
tuner/initial_epoch: 0
tuner/bracket: 1
tuner/round: 0
Score: 0.8587499856948853
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fc989bdfdd0
.execution_id6
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs
['examples']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['transform_graph']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.outputs
['best_hyperparameters']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['tuner_results']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }

Trainer

Like the Tuner component, the Trainer component also requires a module file to setup the training process. It will look for a run_fn() function that defines and trains the model. The steps will look similar to the tuner module file:

# Declare trainer module file
_trainer_module_file = 'trainer.py'
%%writefile {_trainer_module_file}

from tensorflow import keras
from typing import NamedTuple, Dict, Text, Any, List
from tfx.components.trainer.fn_args_utils import FnArgs, DataAccessor
import tensorflow as tf
import tensorflow_transform as tft

# Input key
_IMAGE_KEY = 'image_xf'

# Label key
_LABEL_KEY = 'label_xf'

def _gzip_reader_fn(filenames):
  '''Load compressed dataset
  
  Args:
    filenames - filenames of TFRecords to load

  Returns:
    TFRecordDataset loaded from the filenames
  '''

  # Load the dataset. Specify the compression type since it is saved as `.gz`
  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
  

def _input_fn(file_pattern,
              tf_transform_output,
              num_epochs=None,
              batch_size=32) -> tf.data.Dataset:
  '''Create batches of features and labels from TF Records

  Args:
    file_pattern - List of files or patterns of file paths containing Example records.
    tf_transform_output - transform output graph
    num_epochs - Integer specifying the number of times to read through the dataset. 
            If None, cycles through the dataset forever.
    batch_size - An int representing the number of records to combine in a single batch.

  Returns:
    A dataset of dict elements, (or a tuple of dict elements and label). 
    Each dict maps feature keys to Tensor or SparseTensor objects.
  '''
  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())
  
  dataset = tf.data.experimental.make_batched_features_dataset(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=transformed_feature_spec,
      reader=_gzip_reader_fn,
      num_epochs=num_epochs,
      label_key=_LABEL_KEY)
  
  return dataset


def model_builder(hp):
  '''
  Builds the model and sets up the hyperparameters to tune.

  Args:
    hp - Keras tuner object

  Returns:
    model with hyperparameters to tune
  '''

  # Initialize the Sequential API and start stacking the layers
  model = keras.Sequential()
  model.add(keras.layers.Input(shape=(28, 28, 1), name=_IMAGE_KEY))
  model.add(keras.layers.Flatten())

  # Get the number of units from the Tuner results
  hp_units = hp.get('units')
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))

  # Add next layers
  model.add(keras.layers.Dropout(0.2))
  model.add(keras.layers.Dense(10, activation='softmax'))

  # Get the learning rate from the Tuner results
  hp_learning_rate = hp.get('learning_rate')

  # Setup model for training
  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(),
                metrics=['accuracy'])

  # Print the model summary
  model.summary()
  
  return model


def run_fn(fn_args: FnArgs) -> None:
  """Defines and trains the model.
  Args:
    fn_args: Holds args as name/value pairs. Refer here for the complete attributes: 
    https://www.tensorflow.org/tfx/api_docs/python/tfx/components/trainer/fn_args_utils/FnArgs#attributes
  """
  
  # Load transform output
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)
  
  # Create batches of data good for 10 epochs
  train_set = _input_fn(fn_args.train_files[0], tf_transform_output, 10)
  val_set = _input_fn(fn_args.eval_files[0], tf_transform_output, 10)

  # Load best hyperparameters
  hp = fn_args.hyperparameters.get('values')

  # Build the model
  model = model_builder(hp)

  # Train the model
  model.fit(
      x=train_set,
      validation_data=val_set,
      )
  
  # Save the model
  model.save(fn_args.serving_model_dir, save_format='tf')
Writing trainer.py

You can pass the output of the Tuner component to the Trainer by filling the hyperparameters argument with the Tuner output. This is indicated by the tuner.outputs['best_hyperparameters'] below. You can see the definition of the other arguments here.

# Setup the Trainer component
trainer = tfx.components.Trainer(
    module_file=_trainer_module_file,
    examples=transform.outputs['transformed_examples'],
    hyperparameters=tuner.outputs['best_hyperparameters'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(splits=['train']),
    eval_args=trainer_pb2.EvalArgs(splits=['eval']))

Take note that when re-training your model, you don’t always have to retune your hyperparameters. Once you have a set that you think performs well, you can just import it with the Importer component as shown in the official docs:

hparams_importer = Importer(
    # This can be Tuner's output file or manually edited file. The file contains
    # text format of hyperparameters (keras_tuner.HyperParameters.get_config())
    source_uri='path/to/best_hyperparameters.txt',
    artifact_type=HyperParameters,
).with_id('import_hparams')

trainer = Trainer(
    ...
    # An alternative is directly use the tuned hyperparameters in Trainer's user
    # module code and set hyperparameters to None here.
    hyperparameters = hparams_importer.outputs['result'])
# Run the component
context.run(trainer, enable_cache=False)
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten_1 (Flatten)         (None, 784)               0         
                                                                 
 dense_1 (Dense)             (None, 448)               351680    
                                                                 
 dropout_1 (Dropout)         (None, 448)               0         
                                                                 
 dense_2 (Dense)             (None, 10)                4490      
                                                                 
=================================================================
Total params: 356,170
Trainable params: 356,170
Non-trainable params: 0
_________________________________________________________________
14993/14993 [==============================] - 123s 8ms/step - loss: 0.3370 - accuracy: 0.8760 - val_loss: 0.3042 - val_accuracy: 0.8920
<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
ExecutionResult at 0x7fcb07b64990
.execution_id7
.component<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.inputs
['examples']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['transform_graph']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['schema']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['hyperparameters']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
.component.outputs
['model']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }
['model_run']<style> .tfx-object.expanded { padding: 4px 8px 4px 8px; background: white; border: 1px solid #bbbbbb; box-shadow: 4px 4px 2px rgba(0,0,0,0.05); } html[theme=dark] .tfx-object.expanded { background: black; } .tfx-object, .tfx-object * { font-size: 11pt; } .tfx-object > .title { cursor: pointer; } .tfx-object .expansion-marker { color: #999999; } .tfx-object.expanded > .title > .expansion-marker:before { content: '▼'; } .tfx-object.collapsed > .title > .expansion-marker:before { content: '▶'; } .tfx-object .class-name { font-weight: bold; } .tfx-object .deemphasize { opacity: 0.5; } .tfx-object.collapsed > table.attr-table { display: none; } .tfx-object.expanded > table.attr-table { display: block; } .tfx-object table.attr-table { border: 2px solid white; margin-top: 5px; } html[theme=dark] .tfx-object table.attr-table { border: 2px solid black; } .tfx-object table.attr-table td.attr-name { vertical-align: top; font-weight: bold; } .tfx-object table.attr-table td.attrvalue { text-align: left; } <script> function toggleTfxObject(element) { var objElement = element.parentElement; if (objElement.classList.contains('collapsed')) { objElement.classList.remove('collapsed'); objElement.classList.add('expanded'); } else { objElement.classList.add('collapsed'); objElement.classList.remove('expanded'); } }

Your model should now be saved in your pipeline directory and you can navigate through it as shown below. The file is saved as saved_model.pb.

# Get artifact uri of trainer model output
model_artifact_dir = trainer.outputs['model'].get()[0].uri

# List subdirectories artifact uri
print(f'contents of model artifact directory:{os.listdir(model_artifact_dir)}')

# Define the model directory
model_dir = os.path.join(model_artifact_dir, 'Format-Serving')

# List contents of model directory
print(f'contents of model directory: {os.listdir(model_dir)}')
contents of model artifact directory:['Format-Serving']
contents of model directory: ['fingerprint.pb', 'keras_metadata.pb', 'variables', 'saved_model.pb', 'assets']

Congratulations! You have now created an ML pipeline that includes hyperparameter tuning and model training. You will know more about the next components in future lessons but in the next section, you will first learn about a framework for automatically building ML pipelines: AutoML. Enjoy the rest of the course!