Coursera

GauGAN

Please note that this is an optional notebook, meant to introduce more advanced concepts if you’re up for a challenge, so don’t worry if you don’t completely follow!

It is recommended that you should already be familiar with:

Goals

In this notebook, you will learn about GauGAN, which synthesizes high-resolution images from semantic label maps, which you implement and train. GauGAN is based around a special denormalization technique proposed in Semantic Image Synthesis with Spatially-Adaptive Normalization (Park et al. 2019)

Background

GauGAN builds on Pix2PixHD but simplifies the overall network by adding spatially adaptive denormalization layers. Because it learns its denormalization parameters via convolving the instance segmentation map, it actually is better for multi-modal synthesis, since all it needs as is a random noise vector. Later in the notebook, you will see how the authors further control diversity with the noise vector.

GauGAN Submodules

Let’s first take a look at the building blocks behind GauGAN.

Synchronized BatchNorm

So you’ve already heard of batch norm, which is a normalization technique that tries to normalize the statistics of activations a standard Gaussian distribution.

Batch norm, however, performs poorly with small batch sizes. This becomes a problem when training large models that can only fit small batch sizes on GPUs. Training on multiple GPUs will increase the effective batch size, but vanilla batch norm will only update its statistics asynchronously on each GPU. Essentially, if you train on 2 gpus with nn.BatchNorm2d, the two batchnorm modules will have a different running averages of statistics and batch norm stability isn’t better from larger effective batch size.

Synchronized batch norm (nn.SyncBatchNorm) does exactly what its name suggests - it synchronizes batch norm running average updates across multiple processes so that each update will be with the statistics across all your minibatches.

The authors report slightly better scores with synchronized batch norm as opposed to regular (asynchronous) batch norm. Since you will likely be running this on one machine, this notebook will stick to regular nn.BatchNorm2d modules.

Spatially Adaptive Denormalization (SPADE)

Recall that normalization layers are formulated as

\begin{align*} y &= \dfrac{x - \hat{\mu}}{\hat{\sigma}} * \gamma + \beta \end{align*}

where $\hat{\mu}$ and $\hat{\sigma}$ correspond to an exponential moving average of minibatch means and standard deviations and are used to normalize the input activation $x$. The parameters $\gamma$ and $\beta$ apply “denormalization,” essentially allowing the model to invert the normalization if necessary.

In GauGAN, batch norm is the preferred normalization scheme. Recall that batch norm can be formulated for each input neuron as

\begin{align*} y_{c,h,w} &= \dfrac{x_{c,h,w} - \hat{\mu}_c}{\hat{\sigma}_c} * \gamma_c + \beta_c \end{align*}

where $\hat{\mu}_c$ and $\hat{\sigma}_c$ are per-channel statistics computed across the batch and spatial dimensions. Similarly, $\gamma_c$ and $\beta_c$ are per-channel denormalization parameters.

With vanilla batch norm, these denormalization parameters are spatially invariant - that is, the same values are applied to every position in the input activation. As you may imagine, this could be limiting for the model. Oftentimes it’s conducive for the model to learn denormalization parameters for each position.

The authors address this with SPatially Adaptive DEnormalization (SPADE). They compute denormalization parameters $\gamma$ and $\beta$ by convolving the input segmentation masks and apply these elementwise. SPADE can therefore be formulated as

\begin{align*} y_{c,h,w} &= \dfrac{x_{c,h,w} - \hat{\mu}c}{\hat{\sigma}c} * \gamma{c,h,w} + \beta{c,h,w} \end{align*}

Now let’s implement SPADE!

Note: the authors use spectral norm in all convolutional layers in the generator and discriminator, but the official code omits spectral norm for SPADE layers.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SPADE(nn.Module):
    '''
    SPADE Class
    Values:
        channels: the number of channels in the input, a scalar
        cond_channels: the number of channels in conditional input (one-hot semantic labels), a scalar
    '''

    def __init__(self, channels, cond_channels):
        super().__init__()

        self.batchnorm = nn.BatchNorm2d(channels)
        self.spade = nn.Sequential(
            nn.Conv2d(cond_channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, 2 * channels, kernel_size=3, padding=1),
        )

    def forward(self, x, seg):
        # Apply normalization
        x = self.batchnorm(x)

        # Compute denormalization
        seg = F.interpolate(seg, size=x.shape[-2:], mode='nearest')
        gamma, beta = torch.chunk(self.spade(seg), 2, dim=1)

        # Apply denormalization
        x = x * (1 + gamma) + beta
        return x

