In this lab, you will build your first simple autoencoder. This will take in three-dimensional data, encodes it to two dimensions, and decodes it back to 3D.
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
You will first create a synthetic dataset to act as input to the autoencoder. You can do that with the function below.
def generate_data(m):
'''plots m random points on a 3D plane'''
angles = np.random.rand(m) * 3 * np.pi / 2 - 0.5
data = np.empty((m, 3))
data[:,0] = np.cos(angles) + np.sin(angles)/2 + 0.1 * np.random.randn(m)/2
data[:,1] = np.sin(angles) * 0.7 + 0.1 * np.random.randn(m) / 2
data[:,2] = data[:, 0] * 0.1 + data[:, 1] * 0.3 + 0.1 * np.random.randn(m)
return data
# use the function above to generate data points
X_train = generate_data(100)
X_train = X_train - X_train.mean(axis=0, keepdims=0)
# preview the data
ax = plt.axes(projection='3d')
ax.scatter3D(X_train[:, 0], X_train[:, 1], X_train[:, 2], cmap='Reds');
<ipython-input-3-42bc7d7259f7>:7: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
ax.scatter3D(X_train[:, 0], X_train[:, 1], X_train[:, 2], cmap='Reds');
Now you will build the simple encoder-decoder model. Notice the number of neurons in each Dense layer. The model will contract in the encoder then expand in the decoder.
encoder = keras.models.Sequential([keras.layers.Dense(2, input_shape=[3])])
decoder = keras.models.Sequential([keras.layers.Dense(3, input_shape=[2])])
autoencoder = keras.models.Sequential([encoder, decoder])
You can then setup the model for training.
autoencoder.compile(loss="mse", optimizer=keras.optimizers.SGD(lr=0.1))
WARNING:absl:`lr` is deprecated in Keras optimizer, please use `learning_rate` or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.SGD.
You will configure the training to also use the input data as your target output. In our example, that will be X_train
.
history = autoencoder.fit(X_train, X_train, epochs=200)
Epoch 1/200
4/4 [==============================] - 6s 15ms/step - loss: 0.2580
Epoch 2/200
4/4 [==============================] - 0s 8ms/step - loss: 0.2306
Epoch 3/200
4/4 [==============================] - 0s 8ms/step - loss: 0.2063
Epoch 4/200
4/4 [==============================] - 0s 11ms/step - loss: 0.1833
Epoch 5/200
4/4 [==============================] - 0s 11ms/step - loss: 0.1651
Epoch 6/200
4/4 [==============================] - 0s 8ms/step - loss: 0.1507
Epoch 7/200
4/4 [==============================] - 0s 8ms/step - loss: 0.1382
Epoch 8/200
4/4 [==============================] - 0s 6ms/step - loss: 0.1283
Epoch 9/200
4/4 [==============================] - 0s 5ms/step - loss: 0.1203
Epoch 10/200
4/4 [==============================] - 0s 5ms/step - loss: 0.1127
Epoch 11/200
4/4 [==============================] - 0s 5ms/step - loss: 0.1053
Epoch 12/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0995
Epoch 13/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0940
Epoch 14/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0891
Epoch 15/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0848
Epoch 16/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0814
Epoch 17/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0780
Epoch 18/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0756
Epoch 19/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0733
Epoch 20/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0710
Epoch 21/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0690
Epoch 22/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0675
Epoch 23/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0662
Epoch 24/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0646
Epoch 25/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0634
Epoch 26/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0625
Epoch 27/200
4/4 [==============================] - 0s 11ms/step - loss: 0.0617
Epoch 28/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0607
Epoch 29/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0599
Epoch 30/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0592
Epoch 31/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0586
Epoch 32/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0580
Epoch 33/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0576
Epoch 34/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0571
Epoch 35/200
4/4 [==============================] - 0s 11ms/step - loss: 0.0566
Epoch 36/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0562
Epoch 37/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0559
Epoch 38/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0557
Epoch 39/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0553
Epoch 40/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0549
Epoch 41/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0547
Epoch 42/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0544
Epoch 43/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0543
Epoch 44/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0539
Epoch 45/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0536
Epoch 46/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0534
Epoch 47/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0531
Epoch 48/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0529
Epoch 49/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0526
Epoch 50/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0524
Epoch 51/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0522
Epoch 52/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0520
Epoch 53/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0519
Epoch 54/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0517
Epoch 55/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0517
Epoch 56/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0515
Epoch 57/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0514
Epoch 58/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0512
Epoch 59/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0511
Epoch 60/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0510
Epoch 61/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0509
Epoch 62/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0507
Epoch 63/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0507
Epoch 64/200
4/4 [==============================] - 0s 4ms/step - loss: 0.0505
Epoch 65/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0504
Epoch 66/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0504
Epoch 67/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0503
Epoch 68/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0502
Epoch 69/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0501
Epoch 70/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0500
Epoch 71/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0500
Epoch 72/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0499
Epoch 73/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0498
Epoch 74/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0497
Epoch 75/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0496
Epoch 76/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0496
Epoch 77/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0495
Epoch 78/200
4/4 [==============================] - 0s 11ms/step - loss: 0.0494
Epoch 79/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0493
Epoch 80/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0492
Epoch 81/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0492
Epoch 82/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0491
Epoch 83/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0491
Epoch 84/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0490
Epoch 85/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0489
Epoch 86/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0489
Epoch 87/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0488
Epoch 88/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0488
Epoch 89/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0487
Epoch 90/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0487
Epoch 91/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0486
Epoch 92/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0485
Epoch 93/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0485
Epoch 94/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0484
Epoch 95/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0484
Epoch 96/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0483
Epoch 97/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0482
Epoch 98/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0482
Epoch 99/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0481
Epoch 100/200
4/4 [==============================] - 0s 14ms/step - loss: 0.0481
Epoch 101/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0481
Epoch 102/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0480
Epoch 103/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0479
Epoch 104/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0479
Epoch 105/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0478
Epoch 106/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0478
Epoch 107/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0477
Epoch 108/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0477
Epoch 109/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0476
Epoch 110/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0476
Epoch 111/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0475
Epoch 112/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0475
Epoch 113/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0474
Epoch 114/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0474
Epoch 115/200
4/4 [==============================] - 0s 4ms/step - loss: 0.0474
Epoch 116/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0473
Epoch 117/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0472
Epoch 118/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0472
Epoch 119/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0471
Epoch 120/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0471
Epoch 121/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0471
Epoch 122/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0470
Epoch 123/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0470
Epoch 124/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0469
Epoch 125/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0469
Epoch 126/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0468
Epoch 127/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0468
Epoch 128/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0468
Epoch 129/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0467
Epoch 130/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0467
Epoch 131/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0467
Epoch 132/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0466
Epoch 133/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0466
Epoch 134/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0465
Epoch 135/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0465
Epoch 136/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0464
Epoch 137/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0464
Epoch 138/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0464
Epoch 139/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0463
Epoch 140/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0463
Epoch 141/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0463
Epoch 142/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0462
Epoch 143/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0462
Epoch 144/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0461
Epoch 145/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0461
Epoch 146/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0461
Epoch 147/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0461
Epoch 148/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0460
Epoch 149/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0460
Epoch 150/200
4/4 [==============================] - 0s 4ms/step - loss: 0.0460
Epoch 151/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0459
Epoch 152/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0459
Epoch 153/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0458
Epoch 154/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0458
Epoch 155/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0458
Epoch 156/200
4/4 [==============================] - 0s 11ms/step - loss: 0.0457
Epoch 157/200
4/4 [==============================] - 0s 13ms/step - loss: 0.0457
Epoch 158/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0457
Epoch 159/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0457
Epoch 160/200
4/4 [==============================] - 0s 12ms/step - loss: 0.0456
Epoch 161/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0456
Epoch 162/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0456
Epoch 163/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0455
Epoch 164/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0455
Epoch 165/200
4/4 [==============================] - 0s 10ms/step - loss: 0.0455
Epoch 166/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0455
Epoch 167/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0454
Epoch 168/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0454
Epoch 169/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0454
Epoch 170/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0453
Epoch 171/200
4/4 [==============================] - 0s 14ms/step - loss: 0.0453
Epoch 172/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0453
Epoch 173/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0452
Epoch 174/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0452
Epoch 175/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0452
Epoch 176/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0452
Epoch 177/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0451
Epoch 178/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0451
Epoch 179/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0451
Epoch 180/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0450
Epoch 181/200
4/4 [==============================] - 0s 11ms/step - loss: 0.0450
Epoch 182/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0450
Epoch 183/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0450
Epoch 184/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0450
Epoch 185/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0449
Epoch 186/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0449
Epoch 187/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0449
Epoch 188/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0449
Epoch 189/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0448
Epoch 190/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0448
Epoch 191/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0448
Epoch 192/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0448
Epoch 193/200
4/4 [==============================] - 0s 15ms/step - loss: 0.0447
Epoch 194/200
4/4 [==============================] - 0s 9ms/step - loss: 0.0447
Epoch 195/200
4/4 [==============================] - 0s 7ms/step - loss: 0.0447
Epoch 196/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0446
Epoch 197/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0446
Epoch 198/200
4/4 [==============================] - 0s 8ms/step - loss: 0.0446
Epoch 199/200
4/4 [==============================] - 0s 5ms/step - loss: 0.0446
Epoch 200/200
4/4 [==============================] - 0s 6ms/step - loss: 0.0446
As mentioned, you can use the encoder to compress the input to two dimensions.
# encode the data
codings = encoder.predict(X_train)
# see a sample input-encoder output pair
print(f'input point: {X_train[0]}')
print(f'encoded point: {codings[0]}')
4/4 [==============================] - 0s 3ms/step
input point: [0.99123401 0.26140142 0.16681574]
encoded point: [-0.84345716 1.2612782 ]
# plot all encoder outputs
fig = plt.figure(figsize=(4,3))
plt.plot(codings[:,0], codings[:, 1], "b.")
plt.xlabel("$z_1$", fontsize=18)
plt.ylabel("$z_2$", fontsize=18, rotation=0)
plt.grid(True)
plt.show()
The decoder then tries to reconstruct the original input. See the outputs below. You will see that although not perfect, it still follows the general shape of the original input.
# decode the encoder output
decodings = decoder.predict(codings)
# see a sample output for a single point
print(f'input point: {X_train[0]}')
print(f'encoded point: {codings[0]}')
print(f'decoded point: {decodings[0]}')
4/4 [==============================] - 0s 4ms/step
input point: [0.99123401 0.26140142 0.16681574]
encoded point: [-0.84345716 1.2612782 ]
decoded point: [0.96741295 0.28178403 0.16318715]
# plot the decoder output
ax = plt.axes(projection='3d')
ax.scatter3D(decodings[:, 0], decodings[:, 1], decodings[:, 2], c=decodings[:, 0], cmap='Reds');
That’s it for this simple demonstration of the autoencoder!