Coursera

Components of BigGAN

In this notebook, you’ll learn about and implement the components of BigGAN, the first large-scale GAN architecture proposed in Large Scale GAN Training for High Fidelity Natural Image Synthesis (Brock et al. 2019). BigGAN performs a conditional generation task, so unlike StyleGAN, it conditions on a certain class to generate results. BigGAN is based mainly on empirical results and shows extremely good results when trained on ImageNet and its 1000 classes.

The authors propose a several changes that improve state-of-the-art Inception Score (IS) and Frechet Inception Distance (FID), including:

BigGAN Architecture BigGAN Architecture Components, taken from Figure 15 in Large Scale GAN Training for High Fidelity Natural Image Synthesis (Brock et al. 2019). (a) A typical architectural layout for BigGAN’s generator. See Appendix B for details. (b) A Residual Block (ResBlock up) in BigGAN’s generator. (c) A Residual Block (ResBlock down) in BigGAN’s discriminator.

The Truncation Trick and Orthogonal Regularization

You should already be familiar with the truncation trick, which truncates the range of values of random noise $z$. Truncation to values close to 0 increases fidelity but decreases variety. Truncation to values further from 0 does the opposite.

Truncation results in a different distribution of $z$ values from the one seen in training, which can cause saturation artifacts. The authors address this by making $G$ well-defined, or smooth, on the full distribution of $z$ values.

To do this, they employ orthogonal regularization, first introduced in Neural Photo Editing with Introspective Adversarial Networks (Brock et al. 2017). The authors modify this regularization technique for BigGAN and formulate it as

\begin{align*} R_\beta(W) = \beta\big|\big|W^\top W \odot (\pmb{1} - I)\big|\big|^2_F, \end{align*} where $\pmb{1}$ denotes a matrix of 1’s. This regularization term removes the diagonal terms from the regularization and aims to minimize the pairwise cosine similarity between filters without constraining their norm.

Truncation Trick Generated images with different truncation thresholds, taken from Figure 2 in Large Scale GAN Training for High Fidelity Natural Image Synthesis (Brock et al. 2019). (a) The effects of increasing truncation. From left to right, the threshold is set to 2, 1, 0.5, 0.04. (b) Saturation artifacts from applying truncation to a poorly conditioned model.

Below is the implementation for orthogonal regularization. You can refer to the StyleGAN notebook for the truncation trick code.

# Some setup
import torch
import torch.nn as nn
import torch.nn.functional as F

def orthogonal_regularization(weight):
    '''
    Function for computing the orthogonal regularization term for a given weight matrix.
    '''
    weight = weight.flatten(1)
    return torch.norm(
        torch.dot(weight, weight) * (torch.ones_like(weight) - torch.eye(weight.shape[0]))
    )

BigGAN Parts

Before jumping into the full implementation, let’s first take a look at some submodules that will be important in our BigGAN implementation later.

Class-conditional Batch Normalization

Recall that batch norm aims to normalize activation statistics to a standard gaussian distribution (via an exponential moving average of minibatch mean and variances) but also applies trainable parameters, $\gamma$ and $\beta$, to invert this operation if the model sees fit:

\begin{align*} y &= \dfrac{x - \hat{\mu}}{\hat{\sigma} + \epsilon} * \gamma + \beta. \end{align*}

BigGAN injects class-conditional information by parameterizing $\gamma$ and $\beta$ as linear transformations of the class embedding, $c$. Recall that BigGAN also concatenates $c$ with $z$ skip connections (denoted $[c, z]$), so

\begin{align*} \gamma &:= W_\gamma^\top[c, z] \ \beta &:= W_\beta^\top[c, z] \end{align*}

The idea is actually very similar to the adaptive instance normalization (AdaIN) module that you implemented in the StyleGAN notebook, so we’ve copied that code in comments below for reference.

class ClassConditionalBatchNorm2d(nn.Module):
    '''
    ClassConditionalBatchNorm2d Class
    Values:
    in_channels: the dimension of the class embedding (c) + noise vector (z), a scalar
    out_channels: the dimension of the activation tensor to be normalized, a scalar
    '''

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.class_scale_transform = nn.utils.spectral_norm(nn.Linear(in_channels, out_channels, bias=False))
        self.class_shift_transform = nn.utils.spectral_norm(nn.Linear(in_channels, out_channels, bias=False))

    def forward(self, x, y):
        normalized_image = self.bn(x)
        class_scale = (1 + self.class_scale_transform(y))[:, :, None, None]
        class_shift = self.class_shift_transform(y)[:, :, None, None]
        transformed_image = class_scale * normalized_image + class_shift
        return transformed_image