Residual Blocks

Let’s now implement residual blocks with SPADE normalization. You should be familiar with the residual block by now, but this implementation will be a bit different to accomodate for the extra semantic label map input. For a refresher on residual blocks, please take a look here.

class ResidualBlock(nn.Module):
    '''
    ResidualBlock Class
    Values:
        in_channels: the number of input channels, a scalar
        out_channels: the number of output channels, a scalar
        cond_channels: the number of channels in conditional input in spade layer, a scalar
    '''

    def __init__(self, in_channels, out_channels, cond_channels):
        super().__init__()

        hid_channels = min(in_channels, out_channels)

        self.proj = in_channels != out_channels
        if self.proj:
            self.norm0 = SPADE(in_channels, cond_channels)
            self.conv0 = nn.utils.spectral_norm(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
            )

        self.activation = nn.LeakyReLU(0.2)
        self.norm1 = SPADE(in_channels, cond_channels)
        self.norm2 = SPADE(hid_channels, cond_channels)
        self.conv1 = nn.utils.spectral_norm(
            nn.Conv2d(in_channels, hid_channels, kernel_size=3, padding=1)
        )
        self.conv2 = nn.utils.spectral_norm(
            nn.Conv2d(hid_channels, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x, seg):
        dx = self.norm1(x, seg)
        dx = self.activation(dx)
        dx = self.conv1(dx)
        dx = self.norm2(dx, seg)
        dx = self.activation(dx)
        dx = self.conv2(dx)

        # Learn skip connection if in_channels != out_channels
        if self.proj:
            x = self.norm0(x, seg)
            x = self.conv0(x)

        return x + dx

GauGAN Parts

Now that you understand the main contributions of GauGAN and its submodules, let’s dive into the encoder, generator, and discriminator!

Encoder

GauGAN’s encoder serves a different purpose than Pix2PixHD’s. Instead of learning feature maps to be fed as input to the generator, GauGAN’s encoder encodes the original image into a mean and standard deviation from which to sample noise, which is given to the generator. You may recall this same technique of encoding to a mean and standard devation is used in variational autoencoders (VAEs, covered in a Course 2 optional notebook).

class Encoder(nn.Module):
    '''
    Encoder Class
    Values:
        spatial_size: tuple specifying (height, width) of full size image, a tuple
        z_dim: number of dimensions of latent noise vector (z), a scalar
        n_downsample: number of downsampling blocks in the encoder, a scalar
        base_channels: number of channels in the last hidden layer, a scalar
    '''

    max_channels = 512

    def __init__(self, spatial_size, z_dim=256, n_downsample=6, base_channels=64):
        super().__init__()

        layers = []
        channels = base_channels
        for i in range(n_downsample):
            in_channels = 3 if i == 0 else channels
            out_channels = 2 * z_dim if i < n_downsample else max(self.max_channels, channels * 2)
            layers += [
                nn.utils.spectral_norm(
                    nn.Conv2d(in_channels, out_channels, stride=2, kernel_size=3, padding=1)
                ),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(0.2),
            ]
            channels = out_channels

        h, w = spatial_size[0] // 2 ** n_downsample, spatial_size[1] // 2 ** n_downsample
        layers += [
            nn.Flatten(1),
            nn.Linear(channels * h * w, 2 * z_dim),
        ]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return torch.chunk(self.layers(x), 2, dim=1)

Generator

The GauGAN generator is actually very different from previous image-to-image translation generators. Because information from the semantic label map is injected at each batch normalization layer, the generator is able to just take random noise $z$ as input. This noise is reshaped and upsampled to the target image size.

Let’s take a look at the implementation!

class Generator(nn.Module):
    '''
    Generator Class
    Values:
        n_classes: the number of object classes in the dataset, a scalar
        bottom_width: the downsampled spatial size of the image, a scalar
        z_dim: the number of dimensions the z noise vector has, a scalar
        base_channels: the number of channels in last hidden layer, a scalar
        n_upsample: the number of upsampling operations to apply, a scalar
    '''

    max_channels = 1024

    def __init__(self, n_classes, spatial_size, z_dim=256, base_channels=64, n_upsample=6):
        super().__init__()

        h, w = spatial_size[0] // 2 ** n_upsample, spatial_size[1] // 2 ** n_upsample
        self.proj_z = nn.Linear(z_dim, self.max_channels * h * w)
        self.reshape = lambda x: torch.reshape(x, (-1, self.max_channels, h, w))

        self.upsample = nn.Upsample(scale_factor=2)
        self.res_blocks = nn.ModuleList()
        for i in reversed(range(n_upsample)):
            in_channels = min(self.max_channels, base_channels * 2 ** (i+1))
            out_channels = min(self.max_channels, base_channels * 2 ** i)
            self.res_blocks.append(ResidualBlock(in_channels, out_channels, n_classes))

        self.proj_o = nn.Sequential(
            nn.Conv2d(base_channels, 3, kernel_size=3, padding=1),
            nn.Tanh(),
        )

    def forward(self, z, seg):
        h = self.proj_z(z)
        h = self.reshape(h)
        for res_block in self.res_blocks:
            h = res_block(h, seg)
            h = self.upsample(h)
        h = self.proj_o(h)
        return h

Discriminator

The architecture of the discriminator follows the one used in Pix2PixHD, which uses a multi-scale design with the InstanceNorm. The only difference here is that they apply spectral normalization to all convolutional layers. GauGAN’s discriminator also takes as input the image concatenated with the semantic label map (no instance boundary map as in Pix2PixHD).

class PatchGANDiscriminator(nn.Module):
    '''
    PatchGANDiscriminator Class
    Implements the discriminator class for a subdiscriminator, 
    which can be used for all the different scales, just with different argument values.
    Values:
        in_channels: the number of channels in input, a scalar
        base_channels: the number of channels in first convolutional layer, a scalar
        n_layers: the number of convolutional layers, a scalar
    '''

    def __init__(self, in_channels, base_channels=64, n_layers=3):
        super().__init__()

        # Use nn.ModuleList so we can output intermediate values for loss.
        self.layers = nn.ModuleList()

        # Initial convolutional layer
        self.layers.append(
            nn.Sequential(
                nn.utils.spectral_norm(
                    nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=2)
                ),
                nn.LeakyReLU(0.2, inplace=True),
            )
        )

        # Downsampling convolutional layers
        channels = base_channels
        for _ in range(1, n_layers):
            prev_channels = channels
            channels = min(2 * channels, 512)
            self.layers.append(
                nn.Sequential(
                    nn.utils.spectral_norm(
                        nn.Conv2d(prev_channels, channels, kernel_size=4, stride=2, padding=2)
                    ),
                    nn.InstanceNorm2d(channels, affine=False),
                    nn.LeakyReLU(0.2, inplace=True),
                )
            )

        # Output convolutional layer
        prev_channels = channels
        channels = min(2 * channels, 512)
        self.layers.append(
            nn.Sequential(
                nn.utils.spectral_norm(
                    nn.Conv2d(prev_channels, channels, kernel_size=4, stride=1, padding=2))
                ,
                nn.InstanceNorm2d(channels, affine=False),
                nn.LeakyReLU(0.2, inplace=True),
                nn.utils.spectral_norm(
                    nn.Conv2d(channels, 1, kernel_size=4, stride=1, padding=2)
                ),
            )
        )

    def forward(self, x):
        outputs = [] # for feature matching loss
        for layer in self.layers:
            x = layer(x)
            outputs.append(x)

        return outputs

