Рубрики
Без рубрики

Введение в GANs с Python и TensorFlow

Автор оригинала: Daniele Paliotta.

Введение в GANs с Python и TensorFlow

Вступление

Генеративные модели-это семейство архитектур ИИ, целью которых является создание образцов данных с нуля. Они достигают этого, фиксируя распределение данных того типа вещей, которые мы хотим генерировать.

Такого рода модели активно исследуются, и вокруг них существует огромное количество шумихи. Просто посмотрите на диаграмму, которая показывает количество статей, опубликованных в этой области за последние несколько лет:

Gan papers

С 2014 года, когда была опубликована первая статья о генеративных состязательных сетях, генеративные модели становятся невероятно мощными, и теперь мы можем генерировать гиперреалистичные образцы данных для широкого спектра распределений: изображений, видео, музыки, произведений письма и т. Д.

Вот несколько примеров изображений, сгенерированных a GAN :

Лицо, сгенерированное с помощью GANs
GAN-генерируемые изображения

Что такое Генеративные модели?

Структура GANs

Наиболее успешная структура, предложенная для генеративных моделей, по крайней мере за последние годы, носит название Генеративные состязательные сети ( GANs ).

Проще говоря, GAN состоит из двух отдельных моделей, представленных нейронными сетями: генератора G и дискриминатора D . Цель дискриминатора состоит в том, чтобы определить, происходит ли выборка данных из реального распределения данных или же она генерируется G .

Цель генератора состоит в том, чтобы генерировать образцы данных, такие как обмануть дискриминатор.

Генератор-это не что иное, как глубокая нейронная сеть. Он принимает в качестве входных данных вектор случайного шума (обычно гауссовского или из равномерного распределения) и выводит выборку данных из распределения, которое мы хотим захватить.

Дискриминатор-это, опять же, просто нейронная сеть. Его цель, как гласит его название, состоит в том, чтобы различать между реальными и поддельными образцами. Следовательно, его входным сигналом является выборка данных, поступающая либо от генератора, либо от фактического распределения данных.

Выход-это простое число, представляющее вероятность того, что вход был реальным. Высокая вероятность означает, что дискриминатор уверен, что образцы, которые ему скармливают, являются подлинными. Напротив, низкая вероятность свидетельствует о высокой уверенности в том, что выборка поступает из генераторной сети:

Рамки

Представьте себе фальшивомонетчика, который пытается создать фальшивые произведения искусства, и искусствоведа, который должен отличать настоящие картины от фальшивых.

В этом сценарии критик действует как наш дискриминатор, а фальсификатор-как генератор, принимая обратную связь от критика, чтобы улучшить его навыки и сделать его подделанное искусство более убедительным:

Упрощенная структура

Обучение

Тренировка снова может быть болезненной вещью. Нестабильность обучения всегда была проблемой, и многие исследования были сосредоточены на том, чтобы сделать обучение более стабильным.

Основная целевая функция модели vanilla GAN заключается в следующем:

Функция потерь GANs

Здесь D относится к сети дискриминаторов, в то время как G , очевидно, относится к генератору.

Как видно из формулы, генератор оптимизируется для максимального запутывания дискриминатора, пытаясь заставить его выдавать высокие вероятности для поддельных выборок данных.

Напротив, дискриминатор пытается лучше отличить выборки, поступающие из G , от выборок, поступающих из реального распределения.

Термин “состязательный” происходит именно от того, как тренируются ПУШКИ, натравливая две сети друг на друга.

После того, как мы обучили нашу модель, дискриминатор больше не требуется. Все, что нам нужно сделать, это подать генератору случайный вектор шума, и мы надеемся получить в результате реалистичную, искусственную выборку данных.

Проблемы GANs

Так почему же банды так трудно обучить? Как уже говорилось ранее, в ванильной форме их очень трудно тренировать. Мы кратко рассмотрим, почему это так.

Труднодостижимое равновесие Нэша

Поскольку эти две сети посылают информацию друг другу, ее можно представить как игру, в которой можно угадать, реален вход или нет.