# class AdaIN(nn.Module):
#     '''
#     AdaIN Class, extends/subclass of nn.Module
#     Values:
#       channels: the number of channels the image has, a scalar
#       w_dim: the dimension of the intermediate tensor, w, a scalar 
#     '''

#     def __init__(self, channels, w_dim):
#         super().__init__()
#         self.instance_norm = nn.InstanceNorm2d(channels)
#         self.style_scale_transform = nn.Linear(w_dim, channels)
#         self.style_shift_transform = nn.Linear(w_dim, channels)

#     def forward(self, image, w):
#         normalized_image = self.instance_norm(image)
#         style_scale = self.style_scale_transform(w)[:, :, None, None]
#         style_shift = self.style_shift_transform(w)[:, :, None, None]
#         transformed_image = style_scale * normalized_image + style_shift
#         return transformed_image

Self-Attention Block

As you may already know, self-attention has been a successful technique in helping models learn arbitrary, long-term dependencies. Self-Attention Generative Adversarial Networks (Zhang et al. 2018) first introduced the self-attention mechanism into the GAN architecture. BigGAN augments its residual blocks with these attention blocks.

A Quick Primer on Self-Attention

Self-attention is just scaled dot product attention. Given a sequence $S$ (with images, $S$ is just the image flattened across its height and width), the model learns mappings to query ($Q$), key ($K$), and value ($V$) matrices:

\begin{align*} Q &:= W_q^\top S \ K &:= W_k^\top S \ V &:= W_v^\top S \end{align*}

where $W_q$, $W_k$, and $W_v$ are learned parameters. The subsequent self-attention mechanism is then computed as

\begin{align*} \text{Attention}(Q, K, V) = \text{softmax}\left(\dfrac{QK^\top}{\sqrt{d_k}}\right)V \end{align*}

where $d_k$ is the dimensionality of the $Q, K$ matrices (SA-GAN and BigGAN both omit this term). Intuitively, you can think of the query matrix as containing the representations of each position with respect to itself and the key matrix as containing the representations of each position with respect to the others. How important two positions are to each other is measured by dot product as $QK^\top$, hence dot product attention. A softmax is applied to convert these relative importances to a probability distribution over all positions.

Intuitively, the value matrix provides the importance weighting of the attention at each position, hence scaled dot product attention. Relevant positions should be assigned larger weight and irrelevant ones should be assigned smaller weight.

Don’t worry if you don’t understand this right away - it’s a tough concept! For extra reading, you should check out Attention Is All You Need (Vaswani et al. 2017), which is the paper that first introduces this technique, and The Illustrated Transformer, which breaks down and explains the self-attention mechanism clearly.

class AttentionBlock(nn.Module):
    '''
    AttentionBlock Class
    Values:
    channels: number of channels in input
    '''
    def __init__(self, channels):
        super().__init__()

        self.channels = channels

        self.theta = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))
        self.phi = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))
        self.g = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 2, kernel_size=1, padding=0, bias=False))
        self.o = nn.utils.spectral_norm(nn.Conv2d(channels // 2, channels, kernel_size=1, padding=0, bias=False))

        self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self, x):
        spatial_size = x.shape[2] * x.shape[3]

        # Apply convolutions to get query (theta), key (phi), and value (g) transforms
        theta = self.theta(x)
        phi = F.max_pool2d(self.phi(x), kernel_size=2)
        g = F.max_pool2d(self.g(x), kernel_size=2)

        # Reshape spatial size for self-attention
        theta = theta.view(-1, self.channels // 8, spatial_size)
        phi = phi.view(-1, self.channels // 8, spatial_size // 4)
        g = g.view(-1, self.channels // 2, spatial_size // 4)

        # Compute dot product attention with query (theta) and key (phi) matrices
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), dim=-1)

        # Compute scaled dot product attention with value (g) and attention (beta) matrices
        o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, x.shape[2], x.shape[3]))

        # Apply gain and residual
        return self.gamma * o + x

