Автор оригинала: Pankaj Kumar.
ГАН был предметом разговоров в городе с момента его создания в 2014 году Гудфеллоу. В этом уроке вы научитесь тренировать свой первый ПИСТОЛЕТ в тангаже. Мы также попытаемся объяснить внутреннюю работу GAN и рассмотрим простую реализацию GAN с помощью PyTorch.
Библиотеки для импорта
Сначала мы импортируем библиотеки и функции, которые будут использоваться в реализации.
import torch from torch import nn from torchvision import transforms from torchvision.utils import make_grid from torchvision.datasets import MNIST from torch.utils.data import DataLoader import matplotlib.pyplot as plt from IPython.display import clear_output
Что такое ГАН?
Генеративная сеть может быть просто описана как сеть, которая может учиться на обучающих данных и генерировать данные, подобные обучающим данным. Существуют различные способы разработки генеративной модели, одним из которых является состязательность.
В генеративной состязательной сети есть две подмодели – генератор и дискриминатор. Мы рассмотрим эти подмодели более подробно:
1. Генератор
Генератору, как следует из названия, назначается задача создания изображения.
Генератор принимает небольшие низкоразмерные входные данные(обычно 1-D вектор) и выдает данные изображения размером 128x128x3 в качестве выходных данных.
Эта операция масштабирования нижнего измерения до более высокого измерения достигается с помощью последовательных слоев деконволюции и свертки.
Наш генератор можно рассматривать как функцию, которая принимает данные низкой размерности и сопоставляет их с данными изображения высокой размерности.
В течение периода обучения генератор учится все более и более эффективно сопоставлять данные низкой размерности с данными высокой размерности.
Цель генератора состоит в том, чтобы сгенерировать изображение, которое может обмануть дискриминатор для реального изображения.
Класс генератора:
class Generator(nn.Module): def __init__(self, z_dim, im_chan, hidden_dim=64): super().__init__() self.z_dim = z_dim self.gen = nn.Sequential( # We define the generator as stacks of deconvolution layers # with batch normalization and non-linear activation function # You can try to play with the values of the layers nn.ConvTranspose2d(z_dim, 4*hidden_dim, 3, 2), nn.BatchNorm2d(4*hidden_dim), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, 4, 1), nn.BatchNorm2d(hidden_dim*2), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden_dim * 2, hidden_dim ,3 ,2), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.ConvTranspose2d(hidden_dim, im_chan, 4, 2), nn.Tanh() ) def forward(self, noise): # Define how the generator computes the output noise = noise.view(len(noise), self.z_dim, 1, 1) return self.gen(noise)
# We define a generator with latent dimension 100 and img_dim 1 gen = Generator(100, 1) print("Composition of the Generator:", end="\n\n") print(gen)
Compostion of the Generator: Generator( (gen): Sequential( (0): ConvTranspose2d(100, 256, kernel_size=(3, 3), stride=(2, 2)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1)) (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) (6): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2)) (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (8): ReLU(inplace=True) (9): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2)) (10): Tanh() ) )
Дополнительные примечания: Изображение представляет собой очень объемные данные. Даже RGB-изображение размером 3x128x128 размером 49152.
Изображения, которые нам нужны, лежат в подпространстве или многообразии такого огромного пространства.
В идеале генератор должен узнать, где находится подпространство, и произвольно выбирается из изученного подпространства для получения выходных данных.
Поиск этого идеального подпространства-очень дорогостоящая вычислительная задача, для решения которой наиболее распространенным способом является сопоставление скрытого векторного пространства с пространством данных с помощью толчка вперед.
2. Дискриминатор
У нашего дискриминатора D есть более простая, но, тем не менее, важная задача. Дискриминатор-это двоичный классификатор, который указывает, являются ли входные данные из исходного источника или из нашего генератора. Идеальный дискриминатор
class Discriminator(nn.Module): def __init__(self, im_chan, hidden_dim=16): super().__init__() self.disc = nn.Sequential( # Discriminator is defined as a stack of # convolution layers with batch normalization # and non-linear activations. nn.Conv2d(im_chan, hidden_dim, 4, 2), nn.BatchNorm2d(hidden_dim), nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(hidden_dim, hidden_dim * 2, 4, 2), nn.BatchNorm2d(hidden_dim*2), nn.LeakyReLU(0.2,inplace=True), nn.Conv2d(hidden_dim*2, 1, 4, 2) ) def forward(self, image): disc_pred = self.disc(image) return disc_pred.view(len(disc_pred), -1)
# We define a discriminator for one class classification disc = Discriminator(1) print("Composition of the Discriminator:", end="\n\n") print(disc)
Composition of the Discriminator: Discriminator( (disc): Sequential( (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): LeakyReLU(negative_slope=0.2, inplace=True) (3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2)) (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): LeakyReLU(negative_slope=0.2, inplace=True) (6): Conv2d(32, 1, kernel_size=(4, 4), stride=(2, 2)) )
Функции потерь в БАНДЕ
Теперь мы определяем потери для генератора и дискриминатора.
1. Потеря генератора
Генератор пытается генерировать изображения, которые могут обмануть дискриминатор, чтобы считать их реальными.
Таким образом, генератор пытается максимизировать вероятность присвоения поддельных изображений истинной метке.
Таким образом, потеря генератора-это ожидаемая вероятность того, что дискриминатор классифицирует сгенерированное изображение как поддельное.
def gen_loss(gen, disc, num_images, latent_dim, device): # Generate the the fake images noise = random_noise(num_images, latent_dim).to(device) gen_img = gen(noise) # Pass through discriminator and find the binary cross entropy loss disc_gen = disc(gen_img) gen_loss = Loss(disc_gen, torch.ones_like(disc_gen)) return gen_loss
2. Потеря дискриминатора
Мы хотим, чтобы дискриминатор максимизировал вероятность присвоения истинной метки реальным изображениям и максимизировал вероятность присвоения поддельной метки поддельным изображениям.
Подобно потере генератора, потеря дискриминатора-это вероятность того, что реальное изображение классифицируется как поддельное, а поддельное изображение классифицируется как реальное.
Обратите внимание, как функция потерь наших двух моделей действует друг против друга.
def disc_loss(gen, disc, real_images, num_images, latent_dim, device): # Generate the fake images noise = random_noise(num_images, latent_dim).to(device); img_gen = gen(noise).detach() # Pass the real and fake images through discriminator disc_gen = disc(img_gen) disc_real = disc(real_images) # Find loss for the generator and discriminator gen_loss = Loss(disc_gen, torch.zeros_like(disc_gen)) real_loss = Loss(disc_real, torch.ones_like(disc_real)) # Average over the losses for the discriminator loss disc_loss = ((gen_loss + real_loss) /2).mean() return disc_loss
Загрузка обучающего набора данных MNIST
Мы загружаем данные обучения MNIST . Мы будем использовать пакет torch vision для загрузки необходимого набора данных.
# Set the batch size BATCH_SIZE = 512 # Download the data in the Data folder in the directory above the current folder data_iter = DataLoader( MNIST('../Data', download=True, transform=transforms.ToTensor()), batch_size=BATCH_SIZE, shuffle=True)
Инициализация модели
Установите гиперпараметры моделей.
# Set Loss as Binary CrossEntropy with logits Loss = nn.BCEWithLogitsLoss() # Set the latent dimension latent_dim = 100 display_step = 500 # Set the learning rate lr = 0.0002 # Set the beta_1 and beta_2 for the optimizer beta_1 = 0.5 beta_2 = 0.999
Установите устройство на cpu или cuda в зависимости от того, включено ли аппаратное ускорение.
device = "cpu" if torch.cuda.is_available(): device = "cuda" device
Теперь мы инициализируем генератор, дискриминатор и оптимизаторы. Мы также инициализируем начальные/начальные веса слоя.
# Initialize the Generator and the Discriminator along with # their optimizer gen_opt and disc_opt # We choose ADAM as the optimizer for both models gen = Generator(latent_dim, 1).to(device) gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2)) disc = Discriminator(1 ).to(device) disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2)) # Initialize the weights of the various layers def weights_init(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): torch.nn.init.normal_(m.weight, 0.0, 0.02) if isinstance(m, nn.BatchNorm2d): torch.nn.init.normal_(m.weight, 0.0, 0.02) torch.nn.init.constant_(m.bias, 0) # Apply the initial weights on the generator and discriminator gen = gen.apply(weights_init) disc = disc.apply(weights_init)
Настройка функций утилиты
Нам всегда нужны некоторые служебные функции, которые не вписываются конкретно в наше приложение, но облегчают некоторые из наших задач. Мы определяем функцию, которая может отображать изображения в сетке, используя функцию torch vision make_grid.
def display_images(image_tensor, num_images=25, size=(1, 28, 28)): image_unflat = image_tensor.detach().cpu().view(-1, *size) image_grid = make_grid(image_unflat[:num_images], nrow=5) plt.imshow(image_grid.permute(1, 2, 0).squeeze()) plt.show()
Мы определяем функцию шума для генерации случайного шума, который будет использоваться для ввода в генератор.
def random_noise(n_samples, z_dim): return torch.randn(n_samples, z_dim)
Тренировочный цикл для нашей БАНДЫ в Пайторче
# Set the number of epochs num_epochs = 100 # Set the interval at which generated images will be displayed display_step = 100 # Inter parameter itr = 0 for epoch in range(num_epochs): for images, _ in data_iter: num_images = len(images) # Transfer the images to cuda if harware accleration is present real_images = images.to(device) # Discriminator step disc_opt.zero_grad() D_loss = disc_loss(gen, disc, real_images, num_images, latent_dim, device) D_loss.backward(retain_graph=True) disc_opt.step() # Generator Step gen_opt.zero_grad() G_loss = gen_loss(gen, disc, num_images, latent_dim, device) G_loss.backward(retain_graph=True) gen_opt.step() if itr% display_step ==0 : with torch.no_grad(): # Clear the previous output clear_output(wait=True) noise = noise = random_noise(25,latent_dim).to(device) img = gen(noise) # Display the generated images display_images(img) itr+=1
Результаты
Вот некоторые из результатов нашего GAN.
Вывод
Мы видели, как мы можем генерировать новые изображения из набора изображений. Оружие не ограничивается изображениями чисел. Современные пушки достаточно мощны, чтобы создавать настоящие человеческие лица. Банды теперь используются для создания музыки, искусства и т. Д. Если вы хотите узнать больше о работе Gains, вы можете обратиться к этой оригинальной статье GUN/| от Goodfellow.