Автор оригинала: Mihajlo Pavloski.
TensorFlow: Сохранение и восстановление моделей
Обучение модели глубокой нейронной сети может занять довольно много времени, в зависимости от сложности вашей модели, объема имеющихся данных, аппаратного обеспечения, на котором вы запускаете свои модели, и т. Д. В большинстве случаев вам нужно будет сохранить свой прогресс в файл, так что в случае прерывания (или ошибки) вы сможете продолжить с того места, где остановились.
Более того, после успешного обучения вам наверняка придется повторно использовать изученные параметры модели, чтобы делать прогнозы на новых данных. Это относится к любой платформе глубокого обучения, как и к TensorFlow.
В этом посте мы рассмотрим сохранение и восстановление модели ТензорНого потока, в которой мы опишем некоторые из наиболее полезных вариантов на этом пути и приведем несколько примеров.
Краткое введение в модель тензорного потока
Основная функциональность TensorFlow реализуется через tensors – его базовую структуру данных, аналогичную многомерным массивам в NumPy, и graphs – представление вычислений по данным. Это библиотека symbolic , означающая, что определение графа и тензоров будет только создавать модель, в то время как тензоры получают конкретные значения, а операции выполняются в рамках сеанса – механизма выполнения моделируемых операций в графе. Любые конкретные значения тензоров теряются при закрытии сеанса, что является еще одной причиной сохранения ваших моделей в файл после запуска сеанса.
Это всегда легче понять на примерах, поэтому давайте создадим простую модель тензорного потока для линейной регрессии двумерных данных.
Во-первых, мы импортируем наши библиотеки:
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt %matplotlib inline
Следующий шаг-создание модели. Мы создадим модель, которая будет оценивать горизонтальный и вертикальный сдвиг квадратичной функции в виде:
y = (x - h) ^ 2 + v
где h
и v
– горизонтальные и вертикальные сдвиги.
Следующие строки генерируют модель (подробнее см. Комментарии в коде):
# Clear the current graph in each run, to avoid variable duplication tf.reset_default_graph() # Create placeholders for the x and y points X = tf.placeholder("float") Y = tf.placeholder("float") # Initialize the two parameters that need to be learned h_est = tf.Variable(0.0, name='hor_estimate') v_est = tf.Variable(0.0, name='ver_estimate') # y_est holds the estimated values on y-axis y_est = tf.square(X - h_est) + v_est # Define a cost function as the squared distance between Y and y_est cost = (tf.pow(Y - y_est, 2)) # The training operation for minimizing the cost function. The # learning rate is 0.001 trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
На данный момент у нас есть модель , которую нужно запустить в Сессии , передавая ей некоторые реальные данные. Давайте сгенерируем несколько примеров квадратичных данных и добавим к ним шум.
# Use some values for the horizontal and vertical shift h = 1 v = -2 # Generate training data with noise x_train = np.linspace(-2,4,201) noise = np.random.randn(*x_train.shape) * 0.4 y_train = (x_train - h) ** 2 + v + noise # Visualize the data plt.rcParams['figure.figsize'] = (10, 6) plt.scatter(x_train, y_train) plt.xlabel('x_train') plt.ylabel('y_train')
Класс Saver
Класс Saver
, предоставляемый библиотекой TensorFlow, является рекомендуемым способом сохранения структуры графиков и переменных.
Сохранение Моделей
В следующих нескольких строках мы определяем объект Saver
и в методе train_graph()
проходим 100 итераций, чтобы минимизировать функцию затрат. Затем модель сохраняется на диск на каждой итерации, а также после завершения оптимизации. Каждое сохранение создает двоичные файлы на диске, называемые “контрольными точками”.
# Create a Saver object saver = tf.train.Saver() init = tf.global_variables_initializer() # Run a session. Go through 100 iterations to minimize the cost def train_graph(): with tf.Session() as sess: sess.run(init) for i in range(100): for (x, y) in zip(x_train, y_train): # Feed actual data to the train operation sess.run(trainop, feed_dict={X: x, Y: y}) # Create a checkpoint in every iteration saver.save(sess, 'model_iter', global_step=i) # Save the final model saver.save(sess, 'model_final') h_ = sess.run(h_est) v_ = sess.run(v_est) return h_, v_
Теперь давайте обучим модель с помощью вышеприведенной функции и распечатаем изученные параметры.
result = train_graph() print("h_est = %.2f, v_est = %.2f" % result)
$ python tf_save.py h_est = 1.01, v_est = -1.96
Ладно, параметры были довольно точно оценены. Если мы проверим вашу файловую систему, то обнаружим файлы, сохраненные за последние 4 итерации, а также окончательную модель.
При сохранении модели вы заметите, что для ее сохранения требуется 4 типа файлов:
- файлы “.meta”: содержащие структуру графа
- файлы “.data”: содержащие значения переменных
- файлы “.index”: идентификация контрольной точки
- файл “checkpoint”: буфер протокола со списком последних контрольных точек
Рис. 1. Файлы контрольных точек, сохраненные на диске
Вызов тф.поезда.Метод Saver ()
, как показано выше, сохранит все переменные в файл. Сохранение подмножества ваших переменных возможно путем передачи их в качестве аргумента через список или дикт, например: tf.train.Saver({'hor_estimate': h_est})
.
Несколько других полезных аргументов конструктора Saver
, которые позволяют контролировать весь процесс, являются:
max_to_keep
: максимальное количество контрольных точек для хранения,keep_checkpoint_every_n_hours
: интервал времени для сохранения контрольных точек
Для получения дополнительной информации ознакомьтесь с официальной документацией для класса Saver
, который предлагает другие полезные аргументы, которые вы можете изучить.
Восстановление моделей
Первое, что нужно сделать при восстановлении модели TensorFlow, – это загрузить структуру графа из файла “.meta” в текущий граф.
tf.reset_default_graph() imported_meta = tf.train.import_meta_graph("model_final.meta")
Текущий график можно исследовать с помощью следующей команды tf.get_default_graph()
. Теперь второй шаг – загрузить значения переменных.
Напоминание: значения существуют только в рамках сеанса.
with tf.Session() as sess: imported_meta.restore(sess, tf.train.latest_checkpoint('./')) h_est2 = sess.run('hor_estimate:0') v_est2 = sess.run('ver_estimate:0') print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2))
$ python tf_restore.py INFO:tensorflow:Restoring parameters from ./model_final h_est: 1.01, v_est: -1.96
Как упоминалось ранее, этот подход сохраняет только структуру графа и переменные, что означает, что обучающие данные, вводимые через наши заполнители ” X ” и “Y”, не сохраняются.
В любом случае, для этого примера мы будем использовать наши обучающие данные , определенные из tf
, и визуализировать соответствие модели.
plt.scatter(x_train, y_train, label='train data') plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red', label='model') plt.xlabel('x_train') plt.ylabel('y_train') plt.legend()
В качестве нижней строки для этой части класс Saver
позволяет легко сохранять и восстанавливать вашу модель TensorFlow (график и переменные) в/из файла, а также сохранять несколько контрольных точек вашей работы, которые могут быть полезны для опробования вашей модели на новых данных, продолжения ее обучения и дальнейшей тонкой настройки.
Сохраненный Формат Модели
Одним из новых подходов к сохранению и восстановлению модели в TensorFlow является использование функций Saved Model, builder и loader . Это фактически обертывает класс Saver
, чтобы обеспечить сериализацию более высокого уровня, которая более подходит для производственных целей.
Хотя подход Saved Model
, похоже, еще не полностью принят разработчиками, его создатели указывают, что это явно будущее. По сравнению с классом Saver
, который фокусируется в основном на переменных, SavedModel
пытается объединить в один пакет множество полезных функций, таких как Signatures
, которые позволяют сохранять графики, имеющие набор входов и выходов, и Assets
содержащие внешние файлы, используемые при инициализации.
Сохранение моделей с помощью Saved ModelBuilder
Сохранение модели выполняется с помощью класса Saved ModelBuilder
. В нашем примере мы не используем никаких подписей или активов, но этого достаточно, чтобы проиллюстрировать процесс.
tf.reset_default_graph() # Re-initialize our two variables h_est = tf.Variable(h_est2, name='hor_estimate2') v_est = tf.Variable(v_est2, name='ver_estimate2') # Create a builder builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/') # Add graph and variables to builder and save with tf.Session() as sess: sess.run(h_est.initializer) sess.run(v_est.initializer) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None) builder.save()
$ python tf_saved_model_builder.py INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: b'./SavedModel/saved_model.pb'
Запустив этот код, вы заметите, что наша модель сохраняется в файле, расположенном по адресу “./Saved Model/saved_model.pb”.
Восстановление моделей с помощью сохраненного загрузчика моделей
Восстановление модели выполняется с помощью tf.saved_model.loader
и восстанавливает сохраненные переменные, сигнатуры и активы в области сеанса.
В следующем примере мы загрузим модель и распечатаем значения наших двух коэффициентов h_est
и v_est
.
with tf.Session() as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], './SavedModel/') h_est = sess.run('hor_estimate2:0') v_est = sess.run('ver_estimate2:0') print("h_est: %.2f, v_est: %.2f" % (h_est, v_est))
$ python tf_saved_model_loader.py INFO:tensorflow:Restoring parameters from b'./SavedModel/variables/variables' h_est: 1.01, v_est: -1.96
И снова, как и ожидалось, наша модель была успешно восстановлена с правильными обученными параметрами.
Вывод
Сохранение и восстановление модели тензорного потока-очень полезная функция, зная, что обучение глубоких сетей может занять много времени. Эта тема слишком широка, чтобы подробно освещаться в одном блоге, поэтому мы можем вернуться к ней в следующем посте.
Во всяком случае, в этом посте мы представили два инструмента: базовый класс Saver
, который сохраняет модель в виде контрольных точек, и SavedModel
| builder | loader/|, который строится поверх
Saver и создает файловую структуру, простую в использовании в производстве. Для иллюстрации примеров была использована простая линейная регрессия.