BigGAN Generator

Before implementing the generator in full, you first need to implement the generator residual block.

Generator Residual Block

As with many state-of-the-art computer vision models, BigGAN employs skip connections in the form of residual blocks to map random noise to a fake image. You can think of BigGAN residual blocks as having 3 steps. Given input $x$ and class embedding $y$:

  1. $h :=$ bn-relu-upsample-conv$(x, y)$
  2. $h :=$ bn-relu-conv$(h, y)$
  3. $x :=$ upsample-conv$(x)$,

after which you can apply a residual connection and return $h + x$.

class GResidualBlock(nn.Module):
    '''
    GResidualBlock Class
    Values:
    c_dim: the dimension of conditional vector [c, z], a scalar
    in_channels: the number of channels in the input, a scalar
    out_channels: the number of channels in the output, a scalar
    '''

    def __init__(self, c_dim, in_channels, out_channels):
        super().__init__()

        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

        self.bn1 = ClassConditionalBatchNorm2d(c_dim, in_channels)
        self.bn2 = ClassConditionalBatchNorm2d(c_dim, out_channels)

        self.activation = nn.ReLU()
        self.upsample_fn = nn.Upsample(scale_factor=2)     # upsample occurs in every gblock

        self.mixin = (in_channels != out_channels)
        if self.mixin:
            self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def forward(self, x, y):
        # h := upsample(x, y)
        h = self.bn1(x, y)
        h = self.activation(h)
        h = self.upsample_fn(h)
        h = self.conv1(h)

        # h := conv(h, y)
        h = self.bn2(h, y)
        h = self.activation(h)
        h = self.conv2(h)

        # x := upsample(x)
        x = self.upsample_fn(x)
        if self.mixin:
            x = self.conv_mixin(x)

        return h + x

You can now implement the BigGAN generator in full!! Below is an implementation of the base model (at 128x128 resolution) from the paper.

This implementation uses nn.ModuleList for convenience. If you’re not familiar with this, you can think of it as simply a Pythonic list that registers your modules with the Pytorch backend. For more information, see the torch.nn.ModuleList documentation.

class Generator(nn.Module):
    '''
    Generator Class
    Values:
    z_dim: the dimension of random noise sampled, a scalar
    shared_dim: the dimension of shared class embeddings, a scalar
    base_channels: the number of base channels, a scalar
    bottom_width: the height/width of image before it gets upsampled, a scalar
    n_classes: the number of image classes, a scalar
    '''
 
    def __init__(self, base_channels=96, bottom_width=4, z_dim=120, shared_dim=128, n_classes=1000):
        super().__init__()
 
        n_chunks = 6    # 5 (generator blocks) + 1 (generator input)
        self.z_chunk_size = z_dim // n_chunks
        self.z_dim = z_dim
        self.shared_dim = shared_dim
        self.bottom_width = bottom_width
 
        # No spectral normalization on embeddings, which authors observe to cripple the generator
        self.shared_emb = nn.Embedding(n_classes, shared_dim)
 
        self.proj_z = nn.Linear(self.z_chunk_size, 16 * base_channels * bottom_width ** 2)
 
        # Can't use one big nn.Sequential since we are adding class+noise at each block
        self.g_blocks = nn.ModuleList([
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 16 * base_channels),
                AttentionBlock(16 * base_channels),
            ]),
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 8 * base_channels),
                AttentionBlock(8 * base_channels),
            ]),
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 8 * base_channels, 4 * base_channels),
                AttentionBlock(4 * base_channels),
            ]),
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 4 * base_channels, 2 * base_channels),
                AttentionBlock(2 * base_channels),
            ]),
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 2 * base_channels, base_channels),
                AttentionBlock(base_channels),
            ]),
        ])
        self.proj_o = nn.Sequential(
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(base_channels, 3, kernel_size=1, padding=0)),
            nn.Tanh(),
        )
 
    def forward(self, z, y):
        '''
        z: random noise with size self.z_dim
        y: class embeddings with size self.shared_dim
            = NOTE =
            y should be class embeddings from self.shared_emb, not the raw class labels
        '''
        # Chunk z and concatenate to shared class embeddings
        zs = torch.split(z, self.z_chunk_size, dim=1)
        z = zs[0]
        ys = [torch.cat([y, z], dim=1) for z in zs[1:]]
 
        # Project noise and reshape to feed through generator blocks
        h = self.proj_z(z)
        h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
 
        # Feed through generator blocks
        for idx, g_block in enumerate(self.g_blocks):
            h = g_block[0](h, ys[idx])
            h = g_block[1](h)
 
        # Project to 3 RGB channels with tanh to map values to [-1, 1]
        h = self.proj_o(h)
 
        return h