Структура GAN-это невыпуклая, двухпользовательская, некооперативная игра с непрерывными многомерными параметрами, в которой каждый игрок хочет минимизировать свою функцию затрат. Оптимум этого процесса имеет название Равновесие Нэша – где каждый игрок не будет работать лучше, меняя стратегию, учитывая тот факт, что другой игрок не меняет свою стратегию.

Однако выигрыши обычно обучаются с использованием методов градиентного спуска , которые предназначены для нахождения низкого значения функции затрат , а не для нахождения равновесия Нэша игры.

Коллапс режима

Большинство распределений данных являются мультимодальными. Возьмем набор данных MNIST : существует 10 “режимов” данных, относящихся к различным цифрам между 0 и 9.

Хорошая генеративная модель была бы в состоянии производить образцы с достаточной изменчивостью, таким образом, будучи в состоянии генерировать образцы из всех различных классов.

Однако это происходит не всегда.

Допустим, генератор становится действительно хорош в производстве цифры “3”. Если полученные образцы достаточно убедительны, то дискриминатор, скорее всего, присвоит им высокие вероятности.

В результате генератор будет подталкиваться к созданию образцов, которые поступают из этого конкретного режима, игнорируя другие классы большую часть времени. По сути, он будет спамить один и тот же номер, и с каждым номером, который проходит дискриминатор, это поведение будет только усиливаться.

Пример коллапса режима

Уменьшающийся градиент

Очень похоже на предыдущий пример, дискриминатор может оказаться слишком успешным в различении выборок данных. Когда это верно, градиент генератора исчезает, он начинает учиться все меньше и меньше, не сходясь.

Этот дисбаланс, как и предыдущий, может быть вызван, если мы будем тренировать сети отдельно. Эволюция нейронных сетей может быть совершенно непредсказуемой, что может привести к тому, что один будет опережать другого на милю. Если мы тренируем их вместе, мы в основном гарантируем, что этого не произойдет.

произведение искусства

Было бы невозможно дать исчерпывающее представление обо всех усовершенствованиях и разработках, которые сделали оружие более мощным и стабильным в последние годы.

Вместо этого я составлю список наиболее успешных архитектур и методов, предоставив ссылки на соответствующие ресурсы для более глубокого изучения.

DCGANs

Глубокие сверточные GANs (DCGANs) ввели свертки в генераторные и дискриминаторные сети.

Однако это было не просто добавление сверточных слоев к модели, поскольку обучение стало еще более нестабильным.

Чтобы сделать DCGANs полезными, пришлось применить несколько трюков:

  • Пакетная нормализация применялась как к генератору, так и к сети дискриминаторов
  • Отсев используется в качестве метода регуляризации
  • Генератор нуждался в способе апсемплировать случайный входной вектор к выходному изображению. Здесь используется транспонирование сверточных слоев
  • В обеих сетях используются активации LeakyReLU и TanH
DCGANs

Ганс

Ганы Вассерштейна (Gans) направлены на повышение устойчивости тренировок. За этим типом модели стоит большое количество математики. Более доступное объяснение можно найти здесь .

Основная идея здесь состояла в том, чтобы предложить новую функцию затрат, которая имеет более плавный градиент везде.

Новая функция стоимости использует метрику под названием расстояние Вассерштейна , которая имеет более плавный градиент везде.

В результате дискриминатор, который теперь называется критик , выводит доверительные значения, которые больше не следует воспринимать как вероятность. Высокие значения означают, что модель уверена в том, что входные данные являются реальными.

Два существенных улучшения для УИГАНА:

  • Он не имеет никаких признаков коллапса режима в экспериментах
  • Генератор все еще может учиться, когда критик работает хорошо

Саган

Само-внимание GANs (SAGANs) вводит механизм внимания в структуру GAN.

Механизмы внимания позволяют использовать глобальную информацию локально . Это означает, что мы можем уловить смысл из разных частей изображения и использовать эту информацию для получения лучших образцов.

