Please note that this is an optional notebook, meant to introduce more advanced concepts if you’re up for a challenge, so don’t worry if you don’t completely follow!
It is recommended that you should already be familiar with:
In this notebook, you will learn about and implement MUNIT, a method for unsupervised image-to-image translation, as proposed in Multimodal Unsupervised Image-to-Image Translation (Huang et al. 2018).
MUNIT builds off UNIT’s proposition of a shared latent space, but MUNIT only uses a partially shared latent space. Specifically, the authors assume that the content latent space is shared between two domains, but the style latent spaces are unique to each domain.
Don’t worry if you aren’t familiar with UNIT - there will be a section that briefly goes over it!
Let’s begin with a quick overview of the UNIT framework and then move onto the MUNIT framework.
UNIT, proposed in Unsupervised Image-to-Image Translation Networks (Liu et al. 2018), is a method of image translation that assumes that images from different domains share a latent distribution.
Suppose that there are two image domains, $\mathcal{A}$ and $\mathcal{B}$. Images $(x_a, x_b) \in (\mathcal{A}, \mathcal{B})$ can be mapped to a shared latent space, $\mathcal{Z}$ via encoders $E_a: x_a \mapsto z$ and $E_b: x_b \mapsto z$, respectively. Synthetic images can be produced via generators $G_a: z \mapsto x_a’$ and $G_b: z \mapsto x_b’$, respectively. Note that the generators can generate self-reconstructed or domain-translated images for their respective domains.
And as per all other GAN frameworks, synthetic and real images, $(x_a’,x_a)$, and $(x_b’, x_b)$, are passed into discriminators, $D_a$ and $D_b$, respectively.
Suppose that there are two image domains, $\mathcal{A}$ and $\mathcal{B}$. A pair of corresponding images $(x_a, x_b) \in (\mathcal{A}, \mathcal{B})$ can be generated as $x_a = F_a(c, s_a)$ and $x_b = F_b(c, s_b)$ where $c$ is a content vector from a shared distribution, $s_a, s_b$ are style vectors from distinct distributions, and $F_a, F_b$ are decoders that synthesize images from the content and style vectors.
The idea is that while the content between two domains can be shared (i.e. you can interchange horses and zebras in an image), the styles are different between the two (i.e. you would draw horses and zebras differently).
To learn the content and style distributions in training, the authors also assume some $E_a, E_b$ invert $F_a, F_b$, respectively. Specifically, $E_a^c: x_a \mapsto c$ extracts content and $E_a^s: x_a \mapsto s_a$ extracts style from images in domain $\mathcal{A}$. The same applies for $E_b^c(x_b)$ and $E_b^s(x_b)$ with images in domain $\mathcal{B}$. You can mix and match the content and style vectors from the two domains to translate images from between the two.
For example, if you take content $b$, $c_b = E_b^c(x_b)$, and style $a$, $s_a = E_a^s(x_a)$, and pass these through the horse decoder as $F_a(c_b, s_a)$, you should end up with the image $b$ drawn with characteristics of image $a$.
Don’t worry if this is still unclear now! You’ll go over this in more detail later in the notebook.
Model overview, taken from Figure 2 of Multimodal Unsupervised Image-to-Image Translation (Huang et al. 2018). The red and blue arrows denote encoders-decoder pairs within the same domain. Left: same domain image reconstruction. Right: cross-domain latent (content and style) vector reconstruction.
You will start by importing libraries and defining a visualization function. This code is borrowed from the CycleGAN notebook so you should already be familiar with this!
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
torch.manual_seed(0)
def show_tensor_images(x_real, x_fake):
''' For visualizing images '''
image_tensor = torch.cat((x_fake[:1, ...], x_real[:1, ...]), dim=0)
image_tensor = (image_tensor + 1) / 2
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat, nrow=1)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()
import glob
import random
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
# Inspired by https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/datasets.py
class ImageDataset(Dataset):
def __init__(self, root, transform=None, mode='train'):
super().__init__()
self.transform = transform
self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
if len(self.files_A) > len(self.files_B):
self.files_A, self.files_B = self.files_B, self.files_A
self.new_perm()
assert len(self.files_A) > 0, "Make sure you downloaded the horse2zebra images!"
def new_perm(self):
self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
if item_A.shape[0] != 3:
item_A = item_A.repeat(3, 1, 1)
if item_B.shape[0] != 3:
item_B = item_B.repeat(3, 1, 1)
if index == len(self) - 1:
self.new_perm()
return item_A, item_B
def __len__(self):
return min(len(self.files_A), len(self.files_B))
if len(os.listdir(".")) < 3:
!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
!unzip horse2zebra.zip
MUNIT has a few key subcomponents that are used throughout the model. It’ll generally make your life easier if you implement the smaller parts first!
You’ve already learned about this layer in StyleGAN and seen a very similar cousin - class-conditional batch normalization - in the BigGAN components notebook.
The authors enhance the linear layers for scale and shift with a multi-layer perceptron (MLP), which is essentially just a series of linear layers to help learn more complex representations. See the figure in Submodules and the notes in Submodules: Decoder for more details.
class AdaptiveInstanceNorm2d(nn.Module):
'''
AdaptiveInstanceNorm2d Class
Values:
channels: the number of channels the image has, a scalar
s_dim: the dimension of the style tensor (s), a scalar
h_dim: the hidden dimension of the MLP, a scalar
'''
def __init__(self, channels, s_dim=8, h_dim=256):
super().__init__()
self.instance_norm = nn.InstanceNorm2d(channels, affine=False)
self.style_scale_transform = self.mlp(s_dim, h_dim, channels)
self.style_shift_transform = self.mlp(s_dim, h_dim, channels)
@staticmethod
def mlp(self, in_dim, h_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, out_dim),
)
def forward(self, image, w):
'''
Function for completing a forward pass of AdaIN: Given an image and a style,
returns the normalized image that has been scaled and shifted by the style.
Parameters:
image: the feature map of shape (n_samples, channels, width, height)
w: the intermediate noise vector w to be made into the style (y)
'''
normalized_image = self.instance_norm(image)
style_scale = self.style_scale_transform(w)[:, :, None, None]
style_shift = self.style_shift_transform(w)[:, :, None, None]
transformed_image = style_scale * normalized_image + style_shift
return transformed_image
MUNIT uses layer normalization in the upsampling layers of the decoder. Proposed in Layer Normalization (Ba et al. 2016), layer normalization operates similarly to all other normalization techniques like batch normalization, but instead of normalizing across minibatch examples per channel, it normalizes across channels per minibatch example.
Layer normalization is actually much more prevalent in NLP and but quite rare in computer vision. However, batch normalization is not viable here due to training batch sizes of 1 and instance normalization is undesirable because it normalizes the statistics at each position to a standard Gaussian, which removes style features.
Pytorch implements this as nn.LayerNorm but requires precomputed spatial size for initialization to accomodate for 1D, 2D, and 3D inputs. For convenience, let’s implement a size-agnostic layer normalization module for 2D inputs (i.e. images).
class LayerNorm2d(nn.Module):
'''
LayerNorm2d Class
Values:
channels: number of channels in input, a scalar
affine: whether to apply affine denormalization, a bool
'''
def __init__(self, channels, eps=1e-5, affine=True):
super().__init__()
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.rand(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
mean = x.flatten(1).mean(1).reshape(-1, 1, 1, 1)
std = x.flatten(1).std(1).reshape(-1, 1, 1, 1)
x = (x - mean) / (std + self.eps)
if self.affine:
x = x * self.gamma.reshape(1, -1, 1, 1) + self.beta.reshape(1, -1, 1, 1)
return x
By now, you should already be very familiar with residual blocks. Below is an implementation that supports both adaptive and non-adaptive instance normalization layers, since both are used throughout the model.
class ResidualBlock(nn.Module):
'''
ResidualBlock Class
Values:
channels: number of channels throughout residual block, a scalar
s_dim: the dimension of the style tensor (s), a scalar
h_dim: the hidden dimension of the MLP, a scalar
'''
def __init__(self, channels, s_dim=None, h_dim=None):
super().__init__()
self.conv1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(channels, channels, kernel_size=3)
),
)
self.conv2 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(channels, channels, kernel_size=3)
),
)
self.use_style = s_dim is not None and h_dim is not None
if self.use_style:
self.norm1 = AdaptiveInstanceNorm2d(channels, s_dim, h_dim)
self.norm2 = AdaptiveInstanceNorm2d(channels, s_dim, h_dim)
else:
self.norm1 = nn.InstanceNorm2d(channels)
self.norm2 = nn.InstanceNorm2d(channels)
self.activation = nn.ReLU()
def forward(self, x, s=None):
x_id = x
x = self.conv1(x)
x = self.norm1(x, s) if self.use_style else self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x, s) if self.use_style else self.norm2(x)
return x + x_id
Now that you’re all set up and implemented some basic building blocks, let’s take a look at the content encoder, style encoder, and decoder! These will be used in the generator.
Generator architecture, taken from Figure 3 of Multimodal Unsupervised Image-to-Image Translation (Huang et al. 2018). Content encoder: generates a downsampled representation of the original. Style encoder: generates a style code from the original. Decoder: synthesizes a fake image from content code, infused with style info.
Note: the official implementation feeds style code through a multi-layer perceptron (MLP) and assigns these values to the scale and shift parameters in the instance normalization layers. To be compatible with our previous definitions of AdaIN
, your implementation will simply apply the MLP within AdaptiveInstanceNorm2d
.
The content encoder is similar to many encoders you’ve already seen: it simply downsamples the input image and feeds it through residual blocks to obtain a condensed representation.
class ContentEncoder(nn.Module):
'''
ContentEncoder Class
Values:
base_channels: number of channels in first convolutional layer, a scalar
n_downsample: number of downsampling layers, a scalar
n_res_blocks: number of residual blocks, a scalar
'''
def __init__(self, base_channels=64, n_downsample=2, n_res_blocks=4):
super().__init__()
channels = base_channels
# Input convolutional layer
layers = [
nn.ReflectionPad2d(3),
nn.utils.spectral_norm(
nn.Conv2d(3, channels, kernel_size=7)
),
nn.InstanceNorm2d(channels),
nn.ReLU(inplace=True),
]
# Downsampling layers
for i in range(n_downsample):
layers += [
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2)
),
nn.InstanceNorm2d(2 * channels),
nn.ReLU(inplace=True),
]
channels *= 2
# Residual blocks
layers += [
ResidualBlock(channels) for _ in range(n_res_blocks)
]
self.layers = nn.Sequential(*layers)
self.out_channels = channels
def forward(self, x):
return self.layers(x)
@property
def channels(self):
return self.out_channels
The style encoder operates similarly to the content encoder but instead of residual blocks, it uses global pooling and fully-connected layers to distill the input image to its style vector. An important difference is that the style encoder doesn’t use any normalization layers, since they will remove the feature statistics that encode style. This style code will be passed to the decoder, which use this along with the content code to synthesize a fake image.
class StyleEncoder(nn.Module):
'''
StyleEncoder Class
Values:
base_channels: number of channels in first convolutional layer, a scalar
n_downsample: number of downsampling layers, a scalar
s_dim: the dimension of the style tensor (s), a scalar
'''
n_deepen_layers = 2
def __init__(self, base_channels=64, n_downsample=4, s_dim=8):
super().__init__()
channels = base_channels
# Input convolutional layer
layers = [
nn.ReflectionPad2d(3),
nn.utils.spectral_norm(
nn.Conv2d(3, channels, kernel_size=7, padding=0)
),
nn.ReLU(inplace=True),
]
# Downsampling layers
for i in range(self.n_deepen_layers):
layers += [
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2)
),
nn.ReLU(inplace=True),
]
channels *= 2
for i in range(n_downsample - self.n_deepen_layers):
layers += [
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(channels, channels, kernel_size=4, stride=2)
),
nn.ReLU(inplace=True),
]
# Apply global pooling and pointwise convolution to style_channels
layers += [
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, s_dim, kernel_size=1),
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
As with all encoder-decoder frameworks, the decoder serves to synthesize images from the latent information passed through by the encoder. In this case, the decoder works with both content and style encodings.
You can think of the content encoder and decoder as the backbone of the encoder-decoder framework with style information injected into the residual blocks via AdaIN
layers.
Note: the official implementation feeds style code through a multi-layer perceptron (MLP) and assigns these values to the scale and shift parameters in the instance normalization layers. To be compatible with the previous definitions of AdaIN
in this course, your implementation will simply apply the MLP within AdaptiveInstanceNorm2d
.
Let’s take a look at the implementation!
class Decoder(nn.Module):
'''
Decoder Class
Values:
in_channels: number of channels from encoder output, a scalar
n_upsample: number of upsampling layers, a scalar
n_res_blocks: number of residual blocks, a scalar
s_dim: the dimension of the style tensor (s), a scalar
h_dim: the hidden dimension of the MLP, a scalar
'''
def __init__(self, in_channels, n_upsample=2, n_res_blocks=4, s_dim=8, h_dim=256):
super().__init__()
channels = in_channels
# Residual blocks with AdaIN
self.res_blocks = nn.ModuleList([
ResidualBlock(channels, s_dim) for _ in range(n_res_blocks)
])
# Upsampling blocks
layers = []
for i in range(n_upsample):
layers += [
nn.Upsample(scale_factor=2),
nn.ReflectionPad2d(2),
nn.utils.spectral_norm(
nn.Conv2d(channels, channels // 2, kernel_size=5)
),
LayerNorm2d(channels // 2),
]
channels //= 2
layers += [
nn.ReflectionPad2d(3),
nn.utils.spectral_norm(
nn.Conv2d(channels, 3, kernel_size=7)
),
nn.Tanh(),
]
self.layers = nn.Sequential(*layers)
def forward(self, x, s):
for res_block in self.res_blocks:
x = res_block(x, s=s)
x = self.layers(x)
return x
Now you’re ready to implement MUNIT generator and discriminator, as well as the composite loss function that ties everything together in training.
The generator is essentially just comprised of the two encoders and one decoder implemented in the previous section, so let’s wrap everything in a Generator
module!
class Generator(nn.Module):
'''
Generator Class
Values:
base_channels: number of channels in first convolutional layer, a scalar
n_downsample: number of downsampling layers, a scalar
n_res_blocks: number of residual blocks, a scalar
s_dim: the dimension of the style tensor (s), a scalar
h_dim: the hidden dimension of the MLP, a scalar
'''
def __init__(
self,
base_channels: int = 64,
n_c_downsample: int = 2,
n_s_downsample: int = 4,
n_res_blocks: int = 4,
s_dim: int = 8,
h_dim: int = 256,
):
super().__init__()
self.c_enc = ContentEncoder(
base_channels=base_channels, n_downsample=n_c_downsample, n_res_blocks=n_res_blocks,
)
self.s_enc = StyleEncoder(
base_channels=base_channels, n_downsample=n_s_downsample, s_dim=s_dim,
)
self.dec = Decoder(
self.c_enc.channels, n_upsample=n_c_downsample, n_res_blocks=n_res_blocks, s_dim=s_dim, h_dim=h_dim,
)
def encode(self, x):
content = self.c_enc(x)
style = self.s_enc(x)
return (content, style)
def decode(self, content, style):
return self.dec(content, style)
The discriminator, identical to the one used in Pix2PixHD, is comprised of several PatchGAN discriminators operating at different scales. For details on how these work together, take a look at the Pix2PixHD optional notebook.
The discriminator is trained with the least squares objective.
class Discriminator(nn.Module):
'''
Generator Class
Values:
base_channels: number of channels in first convolutional layer, a scalar
n_layers: number of downsampling layers, a scalar
n_discriminators: number of discriminators (all at different scales), a scalar
'''
def __init__(
self,
base_channels: int = 64,
n_layers: int = 3,
n_discriminators: int = 3,
):
super().__init__()
self.discriminators = nn.ModuleList([
self.patchgan_discriminator(base_channels, n_layers) for _ in range(n_discriminators)
])
self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
@staticmethod
def patchgan_discriminator(base_channels, n_layers):
'''
Function that constructs and returns one PatchGAN discriminator module.
'''
channels = base_channels
# Input convolutional layer
layers = [
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(3, channels, kernel_size=4, stride=2),
),
nn.LeakyReLU(0.2, inplace=True),
]
# Hidden convolutional layers
for _ in range(n_layers):
layers += [
nn.ReflectionPad2d(1),
nn.utils.spectral_norm(
nn.Conv2d(channels, 2 * channels, kernel_size=4, stride=2)
),
nn.LeakyReLU(0.2, inplace=True),
]
channels *= 2
# Output projection layer
layers += [
nn.utils.spectral_norm(
nn.Conv2d(channels, 1, kernel_size=1)
),
]
return nn.Sequential(*layers)
def forward(self, x):
outputs = []
for discriminator in self.discriminators:
outputs.append(discriminator(x))
x = self.downsample(x)
return outputs
There are a lot of moving parts in MUNIT so this section will break down which parts interact with which other parts. Recall from earlier our notation:
Image Reconstruction Loss
The model should be able to encode and decode a reconstruction of the image. For domain $\mathcal{A}$, this can be expressed as
\begin{align*} \mathcal{L}{\text{recon}}^a &= \mathbb{E}{a\sim p(a)}\left|\left|F_a(E_a^c(a), E_a^s(a)) - a\right|\right|_1 \end{align*}
and for domain $\mathcal{B}$, this can be expressed as
\begin{align*} \mathcal{L}{\text{recon}}^b &= \mathbb{E}{b\sim p(b)}\left|\left|F_b(E_b^c(b), E_b^s(b)) - b\right|\right|_1. \end{align*}
Latent Reconstruction Loss
The same principle from above applies to the latent space: decoding and encoding a latent vector should reproduce the original. Don’t worry if the equations look complicated! Just know that intuitively, passing in different content and style vectors through the encoders should yield those same input vectors. For domain $\mathcal{A}$, this can be expressed as
\begin{align*} \mathcal{L}{\text{recon}}^{c_b} &= \mathbb{E}{c_b\sim p(c_b),s_a\sim q(s_a)}\left|\left|E_a^c(F_a(c_b, s_a)) - c_a\right|\right|1 \ \mathcal{L}{\text{recon}}^{s_a} &= \mathbb{E}_{c_b\sim p(c_b),s_a\sim q(s_a)}\left|\left|E_a^s(F_a(c_b, s_a)) - s_b\right|\right|_1 \end{align*}
and for domain $\mathcal{B}$, this can be expressed as
\begin{align*} \mathcal{L}{\text{recon}}^{c_a} &= \mathbb{E}{c_a\sim p(c_a),s_b\sim q(s_b)}\left|\left|E_b^c(F_b(c_a, s_b)) - c_b\right|\right|1 \ \mathcal{L}{\text{recon}}^{s_b} &= \mathbb{E}_{c_a\sim p(c_a),s_b\sim q(s_b)}\left|\left|E_b^s(F_b(c_a, s_b)) - s_a\right|\right|_1 \end{align*}
Adversarial Loss
As with all other GANs, MUNIT is trained with adversarial loss. The authors opt for the LSGAN least-squares objective. For domain $\mathcal{A}$, this can be expressed as
\begin{align*} \mathcal{L}{\text{GAN}}^a &= \mathbb{E}{c_b\sim p(c_b),s_a\sim q(s_a)}\left[(1 - D_a(G_a(c_b, s_a)))^2\right] + \mathbb{E}_{a\sim p(a)}\left[D_a(a)^2\right] \end{align*}
and for domain $\mathcal{B}$, this can be expressed as
\begin{align*} \mathcal{L}{\text{GAN}}^b &= \mathbb{E}{c_a\sim p(c_a),s_b\sim q(s_b)}\left[(1 - D_b(G_b(c_a, s_b)))^2\right] + \mathbb{E}_{b\sim p(b)}\left[D_b(b)^2\right] \end{align*}
Total Loss
The total loss can now be expressed in terms of the individual losses from above, so the objective can be expressed as
\begin{align*} \mathcal{L}(E_a, E_b, F_a, F_b, D_a, D_b) &= \mathcal{L}{\text{GAN}}^a + \mathcal{L}{\text{GAN}}^b + \lambda_x(\mathcal{L}{\text{recon}}^a + \mathcal{L}{\text{recon}}^b) + \lambda_c(\mathcal{L}{\text{recon}}^{c_a} + \mathcal{L}{\text{recon}}^{c_b}) + \lambda_s(\mathcal{L}{\text{recon}}^{s_a} + \mathcal{L}{\text{recon}}^{s_b}) \end{align*}
And now the fun part: implementing this ginormous composite loss!
class GinormousCompositeLoss(nn.Module):
'''
GinormousCompositeLoss Class: implements all losses for MUNIT
'''
@staticmethod
def image_recon_loss(x, gen):
c, s = gen.encode(x)
recon = gen.decode(c, s)
return F.l1_loss(recon, x), c, s
@staticmethod
def latent_recon_loss(c, s, gen):
x_fake = gen.decode(c, s)
recon = gen.encode(x_fake)
return F.l1_loss(recon[0], c), F.l1_loss(recon[1], s), x_fake
@staticmethod
def adversarial_loss(x, dis, is_real):
preds = dis(x)
target = torch.ones_like if is_real else torch.zeros_like
loss = 0.0
for pred in preds:
loss += F.mse_loss(pred, target(pred))
return loss
The authors also propose a couple of auxiliary loss functions that can help improve convergence. This section will just go over these losses but implementation is left as an exercise for the curious reader :)
Style-augmented Cycle Consistency
You’ve already heard of cycle consistency from CycleGAN, which implies that an image translated to the target domain and back should be identical to the original.
Intuitively, style-augmented cycle consistency implies that an image translated to the target domain and back using the original style should result in the original image. Style-augmented cycle consistency is implicitly encouraged by the reconstruction losses, but the authors note that explicitly enforcing it could be useful in some cases.
Domain Invariant Perceptual Loss
You’re probably already familiar with perceptual loss, which is usually implemented via MSE loss between feature maps of fake and real images. However, because the images in the domains are unpaired, pixel-wise loss may not be optimal, since each values at each position do not correspond spatially.
The authors get around this discrepancy by applying instance normalization to the feature maps. This normalizes the values per channel for each position in the feature maps, so MSE loss penalizes the difference in statistics rather than raw pixel value.
Now that you’ve got all the necessary modules, let’s see how we can put them all together.
class MUNIT(nn.Module):
def __init__(
self,
gen_channels: int = 64,
n_c_downsample: int = 2,
n_s_downsample: int = 4,
n_res_blocks: int = 4,
s_dim: int = 8,
h_dim: int = 256,
dis_channels: int = 64,
n_layers: int = 3,
n_discriminators: int = 3,
):
super().__init__()
self.gen_a = Generator(
base_channels=gen_channels, n_c_downsample=n_c_downsample, n_s_downsample=n_s_downsample, n_res_blocks=n_res_blocks, s_dim=s_dim, h_dim=h_dim,
)
self.gen_b = Generator(
base_channels=gen_channels, n_c_downsample=n_c_downsample, n_s_downsample=n_s_downsample, n_res_blocks=n_res_blocks, s_dim=s_dim, h_dim=h_dim,
)
self.dis_a = Discriminator(
base_channels=dis_channels, n_layers=n_layers, n_discriminators=n_discriminators,
)
self.dis_b = Discriminator(
base_channels=dis_channels, n_layers=n_layers, n_discriminators=n_discriminators,
)
self.s_dim = s_dim
self.loss = GinormousCompositeLoss
def forward(self, x_a, x_b):
s_a = torch.randn(x_a.size(0), self.s_dim, 1, 1, device=x_a.device).to(x_a.dtype)
s_b = torch.randn(x_b.size(0), self.s_dim, 1, 1, device=x_b.device).to(x_b.dtype)
# Encode real x and compute image reconstruction loss
x_a_loss, c_a, s_a_fake = self.loss.image_recon_loss(x_a, self.gen_a)
x_b_loss, c_b, s_b_fake = self.loss.image_recon_loss(x_b, self.gen_b)
# Decode real (c, s) and compute latent reconstruction loss
c_b_loss, s_a_loss, x_ba = self.loss.latent_recon_loss(c_b, s_a, self.gen_a)
c_a_loss, s_b_loss, x_ab = self.loss.latent_recon_loss(c_a, s_b, self.gen_b)
# Compute adversarial losses
gen_a_adv_loss = self.loss.adversarial_loss(x_ba, self.dis_a, True)
gen_b_adv_loss = self.loss.adversarial_loss(x_ab, self.dis_b, True)
# Sum up losses for gen
gen_loss = (
10 * x_a_loss + c_b_loss + s_a_loss + gen_a_adv_loss + \
10 * x_b_loss + c_a_loss + s_b_loss + gen_b_adv_loss
)
# Sum up losses for dis
dis_loss = (
self.loss.adversarial_loss(x_ba.detach(), self.dis_a, False) + \
self.loss.adversarial_loss(x_a.detach(), self.dis_a, True) + \
self.loss.adversarial_loss(x_ab.detach(), self.dis_b, False) + \
self.loss.adversarial_loss(x_b.detach(), self.dis_b, True)
)
return gen_loss, dis_loss, x_ab, x_ba
Now you’re ready to train MUNIT! Let’s start by setting some optimization parameters and initializing everything you’ll need for training.
# Initialize model
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
munit_config = {
'gen_channels': 64,
'n_c_downsample': 2,
'n_s_downsample': 4,
'n_res_blocks': 4,
's_dim': 8,
'h_dim': 256,
'dis_channels': 64,
'n_layers': 3,
'n_discriminators': 3,
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
munit = MUNIT(**munit_config).to(device).apply(weights_init)
# Initialize dataloader
transform = transforms.Compose([
transforms.Resize(286),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
])
dataloader = DataLoader(
ImageDataset('horse2zebra', transform),
batch_size=1, pin_memory=True, shuffle=True,
)
# Initialize optimizers
gen_params = list(munit.gen_a.parameters()) + list(munit.gen_b.parameters())
dis_params = list(munit.dis_a.parameters()) + list(munit.dis_b.parameters())
gen_optimizer = torch.optim.Adam(gen_params, lr=1e-4, betas=(0.5, 0.999))
dis_optimizer = torch.optim.Adam(dis_params, lr=1e-4, betas=(0.5, 0.999))
# Parse torch version for autocast
# ######################################################
version = torch.__version__
version = tuple(int(n) for n in version.split('.')[:-1])
has_autocast = version >= (1, 6)
# ######################################################
def train(munit, dataloader, optimizers, device):
max_iters = 1000000
decay_every = 100000
cur_iter = 0
display_every = 500
mean_losses = [0., 0.]
while cur_iter < max_iters:
for (x_a, x_b) in tqdm(dataloader):
x_a = x_a.to(device)
x_b = x_b.to(device)
# Enable autocast to FP16 tensors (new feature since torch==1.6.0)
# If you're running older versions of torch, comment this out
# and use NVIDIA apex for mixed/half precision training
if has_autocast:
with torch.cuda.amp.autocast(enabled=(device=='cuda')):
outputs = munit(x_a, x_b)
else:
outputs = munit(x_a, x_b)
losses, x_ab, x_ba = outputs[:-2], outputs[-2], outputs[-1]
munit.zero_grad()
for i, (optimizer, loss) in enumerate(zip(optimizers, losses)):
optimizer.zero_grad()
loss.backward()
optimizer.step()
mean_losses[i] += loss.item() / display_every
cur_iter += 1
if cur_iter % display_every == 0:
print('Step {}: [G loss: {:.5f}][D loss: {:.5f}]'
.format(cur_iter, *mean_losses))
show_tensor_images(x_ab, x_a)
show_tensor_images(x_ba, x_b)
mean_losses = [0., 0.]
if cur_iter == max_iters:
break
# Schedule learning rate by 0.5
if cur_iter % decay_every == 0:
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.5
train(
munit, dataloader,
[gen_optimizer, dis_optimizer],
device,
)
47%|████▋ | 499/1067 [05:06<05:48, 1.63it/s]
Step 500: [G loss: 11.74086][D loss: 0.30821]
94%|█████████▎| 999/1067 [10:08<00:40, 1.68it/s]
Step 1000: [G loss: 10.53732][D loss: 0.02970]
100%|██████████| 1067/1067 [10:49<00:00, 1.64it/s]
40%|████ | 432/1067 [04:13<06:04, 1.74it/s]
Step 1500: [G loss: 10.07501][D loss: 0.05018]
87%|████████▋ | 932/1067 [08:58<01:15, 1.79it/s]
Step 2000: [G loss: 9.69387][D loss: 0.01541]
100%|██████████| 1067/1067 [10:15<00:00, 1.73it/s]
34%|███▍ | 365/1067 [03:26<06:32, 1.79it/s]
Step 2500: [G loss: 9.45771][D loss: 0.01282]
81%|████████ | 865/1067 [08:07<01:52, 1.80it/s]
Step 3000: [G loss: 9.21778][D loss: 0.00660]
100%|██████████| 1067/1067 [10:00<00:00, 1.78it/s]
28%|██▊ | 298/1067 [02:46<07:00, 1.83it/s]
Step 3500: [G loss: 9.00961][D loss: 0.01060]
75%|███████▍ | 798/1067 [07:27<02:34, 1.75it/s]
Step 4000: [G loss: 8.87639][D loss: 0.01009]
100%|██████████| 1067/1067 [09:59<00:00, 1.78it/s]
22%|██▏ | 231/1067 [02:06<07:35, 1.84it/s]
Step 4500: [G loss: 8.72139][D loss: 0.00777]
69%|██████▊ | 731/1067 [06:42<03:01, 1.85it/s]
Step 5000: [G loss: 8.56996][D loss: 0.00359]
100%|██████████| 1067/1067 [09:47<00:00, 1.82it/s]
15%|█▌ | 164/1067 [01:30<08:15, 1.82it/s]
Step 5500: [G loss: 8.57948][D loss: 0.00540]
62%|██████▏ | 664/1067 [06:08<03:45, 1.79it/s]
Step 6000: [G loss: 8.43045][D loss: 0.00491]
100%|██████████| 1067/1067 [09:50<00:00, 1.81it/s]
9%|▉ | 97/1067 [00:53<08:47, 1.84it/s]
Step 6500: [G loss: 8.42890][D loss: 0.00393]
56%|█████▌ | 597/1067 [05:29<04:16, 1.83it/s]
Step 7000: [G loss: 8.29681][D loss: 0.00428]
100%|██████████| 1067/1067 [09:44<00:00, 1.83it/s]
3%|▎ | 30/1067 [00:16<09:33, 1.81it/s]
Step 7500: [G loss: 8.25187][D loss: 0.00296]
50%|████▉ | 530/1067 [04:45<04:51, 1.84it/s]
Step 8000: [G loss: 8.12992][D loss: 0.00211]
97%|█████████▋| 1030/1067 [09:13<00:20, 1.84it/s]
Step 8500: [G loss: 8.11098][D loss: 0.00283]
100%|██████████| 1067/1067 [09:33<00:00, 1.86it/s]
43%|████▎ | 463/1067 [04:08<05:24, 1.86it/s]
Step 9000: [G loss: 8.17234][D loss: 0.00518]
90%|█████████ | 963/1067 [08:37<00:55, 1.89it/s]
Step 9500: [G loss: 8.16071][D loss: 0.00234]
100%|██████████| 1067/1067 [09:32<00:00, 1.86it/s]
37%|███▋ | 396/1067 [03:32<06:05, 1.83it/s]
Step 10000: [G loss: 8.14926][D loss: 0.00227]
84%|████████▍ | 896/1067 [08:02<01:30, 1.89it/s]
Step 10500: [G loss: 7.89429][D loss: 0.00395]
100%|██████████| 1067/1067 [09:35<00:00, 1.86it/s]
31%|███ | 329/1067 [02:55<06:44, 1.82it/s]
Step 11000: [G loss: 7.89454][D loss: 0.00154]
78%|███████▊ | 829/1067 [07:25<02:10, 1.82it/s]
Step 11500: [G loss: 7.84051][D loss: 0.00255]
100%|██████████| 1067/1067 [09:33<00:00, 1.86it/s]
25%|██▍ | 262/1067 [02:20<07:16, 1.84it/s]
Step 12000: [G loss: 7.81340][D loss: 0.00118]
71%|███████▏ | 762/1067 [06:49<02:40, 1.90it/s]
Step 12500: [G loss: 7.79256][D loss: 0.00100]
100%|██████████| 1067/1067 [09:32<00:00, 1.86it/s]
18%|█▊ | 195/1067 [01:44<07:54, 1.84it/s]
Step 13000: [G loss: 7.77127][D loss: 0.00220]
65%|██████▌ | 695/1067 [06:14<03:28, 1.79it/s]
Step 13500: [G loss: 7.67707][D loss: 0.00148]
100%|██████████| 1067/1067 [09:36<00:00, 1.85it/s]
12%|█▏ | 128/1067 [01:09<08:24, 1.86it/s]
Step 14000: [G loss: 7.71080][D loss: 0.00076]
59%|█████▉ | 628/1067 [05:42<04:00, 1.82it/s]
Step 14500: [G loss: 7.64382][D loss: 0.00227]
100%|██████████| 1067/1067 [09:40<00:00, 1.84it/s]
6%|▌ | 61/1067 [00:32<09:11, 1.82it/s]
Step 15000: [G loss: 7.69827][D loss: 0.00103]
53%|█████▎ | 561/1067 [04:59<04:35, 1.84it/s]
Step 15500: [G loss: 7.62718][D loss: 0.00268]
99%|█████████▉| 1061/1067 [09:26<00:03, 1.91it/s]
Step 16000: [G loss: 7.62309][D loss: 0.00060]
100%|██████████| 1067/1067 [09:29<00:00, 1.87it/s]
46%|████▋ | 494/1067 [04:25<05:07, 1.86it/s]
Step 16500: [G loss: 7.54315][D loss: 0.00055]
93%|█████████▎| 994/1067 [08:53<00:38, 1.89it/s]
Step 17000: [G loss: 7.54358][D loss: 0.00049]
100%|██████████| 1067/1067 [09:33<00:00, 1.86it/s]
40%|████ | 427/1067 [03:49<05:36, 1.90it/s]
Step 17500: [G loss: 7.57658][D loss: 0.00252]
87%|████████▋ | 927/1067 [08:16<01:14, 1.87it/s]
Step 18000: [G loss: 7.56780][D loss: 0.00026]
100%|██████████| 1067/1067 [09:31<00:00, 1.87it/s]
34%|███▎ | 360/1067 [03:12<06:16, 1.88it/s]
Step 18500: [G loss: 7.49320][D loss: 0.00053]
81%|████████ | 860/1067 [07:40<01:49, 1.89it/s]
Step 19000: [G loss: 7.47671][D loss: 0.00076]
100%|██████████| 1067/1067 [09:31<00:00, 1.87it/s]
27%|██▋ | 293/1067 [02:36<06:51, 1.88it/s]
Step 19500: [G loss: 7.44688][D loss: 0.00047]
74%|███████▍ | 793/1067 [07:03<02:27, 1.86it/s]
Step 20000: [G loss: 7.47304][D loss: 0.00089]
100%|██████████| 1067/1067 [09:30<00:00, 1.87it/s]
21%|██ | 226/1067 [02:00<07:33, 1.85it/s]
Step 20500: [G loss: 7.35975][D loss: 0.00081]
68%|██████▊ | 726/1067 [06:28<03:06, 1.83it/s]
Step 21000: [G loss: 7.42969][D loss: 0.00046]
100%|██████████| 1067/1067 [09:31<00:00, 1.87it/s]
15%|█▍ | 159/1067 [01:25<08:00, 1.89it/s]
Step 21500: [G loss: 7.37134][D loss: 0.00021]
62%|██████▏ | 659/1067 [05:53<03:38, 1.86it/s]
Step 22000: [G loss: 7.33494][D loss: 0.00061]
100%|██████████| 1067/1067 [09:32<00:00, 1.86it/s]
9%|▊ | 92/1067 [00:49<08:31, 1.90it/s]
Step 22500: [G loss: 7.34315][D loss: 0.00076]
55%|█████▌ | 592/1067 [05:17<04:11, 1.89it/s]
Step 23000: [G loss: 7.30039][D loss: 0.00104]
100%|██████████| 1067/1067 [09:30<00:00, 1.87it/s]
2%|▏ | 25/1067 [00:13<09:15, 1.88it/s]
Step 23500: [G loss: 7.33213][D loss: 0.00043]
49%|████▉ | 525/1067 [04:41<04:45, 1.90it/s]
Step 24000: [G loss: 7.23658][D loss: 0.00020]
96%|█████████▌| 1025/1067 [09:09<00:22, 1.91it/s]
Step 24500: [G loss: 7.14538][D loss: 0.00028]
100%|██████████| 1067/1067 [09:32<00:00, 1.86it/s]
43%|████▎ | 458/1067 [04:05<05:25, 1.87it/s]
Step 25000: [G loss: 7.14678][D loss: 0.00028]
90%|████████▉ | 958/1067 [08:36<00:57, 1.89it/s]
Step 25500: [G loss: 7.20100][D loss: 0.00042]
100%|██████████| 1067/1067 [09:35<00:00, 1.85it/s]
37%|███▋ | 391/1067 [03:33<06:10, 1.82it/s]
Step 26000: [G loss: 7.17250][D loss: 0.00079]
84%|████████▎ | 891/1067 [08:07<01:34, 1.86it/s]
Step 26500: [G loss: 7.12385][D loss: 0.00080]
100%|██████████| 1067/1067 [09:44<00:00, 1.83it/s]
30%|███ | 324/1067 [02:57<06:49, 1.82it/s]
Step 27000: [G loss: 7.13010][D loss: 0.00136]
77%|███████▋ | 824/1067 [07:29<02:08, 1.89it/s]
Step 27500: [G loss: 7.18703][D loss: 0.00051]
100%|██████████| 1067/1067 [09:40<00:00, 1.84it/s]
24%|██▍ | 257/1067 [02:18<07:21, 1.83it/s]
Step 28000: [G loss: 7.13295][D loss: 0.00031]
71%|███████ | 757/1067 [06:46<02:45, 1.87it/s]
Step 28500: [G loss: 7.09443][D loss: 0.00021]
100%|██████████| 1067/1067 [09:32<00:00, 1.86it/s]
18%|█▊ | 190/1067 [01:41<07:49, 1.87it/s]
Step 29000: [G loss: 7.09862][D loss: 0.00057]
65%|██████▍ | 690/1067 [06:09<03:20, 1.88it/s]
Step 29500: [G loss: 7.08326][D loss: 0.00042]
100%|██████████| 1067/1067 [09:31<00:00, 1.87it/s]
12%|█▏ | 123/1067 [01:06<08:41, 1.81it/s]
Step 30000: [G loss: 6.99207][D loss: 0.00095]
58%|█████▊ | 623/1067 [05:33<04:02, 1.83it/s]
Step 30500: [G loss: 7.05251][D loss: 0.00054]
100%|██████████| 1067/1067 [09:31<00:00, 1.87it/s]
5%|▌ | 56/1067 [00:29<08:53, 1.90it/s]
Step 31000: [G loss: 6.99728][D loss: 0.00009]
52%|█████▏ | 556/1067 [04:57<04:32, 1.87it/s]
Step 31500: [G loss: 7.00414][D loss: 0.00013]
99%|█████████▉| 1056/1067 [09:23<00:06, 1.79it/s]
Step 32000: [G loss: 7.06001][D loss: 0.00056]
100%|██████████| 1067/1067 [09:30<00:00, 1.87it/s]
46%|████▌ | 489/1067 [04:20<05:04, 1.90it/s]
Step 32500: [G loss: 6.95155][D loss: 0.00017]
93%|█████████▎| 989/1067 [08:48<00:42, 1.85it/s]
Step 33000: [G loss: 6.95831][D loss: 0.00051]
100%|██████████| 1067/1067 [09:30<00:00, 1.87it/s]
40%|███▉ | 422/1067 [03:45<05:37, 1.91it/s]
Step 33500: [G loss: 6.98002][D loss: 0.00065]
86%|████████▋ | 922/1067 [08:12<01:17, 1.86it/s]
Step 34000: [G loss: 6.89810][D loss: 0.00018]
100%|██████████| 1067/1067 [09:29<00:00, 1.87it/s]
33%|███▎ | 355/1067 [03:08<06:11, 1.92it/s]
Step 34500: [G loss: 6.96272][D loss: 0.00048]
80%|████████ | 855/1067 [07:36<01:54, 1.85it/s]
Step 35000: [G loss: 6.81202][D loss: 0.00029]
100%|██████████| 1067/1067 [09:30<00:00, 1.87it/s]
27%|██▋ | 288/1067 [02:33<06:53, 1.88it/s]
Step 35500: [G loss: 6.92884][D loss: 0.00067]
74%|███████▍ | 788/1067 [06:59<02:33, 1.81it/s]
Step 36000: [G loss: 6.87659][D loss: 0.00019]
100%|██████████| 1067/1067 [09:29<00:00, 1.87it/s]
21%|██ | 221/1067 [01:58<07:22, 1.91it/s]
Step 36500: [G loss: 6.82683][D loss: 0.00004]
68%|██████▊ | 721/1067 [06:24<03:01, 1.91it/s]
Step 37000: [G loss: 6.79883][D loss: 0.00014]
100%|██████████| 1067/1067 [09:29<00:00, 1.88it/s]
14%|█▍ | 154/1067 [01:22<08:00, 1.90it/s]
Step 37500: [G loss: 6.80369][D loss: 0.00021]
61%|██████▏ | 654/1067 [05:48<03:40, 1.88it/s]
Step 38000: [G loss: 6.86273][D loss: 0.00015]
100%|██████████| 1067/1067 [09:29<00:00, 1.87it/s]
8%|▊ | 87/1067 [00:46<08:37, 1.89it/s]
Step 38500: [G loss: 6.82430][D loss: 0.00018]
55%|█████▌ | 587/1067 [05:13<04:14, 1.89it/s]
Step 39000: [G loss: 6.80032][D loss: 0.00007]
100%|██████████| 1067/1067 [09:30<00:00, 1.87it/s]
2%|▏ | 20/1067 [00:10<09:23, 1.86it/s]
Step 39500: [G loss: 6.76142][D loss: 0.00103]
49%|████▊ | 520/1067 [04:37<04:49, 1.89it/s]
Step 40000: [G loss: 6.74708][D loss: 0.00010]
96%|█████████▌| 1020/1067 [09:04<00:25, 1.82it/s]
Step 40500: [G loss: 6.73421][D loss: 0.00022]
100%|██████████| 1067/1067 [09:29<00:00, 1.87it/s]
42%|████▏ | 453/1067 [04:01<05:30, 1.86it/s]
Step 41000: [G loss: 6.69312][D loss: 0.00047]
71%|███████ | 760/1067 [06:46<02:40, 1.92it/s]