BigGAN Discriminator

Before implementing the discriminator in full, you need to implement a discriminator residual block, which is simpler than the generator’s. Note that the last residual block does not apply downsampling.

  1. $h :=$ relu-conv-relu-downsample$(x)$
  2. $x :=$ conv-downsample$(x)$

In the official BigGAN implementation, the architecture is slightly different for the first discriminator residual block, since it handles the raw image as input:

  1. $h :=$ conv-relu-downsample$(x)$
  2. $x :=$ downsample-conv$(x)$

After these two steps, you can return the residual connection $h + x$. You might notice that there is no class information in these residual blocks. As you’ll see later in the code, the authors inject class-conditional information after the final hidden layer (and before the output layer) via channel-wise dot product.

class DResidualBlock(nn.Module):
    '''
    DResidualBlock Class
    Values:
    in_channels: the number of channels in the input, a scalar
    out_channels: the number of channels in the output, a scalar
    downsample: whether to apply downsampling
    use_preactivation: whether to apply an activation function before the first convolution
    '''

    def __init__(self, in_channels, out_channels, downsample=True, use_preactivation=False):
        super().__init__()

        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

        self.activation = nn.ReLU()
        self.use_preactivation = use_preactivation  # apply preactivation in all except first dblock

        self.downsample = downsample    # downsample occurs in all except last dblock
        if downsample:
            self.downsample_fn = nn.AvgPool2d(2)
        self.mixin = (in_channels != out_channels) or downsample
        if self.mixin:
            self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def _residual(self, x):
        if self.use_preactivation:
            if self.mixin:
                x = self.conv_mixin(x)
            if self.downsample:
                x = self.downsample_fn(x)
        else:
            if self.downsample:
                x = self.downsample_fn(x)
            if self.mixin:
                x = self.conv_mixin(x)
        return x

    def forward(self, x):
        # Apply preactivation if applicable
        if self.use_preactivation:
            h = F.relu(x)
        else:
            h = x

        h = self.conv1(h)
        h = self.activation(h)
        if self.downsample:
            h = self.downsample_fn(h)

        return h + self._residual(x)

Now implement the BigGAN discriminator in full!!