Это происходит из наблюдения, что свертки довольно плохо улавливают долгосрочные зависимости во входных выборках, поскольку свертка является локальной операцией, рецептивное поле которой зависит от пространственного размера ядра.

Это означает, что, например, выход в левом верхнем углу изображения не может иметь никакого отношения к выходу в правом нижнем углу.

Одним из способов решения этой проблемы было бы использование ядер с большими размерами, чтобы захватить больше информации. Однако это привело бы к тому, что модель была бы вычислительно неэффективной и очень медленной в обучении.

Самостоятельное внимание решает эту проблему, обеспечивая эффективный способ захвата глобальной информации и использования ее локально, когда она может оказаться полезной.

БигГАНы

БигГАНы на момент написания статьи считались более или менее современными в том, что касается качества генерируемых образцов.

То, что исследователи сделали здесь, состояло в том, чтобы собрать воедино все, что работало до этого момента, а затем масштабировать его массово. Их базовая модель была фактически САГАНОМ, к которому они добавили некоторые трюки для повышения стабильности.

Они доказали, что могут значительно выиграть от масштабирования, даже если в модель не будут введены дополнительные функциональные улучшения, как указано в оригинальной статье:

Мы показали, что генеративные состязательные сети, обученные моделировать естественные изображения нескольких категорий, в значительной степени выигрывают от масштабирования, как с точки зрения точности, так и разнообразия генерируемых выборок. В результате наши модели установили новый уровень производительности среди моделей ImageNet GAN, улучшив уровень техники с большим отрывом

Простой GAN в Python

Реализация кода

Учитывая все сказанное, давайте продолжим и реализуем простой пистолет, который генерирует цифры от 0 до 9, довольно классический пример:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

# Sample z from uniform distribution
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

Теперь мы можем определить заполнитель для наших входных выборок и векторов шума:

# Input image, for discriminator model.
X = tf.placeholder(tf.float32, shape=[None, 784])

# Input noise for generator.
Z = tf.placeholder(tf.float32, shape=[None, 100])

Теперь мы определяем наши генераторные и дискриминаторные сети. Это простые персептроны с одним скрытым слоем.

Мы используем relu активации в нейронах скрытого слоя и sigmoid для выходных слоев.

def generator(z):
    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(z, 128, activation=tf.nn.relu)
        x = tf.layers.dense(z, 784)
        x = tf.nn.sigmoid(x)
    return x

def discriminator(x):
    with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(x, 128, activation=tf.nn.relu)
        x = tf.layers.dense(x, 1)
        x = tf.nn.sigmoid(x)
    return x

Теперь мы можем определить наши модели, функции потерь и оптимизаторы:

# Generator model
G_sample = generator(Z)

# Discriminator models
D_real = discriminator(X)
D_fake = discriminator(G_sample)


# Loss function
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

# Select parameters
disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("disc")]
gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("gen")]

# Optimizers
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=disc_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=gen_vars)

Наконец, мы можем написать программу тренировок. На каждой итерации мы выполняем один шаг оптимизации для дискриминатора и один для генератора.

Каждые 100 итераций мы сохраняем некоторые сгенерированные образцы, чтобы иметь возможность взглянуть на наш прогресс.

# Batch size
mb_size = 128

# Dimension of input noise
Z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out2/'):
    os.makedirs('out2/')

i = 0

for it in range(1000000):

    # Save generated images every 1000 iterations.
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)


    # Get next batch of images. Each batch has mb_size samples.
    X_mb, _ = mnist.train.next_batch(mb_size)


    # Run disciminator solver
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})

    # Run generator solver
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

    # Print loss
    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))

Результаты и возможные улучшения

Во время первых итераций все, что мы видим, – это случайный шум:

Первые итерации

Здесь сети еще ничему не научились. Хотя всего через пару минут мы уже видим, как складываются наши цифры!

68000-я итерация

Ресурсы

Если вы хотите поиграть с кодом, он находится на GitHub !