In this notebook, you will perform transfer learning to train CIFAR-10 dataset on ResNet50 model available in Keras.
import os, re, time, json
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import numpy as np
try:
# %tensorflow_version only exists in Colab.
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
print("Tensorflow version " + tf.__version__)
Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.
Tensorflow version 2.13.0
BATCH_SIZE = 32
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Define some functions that will help you to create some visualizations. (These will be used later)
#@title Visualization Utilities[RUN ME]
#Matplotlib config
plt.rc('image', cmap='gray')
plt.rc('grid', linewidth=0)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf")
# utility to display a row of digits with their predictions
def display_images(digits, predictions, labels, title):
n = 10
indexes = np.random.choice(len(predictions), size=n)
n_digits = digits[indexes]
n_predictions = predictions[indexes]
n_predictions = n_predictions.reshape((n,))
n_labels = labels[indexes]
fig = plt.figure(figsize=(20, 4))
plt.title(title)
plt.yticks([])
plt.xticks([])
for i in range(10):
ax = fig.add_subplot(1, 10, i+1)
class_index = n_predictions[i]
plt.xlabel(classes[class_index])
plt.xticks([])
plt.yticks([])
plt.imshow(n_digits[i])
# utility to display training and validation curves
def plot_metrics(metric_name, title, ylim=5):
plt.title(title)
plt.ylim(0,ylim)
plt.plot(history.history[metric_name],color='blue',label=metric_name)
plt.plot(history.history['val_' + metric_name],color='green',label='val_' + metric_name)
CIFAR-10 dataset has 32 x 32 RGB images belonging to 10 classes. You will load the dataset from Keras.
(training_images, training_labels) , (validation_images, validation_labels) = tf.keras.datasets.cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 4s 0us/step
Use the display_image
to view some of the images and their class labels.
display_images(training_images, training_labels, training_labels, "Training Data" )
display_images(validation_images, validation_labels, validation_labels, "Training Data" )
validation_images[0].astype('float32').shape
(32, 32, 3)
Here, you’ll perform normalization on images in training and validation set.
def preprocess_image_input(input_images):
input_images = input_images.astype('float32')
output_ims = tf.keras.applications.resnet50.preprocess_input(input_images)
return output_ims
train_X = preprocess_image_input(training_images)
valid_X = preprocess_image_input(validation_images)
You will be performing transfer learning on ResNet50 available in Keras.
'''
Feature Extraction is performed by ResNet50 pretrained on imagenet weights.
Input size is 224 x 224.
'''
def feature_extractor(inputs):
feature_extractor = tf.keras.applications.resnet.ResNet50(input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')(inputs)
return feature_extractor
'''
Defines final dense layers and subsequent softmax layer for classification.
'''
def classifier(inputs):
x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(1024, activation="relu")(x)
x = tf.keras.layers.Dense(512, activation="relu")(x)
x = tf.keras.layers.Dense(10, activation="softmax", name="classification")(x)
return x
'''
Since input image size is (32 x 32), first upsample the image by factor of (7x7) to transform it to (224 x 224)
Connect the feature extraction and "classifier" layers to build the model.
'''
def final_model(inputs):
resize = tf.keras.layers.UpSampling2D(size=(7,7))(inputs)
resnet_feature_extractor = feature_extractor(resize)
classification_output = classifier(resnet_feature_extractor)
return classification_output
'''
Define the model and compile it.
Use Stochastic Gradient Descent as the optimizer.
Use Sparse Categorical CrossEntropy as the loss function.
'''
def define_compile_model():
inputs = tf.keras.layers.Input(shape=(32,32,3))
classification_output = final_model(inputs)
model = tf.keras.Model(inputs=inputs, outputs = classification_output)
model.compile(optimizer='SGD',
loss='sparse_categorical_crossentropy',
metrics = ['accuracy'])
return model
model = define_compile_model()
model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94765736/94765736 [==============================] - 1s 0us/step
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 32, 32, 3)] 0
up_sampling2d (UpSampling2 (None, 224, 224, 3) 0
D)
resnet50 (Functional) (None, 7, 7, 2048) 23587712
global_average_pooling2d ( (None, 2048) 0
GlobalAveragePooling2D)
flatten (Flatten) (None, 2048) 0
dense (Dense) (None, 1024) 2098176
dense_1 (Dense) (None, 512) 524800
classification (Dense) (None, 10) 5130
=================================================================
Total params: 26215818 (100.01 MB)
Trainable params: 26162698 (99.80 MB)
Non-trainable params: 53120 (207.50 KB)
_________________________________________________________________
# this will take around 20 minutes to complete
EPOCHS = 4
history = model.fit(train_X, training_labels, epochs=EPOCHS, validation_data = (valid_X, validation_labels), batch_size=64)
Epoch 1/4
782/782 [==============================] - 560s 674ms/step - loss: 0.4068 - accuracy: 0.8673 - val_loss: 0.3160 - val_accuracy: 0.8893
Epoch 2/4
14/782 [..............................] - ETA: 7:50 - loss: 0.1409 - accuracy: 0.9542
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-11-d14574d8f8f1> in <cell line: 3>()
1 # this will take around 20 minutes to complete
2 EPOCHS = 4
----> 3 history = model.fit(train_X, training_labels, epochs=EPOCHS, validation_data = (valid_X, validation_labels), batch_size=64)
/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
63 filtered_tb = None
64 try:
---> 65 return fn(*args, **kwargs)
66 except Exception as e:
67 filtered_tb = _process_traceback_frames(e.__traceback__)
/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1746 logs = tmp_logs
1747 end_step = step + data_handler.step_increment
-> 1748 callbacks.on_train_batch_end(end_step, logs)
1749 if self.stop_training:
1750 break
/usr/local/lib/python3.10/dist-packages/keras/src/callbacks.py in on_train_batch_end(self, batch, logs)
473 """
474 if self._should_call_train_batch_hooks:
--> 475 self._call_batch_hook(ModeKeys.TRAIN, "end", batch, logs=logs)
476
477 def on_test_batch_begin(self, batch, logs=None):
/usr/local/lib/python3.10/dist-packages/keras/src/callbacks.py in _call_batch_hook(self, mode, hook, batch, logs)
320 self._call_batch_begin_hook(mode, batch, logs)
321 elif hook == "end":
--> 322 self._call_batch_end_hook(mode, batch, logs)
323 else:
324 raise ValueError(
/usr/local/lib/python3.10/dist-packages/keras/src/callbacks.py in _call_batch_end_hook(self, mode, batch, logs)
343 self._batch_times.append(batch_time)
344
--> 345 self._call_batch_hook_helper(hook_name, batch, logs)
346
347 if len(self._batch_times) >= self._num_batches_for_timing_check:
/usr/local/lib/python3.10/dist-packages/keras/src/callbacks.py in _call_batch_hook_helper(self, hook_name, batch, logs)
391 for callback in self.callbacks:
392 hook = getattr(callback, hook_name)
--> 393 hook(batch, logs)
394
395 if self._check_timing:
/usr/local/lib/python3.10/dist-packages/keras/src/callbacks.py in on_train_batch_end(self, batch, logs)
1091
1092 def on_train_batch_end(self, batch, logs=None):
-> 1093 self._batch_update_progbar(batch, logs)
1094
1095 def on_test_batch_end(self, batch, logs=None):
/usr/local/lib/python3.10/dist-packages/keras/src/callbacks.py in _batch_update_progbar(self, batch, logs)
1167 if self.verbose == 1:
1168 # Only block async when verbose = 1.
-> 1169 logs = tf_utils.sync_to_numpy_or_python_type(logs)
1170 self.progbar.update(self.seen, list(logs.items()), finalize=False)
1171
/usr/local/lib/python3.10/dist-packages/keras/src/utils/tf_utils.py in sync_to_numpy_or_python_type(tensors)
692 return t.item() if np.ndim(t) == 0 else t
693
--> 694 return tf.nest.map_structure(_to_single_numpy_or_python_type, tensors)
695
696
/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
622 ValueError: If wrong keyword arguments are provided.
623 """
--> 624 return nest_util.map_structure(
625 nest_util.Modality.CORE, func, *structure, **kwargs
626 )
/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/nest_util.py in map_structure(modality, func, *structure, **kwargs)
1052 """
1053 if modality == Modality.CORE:
-> 1054 return _tf_core_map_structure(func, *structure, **kwargs)
1055 elif modality == Modality.DATA:
1056 return _tf_data_map_structure(func, *structure, **kwargs)
/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/nest_util.py in _tf_core_map_structure(func, *structure, **kwargs)
1092 return _tf_core_pack_sequence_as(
1093 structure[0],
-> 1094 [func(*x) for x in entries],
1095 expand_composites=expand_composites,
1096 )
/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/nest_util.py in <listcomp>(.0)
1092 return _tf_core_pack_sequence_as(
1093 structure[0],
-> 1094 [func(*x) for x in entries],
1095 expand_composites=expand_composites,
1096 )
/usr/local/lib/python3.10/dist-packages/keras/src/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
685 # Don't turn ragged or sparse tensors to NumPy.
686 if isinstance(t, tf.Tensor):
--> 687 t = t.numpy()
688 # Strings, ragged and sparse tensors don't have .item(). Return them
689 # as-is.
/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py in numpy(self)
1139 """
1140 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
-> 1141 maybe_arr = self._numpy() # pylint: disable=protected-access
1142 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1143
/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
1105 def _numpy(self):
1106 try:
-> 1107 return self._numpy_internal()
1108 except core._NotOkStatusException as e: # pylint: disable=protected-access
1109 raise core._status_to_exception(e) from None # pylint: disable=protected-access
KeyboardInterrupt:
Calculate the loss and accuracy metrics using the model’s .evaluate
function.
loss, accuracy = model.evaluate(valid_X, validation_labels, batch_size=64)
157/157 [==============================] - 28s 175ms/step - loss: 0.2116 - accuracy: 0.9243
Plot the loss (in blue) and validation loss (in green).
plot_metrics("loss", "Loss")
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-13-a2a9a3895552> in <cell line: 1>()
----> 1 plot_metrics("loss", "Loss")
<ipython-input-3-308e0f3f2072> in plot_metrics(metric_name, title, ylim)
38 plt.title(title)
39 plt.ylim(0,ylim)
---> 40 plt.plot(history.history[metric_name],color='blue',label=metric_name)
41 plt.plot(history.history['val_' + metric_name],color='green',label='val_' + metric_name)
NameError: name 'history' is not defined
Plot the training accuracy (blue) as well as the validation accuracy (green).
plot_metrics("accuracy", "Accuracy")
You can take a look at the predictions on the validation set.
probabilities = model.predict(valid_X, batch_size=64)
probabilities = np.argmax(probabilities, axis = 1)
display_images(validation_images, probabilities, validation_labels, "Bad predictions indicated in red.")
157/157 [==============================] - 29s 171ms/step