class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
    base_channels: the number of base channels, a scalar
    n_classes: the number of image classes, a scalar
    '''

    def __init__(self, base_channels=96, n_classes=1000):
        super().__init__()

        # For adding class-conditional evidence
        self.shared_emb = nn.utils.spectral_norm(nn.Embedding(n_classes, 16 * base_channels))

        self.d_blocks = nn.Sequential(
            DResidualBlock(3, base_channels, downsample=True, use_preactivation=False),
            AttentionBlock(base_channels),

            DResidualBlock(base_channels, 2 * base_channels, downsample=True, use_preactivation=True),
            AttentionBlock(2 * base_channels),

            DResidualBlock(2 * base_channels, 4 * base_channels, downsample=True, use_preactivation=True),
            AttentionBlock(4 * base_channels),

            DResidualBlock(4 * base_channels, 8 * base_channels, downsample=True, use_preactivation=True),
            AttentionBlock(8 * base_channels),

            DResidualBlock(8 * base_channels, 16 * base_channels, downsample=True, use_preactivation=True),
            AttentionBlock(16 * base_channels),

            DResidualBlock(16 * base_channels, 16 * base_channels, downsample=False, use_preactivation=True),
            AttentionBlock(16 * base_channels),

            nn.ReLU(inplace=True),
        )
        self.proj_o = nn.utils.spectral_norm(nn.Linear(16 * base_channels, 1))

    def forward(self, x, y=None):
        h = self.d_blocks(x)
        h = torch.sum(h, dim=[2, 3])

        # Class-unconditional output
        uncond_out = self.proj_o(h)
        if y is None:
            return uncond_out

        # Class-conditional output
        cond_out = torch.sum(self.shared_emb(y) * h, dim=1, keepdim=True)
        return uncond_out + cond_out

Setting Up BigGAN Training

Now you’re are ready to set up BigGAN for training! Unfortunately, this notebook will not provide actual training code due to the size of BigGAN.

device = 'cpu'

# Initialize models
base_channels = 96
z_dim = 120
n_classes = 5   # 5 classes is used instead of the original 1000, for efficiency
shared_dim = 128
generator = Generator(base_channels=base_channels, bottom_width=4, z_dim=z_dim, shared_dim=shared_dim, n_classes=n_classes).to(device)
discriminator = Discriminator(base_channels=base_channels, n_classes=n_classes).to(device)

# Initialize weights orthogonally
for module in generator.modules():
    if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
        nn.init.orthogonal_(module.weight)
for module in discriminator.modules():
    if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
        nn.init.orthogonal_(module.weight)

# Initialize optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.999), eps=1e-6)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=4e-4, betas=(0.0, 0.999), eps=1e-6)

Here is a sample forward pass:

batch_size = n_classes

z = torch.randn(batch_size, z_dim, device=device)                 # Generate random noise (z)
y = torch.arange(start=0, end=n_classes, device=device).long()    # Generate a batch of labels (y), one for each class
y_emb = generator.shared_emb(y)                                   # Retrieve class embeddings (y_emb) from generator

x_gen = generator(z, y_emb)                                       # Generate fake images from z and y_emb
score = discriminator(x_gen, y)                                   # Generate classification for fake images
x_gen
tensor([[[[ 0.3764,  0.2062,  0.1003,  ...,  0.2038,  0.6493,  0.5354],
          [ 0.2120, -0.0244,  0.0219,  ...,  0.1268,  0.6609,  0.4200],
          [ 0.3764,  0.0762,  0.0928,  ..., -0.1703,  0.4835,  0.0715],
          ...,
          [ 0.1498,  0.2105,  0.2855,  ...,  0.4390,  0.5243,  0.4624],
          [ 0.0997,  0.2285,  0.1930,  ...,  0.6904,  0.6122,  0.5547],
          [-0.1155, -0.0126,  0.0427,  ...,  0.5848,  0.5929,  0.5208]],

         [[-0.6701, -0.6237, -0.4032,  ..., -0.2320, -0.1352, -0.2514],
          [-0.5353, -0.5857, -0.4197,  ..., -0.0198,  0.2324, -0.1094],
          [-0.5952, -0.6149, -0.4792,  ..., -0.0295,  0.1717, -0.0544],
          ...,
          [ 0.2554,  0.1961, -0.2161,  ..., -0.1560, -0.1466, -0.1050],
          [ 0.0770, -0.0344, -0.2909,  ..., -0.5381, -0.3345, -0.2841],
          [-0.0124, -0.2283, -0.4513,  ..., -0.6251, -0.5011, -0.4414]],

         [[-0.1124,  0.1320,  0.1490,  ...,  0.3488,  0.3543,  0.2205],
          [-0.2230,  0.3010,  0.1071,  ...,  0.3642,  0.3061,  0.2106],
          [-0.0705,  0.3298,  0.0484,  ...,  0.3001,  0.3170,  0.2937],
          ...,
          [ 0.2445,  0.5904,  0.3022,  ..., -0.3751, -0.4246, -0.2433],
          [-0.0503,  0.1878, -0.0436,  ..., -0.4021, -0.4097, -0.2530],
          [-0.2906, -0.0291, -0.1358,  ..., -0.5380, -0.4662, -0.3413]]],


        [[[ 0.5642,  0.2991,  0.3152,  ...,  0.6265,  0.7206,  0.5652],
          [ 0.3819,  0.0656,  0.1877,  ...,  0.3383,  0.5584,  0.4696],
          [ 0.5467,  0.4571,  0.3816,  ..., -0.1298, -0.0461,  0.0637],
          ...,
          [ 0.6548,  0.5589,  0.7471,  ...,  0.4723,  0.3542,  0.2683],
          [ 0.7054,  0.7183,  0.7809,  ...,  0.4768,  0.2671,  0.2824],
          [ 0.7160,  0.7471,  0.7625,  ...,  0.3614,  0.3019,  0.3225]],

         [[-0.6420, -0.4541, -0.4790,  ..., -0.0721, -0.2161, -0.2023],
          [-0.5909, -0.3950, -0.4190,  ...,  0.0698, -0.0466, -0.0419],
          [-0.3349, -0.1775, -0.3724,  ...,  0.3626,  0.0664,  0.1243],
          ...,
          [ 0.2504,  0.3931,  0.1362,  ...,  0.1815,  0.0648, -0.0537],
          [-0.1490, -0.0898, -0.5650,  ..., -0.2663, -0.1301, -0.3496],
          [-0.0882, -0.1057, -0.4517,  ..., -0.2561, -0.2773, -0.2850]],

         [[-0.2685,  0.0065,  0.0269,  ...,  0.1906,  0.4550,  0.1919],
          [-0.3043,  0.1801,  0.2201,  ...,  0.0987,  0.3714,  0.2143],
          [-0.2136,  0.1485,  0.2857,  ...,  0.2016,  0.4364,  0.2552],
          ...,
          [-0.1101,  0.4065,  0.5461,  ...,  0.5666,  0.5787,  0.5179],
          [-0.3900, -0.0261,  0.1544,  ...,  0.3306,  0.4521,  0.3786],
          [-0.4478, -0.1992, -0.0117,  ...,  0.0107,  0.3137,  0.1977]]],


        [[[ 0.4363,  0.4289,  0.2121,  ..., -0.2542,  0.0885,  0.2126],
          [ 0.3076,  0.1416,  0.0356,  ..., -0.4902, -0.0588,  0.2214],
          [ 0.5754,  0.3541,  0.5082,  ..., -0.4994, -0.0241,  0.0354],
          ...,
          [ 0.5406,  0.5846,  0.6225,  ...,  0.2741,  0.2576,  0.3010],
          [ 0.5879,  0.6735,  0.5780,  ...,  0.3568,  0.1035,  0.0337],
          [ 0.5953,  0.6407,  0.5753,  ...,  0.1657, -0.0354, -0.0823]],

         [[-0.8172, -0.7306, -0.7337,  ..., -0.0693,  0.0182, -0.2005],
          [-0.7926, -0.7521, -0.7359,  ..., -0.1903,  0.0283, -0.2898],
          [-0.8871, -0.8661, -0.8765,  ..., -0.1838, -0.0615, -0.2026],
          ...,
          [-0.0294, -0.2769, -0.1690,  ..., -0.3879, -0.5560, -0.7139],
          [-0.3564, -0.5523, -0.6157,  ..., -0.7267, -0.6741, -0.7203],
          [-0.2925, -0.4557, -0.5065,  ..., -0.7561, -0.7102, -0.7310]],

         [[-0.0530,  0.3433,  0.3574,  ...,  0.1662,  0.3874,  0.1833],
          [-0.1432,  0.3105,  0.3135,  ...,  0.2063,  0.4068,  0.3589],
          [ 0.1059,  0.4621,  0.3752,  ...,  0.0058,  0.2921,  0.2202],
          ...,
          [ 0.3029,  0.6315,  0.3367,  ...,  0.1285,  0.0328,  0.0680],
          [-0.2057,  0.0406, -0.1752,  ..., -0.1776, -0.0647,  0.0253],
          [-0.1908,  0.0437, -0.1215,  ..., -0.2816, -0.1932, -0.0304]]],


        [[[ 0.7349,  0.7402,  0.5093,  ...,  0.4228,  0.6353,  0.5729],
          [ 0.6171,  0.6006,  0.3223,  ...,  0.3218,  0.4861,  0.4332],
          [ 0.6064,  0.6659,  0.4816,  ...,  0.1503,  0.1281,  0.1686],
          ...,
          [ 0.8096,  0.7520,  0.5844,  ...,  0.1101,  0.3970,  0.0932],
          [ 0.7348,  0.6589,  0.5708,  ...,  0.3332,  0.3334,  0.1541],
          [ 0.7083,  0.6791,  0.6723,  ...,  0.3586,  0.3242,  0.1613]],

         [[-0.6720, -0.6150, -0.5122,  ..., -0.0120,  0.1110, -0.0563],
          [-0.6801, -0.6543, -0.6448,  ..., -0.1728,  0.0339, -0.2468],
          [-0.7743, -0.7142, -0.6596,  ..., -0.3069, -0.1335, -0.5169],
          ...,
          [ 0.1018, -0.0362, -0.1768,  ..., -0.3483, -0.3730, -0.4768],
          [ 0.1513, -0.0649, -0.2814,  ..., -0.4619, -0.3229, -0.4285],
          [ 0.1112,  0.0683, -0.0914,  ..., -0.1884, -0.0454, -0.2744]],

         [[-0.2800,  0.0730,  0.1129,  ...,  0.2084,  0.3633,  0.1039],
          [-0.3499,  0.0238, -0.0326,  ...,  0.1355,  0.1874,  0.0722],
          [-0.1856,  0.1628,  0.0617,  ...,  0.3222,  0.3389,  0.2680],
          ...,
          [-0.2156, -0.0309, -0.3184,  ...,  0.3738,  0.4178,  0.3200],
          [-0.2431, -0.2815, -0.4111,  ...,  0.2746,  0.3818,  0.2741],
          [-0.1532, -0.1846, -0.2620,  ..., -0.0852,  0.1240,  0.1145]]],


        [[[ 0.5993,  0.6808,  0.7763,  ...,  0.6647,  0.8183,  0.7997],
          [ 0.4799,  0.4115,  0.5632,  ...,  0.5805,  0.7273,  0.7534],
          [ 0.5875,  0.4885,  0.6752,  ...,  0.5689,  0.6573,  0.7198],
          ...,
          [ 0.7894,  0.7746,  0.7015,  ...,  0.3178,  0.3655,  0.2970],
          [ 0.7641,  0.7415,  0.5554,  ...,  0.4623,  0.3067,  0.2978],
          [ 0.7174,  0.7244,  0.5910,  ...,  0.6001,  0.4971,  0.5115]],

         [[-0.7059, -0.5483, -0.5289,  ..., -0.6020, -0.5322, -0.5925],
          [-0.6409, -0.7030, -0.7319,  ..., -0.5934, -0.5573, -0.7274],
          [-0.6504, -0.7190, -0.7890,  ..., -0.5382, -0.5645, -0.7094],
          ...,
          [ 0.1639,  0.1628,  0.1495,  ..., -0.0085, -0.1639, -0.3762],
          [-0.1256, -0.2887, -0.3985,  ..., -0.3419, -0.3214, -0.5222],
          [ 0.1453, -0.0969, -0.2538,  ..., -0.4123, -0.4326, -0.5766]],

         [[-0.3271, -0.1243, -0.2936,  ...,  0.1795,  0.3613,  0.2006],
          [-0.4231, -0.1418, -0.2940,  ...,  0.0010,  0.0608,  0.1459],
          [-0.1204,  0.1780, -0.1157,  ...,  0.1180,  0.1269,  0.0769],
          ...,
          [-0.3976, -0.1623, -0.2040,  ...,  0.2804,  0.3946,  0.3863],
          [-0.5408, -0.4046, -0.4966,  ...,  0.4870,  0.4759,  0.4846],
          [-0.6698, -0.5049, -0.5486,  ...,  0.2692,  0.2718,  0.3009]]]],
       grad_fn=<TanhBackward>)
score
tensor([[-2.6239],
        [-6.1415],
        [-6.3239],
        [-9.7936],
        [-8.5146]], grad_fn=<AddBackward0>)

BigGAN-deep

Initially, the authors of the BigGAN paper didn’t find much help in increasing the depth of the network. But they experimented further (research is always improving!) and added a few notes about an additional architecture, called BigGAN-deep. This modification of BigGAN is 4x deeper, sports a modified residual block architecture, and concatenates the entire $z$ vector to $c$ (as opposed to separate chunks at different resolutions).

Typically on a difficult and complex task that you’re unlikely to overfit, you expect better performance when a model has more parameters, because it has more room to learn. Surprisingly, BigGAN-deep has fewer parameters than its BigGAN counterpart. Architectural optimizations such as using depthwise separable convolutions and truncating/concatenating channels in skip connections (as opposed to using pointwise convolutions) decrease parameters without trading expressivity.

For more details on the BigGAN-deep architecture, see Appendix B of the paper.

And as for the implementation of the BigGAN-deep variant, well, that’s left as an exercise for the reader. You’re a smart cookie, you’ll figure it out! Just keep in mind that with great power comes great responsibility ;)