Now you’re ready to implement the multiscale discriminator in full! This puts together the different subdiscriminator scales.

class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        in_channels: number of input channels to each discriminator, a scalar
        base_channels: number of channels in last hidden layer, a scalar
        n_layers: number of downsampling layers in each discriminator, a scalar
        n_discriminators: number of discriminators at different scales, a scalar
    '''

    def __init__(self, in_channels, base_channels=64, n_layers=3, n_discriminators=3):
        super().__init__()

        # Initialize all discriminators
        self.discriminators = nn.ModuleList()
        for _ in range(n_discriminators):
            self.discriminators.append(
                PatchGANDiscriminator(in_channels, base_channels=base_channels, n_layers=n_layers)
            )

        # Downsampling layer to pass inputs between discriminators at different scales
        self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)

    def forward(self, x):
        outputs = []

        for i, discriminator in enumerate(self.discriminators):
            # Downsample input for subsequent discriminators
            if i != 0:
                x = self.downsample(x)

            outputs.append(discriminator(x))

        # Return list of multiscale discriminator outputs
        return outputs

    @property
    def n_discriminators(self):
        return len(self.discriminators)

GauGAN: Putting it all together

You can now create your GauGAN model that encapsulates all the parts you’ve just learned about! Since the encoder outputs mean and log-variance values to sample random noise from, this implementation will use the ‘reparameterization trick’ to allow gradient flow to the encoder. If you’re not familiar with this trick, it samples from $\mathcal{N}(0, I)$ and applies shift and scale ($\mu, \sigma$) as opposed to sampling directly from $\mathcal{N}(\mu, \sigma^2I)$:

\begin{align*} z: z \sim \mathcal{N}(\mu, \sigma^2I) \equiv \sigma * z + \mu: z \sim \mathcal{N}(0, I). \end{align*}

Check out the optional VAE notebook for a more detailed description!

class GauGAN(nn.Module):
    '''
    GauGAN Class
    Values:
        n_classes: number of object classes in dataset, a scalar
        spatial_size: tuple containing (height, width) of full-size image, a tuple
        base_channels: number of channels in last generator & first discriminator layers, a scalar
        z_dim: number of dimensions in noise vector (z), a scalar
        n_upsample: number of downsampling (encoder) and upsampling (generator) operations, a scalar
        n_disc_layer:: number of discriminator layers, a scalar
        n_disc: number of discriminators (at different scales), a scalar
    '''

    def __init__(
        self,
        n_classes,
        spatial_size,
        base_channels=64,
        z_dim=256,
        n_upsample=6,
        n_disc_layers=3,
        n_disc=3,
    ):
        super().__init__()

        self.encoder = Encoder(
            spatial_size, z_dim=z_dim, n_downsample=n_upsample, base_channels=base_channels,
        )
        self.generator = Generator(
            n_classes, spatial_size, z_dim=z_dim, base_channels=base_channels, n_upsample=n_upsample,
        )
        self.discriminator = Discriminator(
            n_classes + 3, base_channels=base_channels, n_layers=n_disc_layers, n_discriminators=n_disc,
        )

    def forward(self, x, seg):
        ''' Performs a full forward pass for training. '''
        mu, logvar = self.encode(x)
        z = self.sample_z(mu, logvar)
        x_fake = self.generate(z, seg)
        pred = self.discriminate(x_fake, seg)
        return x_fake, pred

    def encode(self, x):
        return self.encoder(x)

    def generate(self, z, seg):
        ''' Generates fake image from noise vector and segmentation. '''
        return self.generator(z, seg)

    def discriminate(self, x, seg):
        ''' Predicts whether input image is real. '''
        return self.discriminator(torch.cat((x, seg), dim=1))

    @staticmethod
    def sample_z(mu, logvar):
        ''' Samples noise vector with reparameterization trick. '''
        eps = torch.randn(mu.size(), device=mu.device).to(mu.dtype)
        return (logvar / 2).exp() * eps + mu

    @property
    def n_disc(self):
        return self.discriminator.n_discriminators

Loss Functions

GauGAN reuses the composite loss functions that Pix2PixHD does, except it replaces the LSGAN loss with Hinge loss. It also imposes a soft (0.05 weight) Kullbach-Leibler divergence (KLD) loss term on the Gaussian statistics generated by the encoder.

A debrief on KLD

KLD measures how different two probability distributions are. In the case of $\mathcal{N}(\mu, \sigma^2I)$ learned by the encoder, KLD loss encourages the learned distribution to be close to a standard Gaussian. For more information on implementation, check out Pytorch’s KLDivLoss documentation.

import torchvision.models as models

class VGG19(nn.Module):
    '''
    VGG19 Class
    Wrapper for pretrained torchvision.models.vgg19 to output intermediate feature maps
    '''

    def __init__(self):
        super().__init__()
        vgg_features = models.vgg19(pretrained=True).features

        self.f1 = nn.Sequential(*[vgg_features[x] for x in range(2)])
        self.f2 = nn.Sequential(*[vgg_features[x] for x in range(2, 7)])
        self.f3 = nn.Sequential(*[vgg_features[x] for x in range(7, 12)])
        self.f4 = nn.Sequential(*[vgg_features[x] for x in range(12, 21)])
        self.f5 = nn.Sequential(*[vgg_features[x] for x in range(21, 30)])

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h1 = self.f1(x)
        h2 = self.f2(h1)
        h3 = self.f3(h2)
        h4 = self.f4(h3)
        h5 = self.f5(h4)
        return [h1, h2, h3, h4, h5]

class Loss(nn.Module):
    '''
    Loss Class
    Implements composite loss for GauGAN
    Values:
        lambda1: weight for feature matching loss, a float
        lambda2: weight for vgg perceptual loss, a float
        lambda3: weight for KLD loss, a float
        device: 'cuda' or 'cpu' for hardware to use
        norm_weight_to_one: whether to normalize weights to (0, 1], a bool
    '''

    def __init__(self, lambda1=10., lambda2=10., lambda3=0.05, device='cuda', norm_weight_to_one=True):
        super().__init__()

        self.vgg = VGG19().to(device)
        self.vgg_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

        lambda0 = 1.0
        # Keep ratio of composite loss, but scale down max to 1.0
        scale = max(lambda0, lambda1, lambda2, lambda3) if norm_weight_to_one else 1.0

        self.lambda0 = lambda0 / scale
        self.lambda1 = lambda1 / scale
        self.lambda2 = lambda2 / scale
        self.lambda3 = lambda3 / scale

    def kld_loss(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp())

    def g_adv_loss(self, discriminator_preds):
        adv_loss = 0.0
        for preds in discriminator_preds:
            pred = preds[-1]
            adv_loss += -pred.mean()
        return adv_loss

    def d_adv_loss(self, discriminator_preds, is_real):
        adv_loss = 0.0
        for preds in discriminator_preds:
            pred = preds[-1]
            target = -1 + pred if is_real else -1 - pred
            mask = target < 0
            adv_loss += (mask * target).mean()
        return adv_loss

    def fm_loss(self, real_preds, fake_preds):
        fm_loss = 0.0
        for real_features, fake_features in zip(real_preds, fake_preds):
            for real_feature, fake_feature in zip(real_features, fake_features):
                fm_loss += F.l1_loss(real_feature.detach(), fake_feature)
        return fm_loss

    def vgg_loss(self, x_real, x_fake):
        vgg_real = self.vgg(x_real)
        vgg_fake = self.vgg(x_fake)

        vgg_loss = 0.0
        for real, fake, weight in zip(vgg_real, vgg_fake, self.vgg_weights):
            vgg_loss += weight * F.l1_loss(real.detach(), fake)
        return vgg_loss

    def forward(self, x_real, label_map, gaugan):
        '''
        Function that computes the forward pass and total loss for GauGAN.
        '''
        mu, logvar = gaugan.encode(x_real)
        z = gaugan.sample_z(mu, logvar)
        x_fake = gaugan.generate(z, label_map)

        # Get necessary outputs for loss/backprop for both generator and discriminator
        fake_preds_for_g = gaugan.discriminate(x_fake, label_map)
        fake_preds_for_d = gaugan.discriminate(x_fake.detach(), label_map)
        real_preds_for_d = gaugan.discriminate(x_real.detach(), label_map)

        g_loss = (
            self.lambda0 * self.g_adv_loss(fake_preds_for_g) + \
            self.lambda1 * self.fm_loss(real_preds_for_d, fake_preds_for_g) / gaugan.n_disc + \
            self.lambda2 * self.vgg_loss(x_fake, x_real) + \
            self.lambda3 * self.kld_loss(mu, logvar)
        )
        d_loss = 0.5 * (
            self.d_adv_loss(real_preds_for_d, True) + \
            self.d_adv_loss(fake_preds_for_d, False)
        )

        return g_loss, d_loss, x_fake.detach()

Training GauGAN

You now have GauGAN implemented! All you have to do now is prepare your dataset. The authors trained GauGAN on a variety of datasets, but for simplicity, this notebook is geared towards the Cityscapes dataset (Pix2PixHD is also trained on the Cityscapes dataset). You’ll have to download the dataset, unzip it, and put it in your data folder to initialize the dataset code below.

Specifically, you should download the gtFine_trainvaltest and leftImg8bit_trainvaltest and specify the corresponding data splits into the dataloader.

Below is a quick dataset class to help you load and preprocess the Cityscapes examples. You can also run the code below to download the dataset, after you’ve created an account at http://cityscapes-dataset.com.

# Store necessary cookies
import getpass, os
username = input("What is your Cityscapes username? (http://cityscapes-dataset.com) ")
password = getpass.getpass("What is your Cityscapes password? ")
os.mkdir("data")
os.system(f"wget --keep-session-cookies --save-cookies=data/cookies.txt --post-data 'username={username}&password={password}&submit=Login' https://www.cityscapes-dataset.com/login/")
# Download data
!cd data; wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
!cd data; wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3
# Unzip data
!cd data; unzip leftImg8bit_trainvaltest
!cd data; unzip gtFine_trainvaltest
import os

import numpy as np
import torchvision.transforms as transforms
from PIL import Image

class CityscapesDataset(torch.utils.data.Dataset):
    '''
    CityscapesDataset Class
    Values:
        paths: (a list of) paths to construct dataset from, a list or string
        img_size: tuple containing the (height, width) for resizing, a tuple
        n_classes: the number of object classes, a scalar
    '''

    def __init__(self, paths, img_size=(256, 512), n_classes=35):
        super().__init__()

        self.n_classes = n_classes

        # Collect list of examples
        self.examples = {}
        if type(paths) == str:
            self.load_examples_from_dir(paths)
        elif type(paths) == list:
            for path in paths:
                self.load_examples_from_dir(path)
        else:
            raise ValueError('`paths` should be a single path or list of paths')

        self.examples = list(self.examples.values())
        assert all(len(example) == 2 for example in self.examples)

        # Initialize transforms for the real color image
        self.img_transforms = transforms.Compose([
            transforms.Resize(img_size),
            transforms.Lambda(lambda img: np.array(img)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        # Initialize transforms for semantic label maps
        self.map_transforms = transforms.Compose([
            transforms.Resize(img_size),
            transforms.Lambda(lambda img: np.array(img)),
            transforms.ToTensor(),
        ])

    def load_examples_from_dir(self, abs_path):
        '''
        Given a folder of examples, this function returns a list of paired examples.
        '''
        assert os.path.isdir(abs_path)

        img_suffix = '_leftImg8bit.png'
        label_suffix = '_gtFine_labelIds.png'

        for root, _, files in os.walk(abs_path):
            for f in files:
                if f.endswith(img_suffix):
                    prefix = f[:-len(img_suffix)]
                    attr = 'orig_img'
                elif f.endswith(label_suffix):
                    prefix = f[:-len(label_suffix)]
                    attr = 'label_map'
                else:
                    continue

                if prefix not in self.examples.keys():
                    self.examples[prefix] = {}
                self.examples[prefix][attr] = root + '/' + f

    def __getitem__(self, idx):
        example = self.examples[idx]

        # Load image and maps
        img = Image.open(example['orig_img']).convert('RGB') # color image: (3, h, w)
        label = Image.open(example['label_map'])             # semantic label map: (1, h, w)

        # Apply corresponding transforms
        img = self.img_transforms(img)
        label = self.map_transforms(label).long() * 255

        # Convert labels to one-hot vectors
        label = F.one_hot(label, num_classes=self.n_classes)
        label = label.squeeze(0).permute(2, 0, 1).to(img.dtype)

        return (img, label)

    def __len__(self):
        return len(self.examples)

    @staticmethod
    def collate_fn(batch):
        imgs, labels = [], []
        for (x, l) in batch:
            imgs.append(x)
            labels.append(l)
        return torch.stack(imgs, dim=0), torch.stack(labels, dim=0)

Now initialize everything you’ll need for training. Don’t be worried if there looks like a lot of random code, it’s all stuff you’ve seen before!

from tqdm import tqdm
from torch.utils.data import DataLoader

def lr_lambda(epoch):
    ''' Function for scheduling learning rate '''
    return 1. if epoch < decay_after else 1 - float(epoch - decay_after) / (epochs - decay_after)

def weights_init(m):
    ''' Function for initializing all model weights '''
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weights_init)

# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gaugan_config = {
    'n_classes': 35,
    'spatial_size': (128, 256), # Default (256, 512): halve size for memory
    'base_channels': 32,        # Default 64: halve channels for memory
    'z_dim': 256,
    'n_upsample': 5,            # Default 6: decrease layers for memory
    'n_disc_layers': 2,
    'n_disc': 3,
}
gaugan = GauGAN(**gaugan_config).to(device)
loss = Loss(device=device)

# Initialize dataloader
train_dir = ['data']
batch_size = 16                 # Default 32: decrease for memory
dataset = CityscapesDataset(
    train_dir, img_size=gaugan_config['spatial_size'], n_classes=gaugan_config['n_classes'],
)
dataloader = DataLoader(
    dataset, collate_fn=CityscapesDataset.collate_fn,
    batch_size=batch_size, shuffle=True,
    drop_last=False, pin_memory=True,
)

# Initialize optimizers + schedulers
epochs = 200                    # total number of train epochs
decay_after = 100               # number of epochs with constant lr
betas = (0.0, 0.999)

g_params = list(gaugan.generator.parameters()) + list(gaugan.encoder.parameters())
d_params = list(gaugan.discriminator.parameters())

g_optimizer = torch.optim.Adam(g_params, lr=1e-4, betas=betas)
d_optimizer = torch.optim.Adam(d_params, lr=4e-4, betas=betas)
g_scheduler = torch.optim.lr_scheduler.LambdaLR(g_optimizer, lr_lambda)
d_scheduler = torch.optim.lr_scheduler.LambdaLR(d_optimizer, lr_lambda)

And now the training loop, which is pretty much the same between the two phases:

from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Parse torch version for autocast
# #################################################
version = torch.__version__
version = tuple(int(n) for n in version.split('.')[:-1])
has_autocast = version >= (1, 6)
# #################################################

def show_tensor_images(image_tensor):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:1], nrow=1)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def train(dataloader, gaugan, optimizers, schedulers, device):
    g_optimizer, d_optimizer = optimizers
    g_scheduler, d_scheduler = schedulers

    cur_step = 0
    display_step = 100

    mean_g_loss = 0.0
    mean_d_loss = 0.0

    for epoch in range(epochs):
        # Training epoch
        for (x_real, labels) in tqdm(dataloader, position=0):
            x_real = x_real.to(device)
            labels = labels.to(device)

            # Enable autocast to FP16 tensors (new feature since torch==1.6.0)
            # If you're running older versions of torch, comment this out
            # and use NVIDIA apex for mixed/half precision training
            if has_autocast:
                with torch.cuda.amp.autocast(enabled=(device=='cuda')):
                    g_loss, d_loss, x_fake = loss(x_real, labels, gaugan)
            else:
                g_loss, d_loss, x_fake = loss(x_real, labels, gaugan)

            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            mean_g_loss += g_loss.item() / display_step
            mean_d_loss += d_loss.item() / display_step

            if cur_step % display_step == 0 and cur_step > 0:
                print('Step {}: Generator loss: {:.5f}, Discriminator loss: {:.5f}'
                      .format(cur_step, mean_g_loss, mean_d_loss))
                show_tensor_images(x_fake.to(x_real.dtype))
                show_tensor_images(x_real)
                mean_g_loss = 0.0
                mean_d_loss = 0.0
            cur_step += 1

        g_scheduler.step()
        d_scheduler.step()

train(
    dataloader, gaugan,
    [g_optimizer, d_optimizer],
    [g_scheduler, d_scheduler],
    device,
)
100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
  9%|▉         | 1/11 [00:03<00:30,  3.02s/it]

Step 100: Generator loss: 9.22943, Discriminator loss: -26.50570

png

png

100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
 18%|█▊        | 2/11 [00:06<00:27,  3.01s/it]

Step 200: Generator loss: 8.51473, Discriminator loss: -32.23375

png

png

100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
 27%|██▋       | 3/11 [00:08<00:23,  2.98s/it]

Step 300: Generator loss: 8.22687, Discriminator loss: -32.52094

png

png

100%|██████████| 11/11 [00:33<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
 36%|███▋      | 4/11 [00:11<00:21,  3.01s/it]

Step 400: Generator loss: 7.55541, Discriminator loss: -32.72000

png

png

100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
 45%|████▌     | 5/11 [00:15<00:18,  3.03s/it]

Step 500: Generator loss: 6.25381, Discriminator loss: -33.17156

png

png

100%|██████████| 11/11 [00:33<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
 55%|█████▍    | 6/11 [00:18<00:15,  3.01s/it]

Step 600: Generator loss: 5.81972, Discriminator loss: -33.37750

png

png

100%|██████████| 11/11 [00:33<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
 64%|██████▎   | 7/11 [00:21<00:12,  3.03s/it]

Step 700: Generator loss: 5.39125, Discriminator loss: -33.56469

png

png

100%|██████████| 11/11 [00:33<00:00,  3.01s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
 73%|███████▎  | 8/11 [00:24<00:09,  3.01s/it]

Step 800: Generator loss: 5.42630, Discriminator loss: -33.64375

png

png

100%|██████████| 11/11 [00:33<00:00,  3.02s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
 82%|████████▏ | 9/11 [00:27<00:06,  3.01s/it]

Step 900: Generator loss: 4.99733, Discriminator loss: -33.77875

png

png

100%|██████████| 11/11 [00:33<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
 91%|█████████ | 10/11 [00:30<00:03,  3.01s/it]

Step 1000: Generator loss: 4.93960, Discriminator loss: -33.81656

png

png

100%|██████████| 11/11 [00:33<00:00,  3.01s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

Step 1100: Generator loss: 5.27886, Discriminator loss: -33.70094

png

png

100%|██████████| 11/11 [00:33<00:00,  3.02s/it]
100%|██████████| 11/11 [00:33<00:00,  3.02s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  3.00s/it]
  9%|▉         | 1/11 [00:03<00:30,  3.07s/it]

Step 1200: Generator loss: 4.79553, Discriminator loss: -33.86656

png

png

100%|██████████| 11/11 [00:33<00:00,  3.04s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
 18%|█▊        | 2/11 [00:05<00:26,  2.98s/it]

Step 1300: Generator loss: 4.56317, Discriminator loss: -33.98875

png

png

100%|██████████| 11/11 [00:32<00:00,  2.99s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
 27%|██▋       | 3/11 [00:08<00:23,  2.97s/it]

Step 1400: Generator loss: 4.35006, Discriminator loss: -34.02750

png

png

100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.91s/it]
100%|██████████| 11/11 [00:31<00:00,  2.91s/it]
 36%|███▋      | 4/11 [00:11<00:20,  2.97s/it]

Step 1500: Generator loss: 4.22546, Discriminator loss: -34.06219

png

png

100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:31<00:00,  2.90s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
 45%|████▌     | 5/11 [00:14<00:17,  2.94s/it]

Step 1600: Generator loss: 4.13058, Discriminator loss: -34.10000

png

png

100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.91s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
 55%|█████▍    | 6/11 [00:17<00:14,  2.99s/it]

Step 1700: Generator loss: 3.98097, Discriminator loss: -34.13719

png

png

100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:31<00:00,  2.90s/it]
 64%|██████▎   | 7/11 [00:20<00:11,  2.95s/it]

Step 1800: Generator loss: 3.90972, Discriminator loss: -34.17094

png

png

100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.91s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:31<00:00,  2.89s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
 73%|███████▎  | 8/11 [00:23<00:08,  2.97s/it]

Step 1900: Generator loss: 3.90601, Discriminator loss: -34.16000

png

png

100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.91s/it]
 82%|████████▏ | 9/11 [00:26<00:05,  2.96s/it]

Step 2000: Generator loss: 3.86219, Discriminator loss: -34.18219

png

png

100%|██████████| 11/11 [00:32<00:00,  2.96s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
 91%|█████████ | 10/11 [00:29<00:02,  2.98s/it]

Step 2100: Generator loss: 3.74077, Discriminator loss: -34.18875

png

png

100%|██████████| 11/11 [00:32<00:00,  2.97s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.93s/it]
100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]
100%|██████████| 11/11 [00:32<00:00,  2.95s/it]