Рубрики
Без рубрики

Сделать Печатные Ссылки Кликабельными С помощью API Обнаружения объектов TensorFlow 2

📃 TL;DR В этой статье мы начнем решать вопрос о том, чтобы сделать печатные ссылки (например, в книге или журнале) кликабельными с помощью камеры вашего смартфона. Мы будем использовать обнаружение объектов TensorFlow 2…

Автор оригинала: Oleksii Trekhleb.

📃 TL;DR

В этой статье мы начнем решать вопрос о том, чтобы сделать печатные ссылки (например, в книге или журнале) кликабельными с помощью камеры вашего смартфона.

Мы будем использовать TensorFlow 2 Object Detection API для обучения пользовательской модели детектора объектов для поиска позиций и ограничивающих полей подстрок, таких как https:// в текстовом изображении (т. Е. в потоке камеры смартфона).

Текст каждой ссылки (правое продолжение https:// ограничивающая рамка) будет распознан с помощью библиотеки Tesseract . Часть распознавания не будет рассмотрена в этой статье, но вы можете найти полный пример кода приложения в репозитории links-detector .

🚀 Запустите демо-версию детектора ссылок со своего смартфона, чтобы увидеть конечный результат.

📝 Откройте репозиторий links-detector на GitHub, чтобы увидеть полный исходный код приложения.

Вот как будет выглядеть окончательное решение:

Демонстрация детектора ссылок

⚠️ В настоящее время приложение находится в экспериментальной |/альфа-стадии и имеет много проблем и ограничений . Поэтому не повышайте уровень своих ожиданий слишком высоко, пока эти проблемы не будут решены.. Кроме того, цель этой статьи-больше узнать, как работать с API обнаружения объектов TensorFlow 2, а не создавать готовую к производству модель.

В случае, если блоки кода Python в этой статье не будут иметь надлежащего форматирования на этой платформе, не стесняйтесь читать статью на GitHub

🤷🏻 ‍️ Проблема

Я работаю инженером-программистом и в свободное время изучаю машинное обучение как хобби. Но проблема пока не в этом.

Недавно я купил печатную книгу о машинном обучении, и пока я читал первые несколько глав, я столкнулся со многими печатными ссылками в тексте, которые выглядели как https://tensorflow.org/ или https://some-url.com/which/may/be/even/longer?and_with_params=true .

Печатные ссылки

Я видел все эти ссылки, но не мог нажать на них, так как они были напечатаны (спасибо, кэп!). Чтобы перейти по этим ссылкам, мне нужно было начать вводить их символ за символом в адресной строке браузера, что было довольно раздражающим и подверженным ошибкам.

💡 Возможное Решение

Итак, я подумал, что, если, подобно обнаружению QR-кода, мы попытаемся “научить” смартфон (1) обнаруживать и (2) | распознавать печатные ссылки для нас и сделать их кликабельными ? Таким образом, вы сделаете только один щелчок вместо нескольких нажатий клавиш. Оперативная сложность “щелчка” по печатным ссылкам варьируется от O(N) до O(1) .

Вот как будет выглядеть окончательный рабочий процесс:

Демонстрация детектора ссылок

📝 Требования к решению

Как я уже упоминал ранее, я просто изучаю машинное обучение как хобби. Таким образом, цель этой статьи больше связана с изучением того, как работать с API обнаружения объектов TensorFlow 2, а не с созданием готового к производству приложения.

С учетом сказанного я упростил требования к решению до следующего:

  1. Процессы обнаружения и распознавания должны иметь производительность, близкую к реальному времени (т. е. 0.5-1 кадров в секунду) на устройстве, таком как iPhone X. Это означает, что весь процесс обнаружения + распознавания должен занимать до 2 секунды (довольно терпимо, как для любительского проекта).
  2. Должны поддерживаться только Английские ссылки.
  3. Следует поддерживать только темный текст (т. е. черный или темно-серый) на светлом фоне (т. е. белый или светло-серый).
  4. На данный момент должны поддерживаться только https:// ссылки (это нормально, если наша модель не распознает http:// , ftp:// , tcp:// или другие типы ссылок).

🧩 Разбивка решения

Разбивка на высоком уровне

Давайте посмотрим, как мы могли бы подойти к проблеме на высоком уровне.

Вариант 1: Модель обнаружения на задней панели

Поток:

  1. Получите поток камеры (кадр за кадром) на стороне клиента.
  2. Отправляйте каждый кадр один за другим по сети на серверную часть.
  3. Выполните обнаружение и распознавание ссылок на бэкэнде и отправьте ответ обратно клиенту.
  4. Клиент рисует поля обнаружения с кликабельными ссылками.
Модель на задней панели

Плюсы:

  • 💚 Производительность обнаружения не ограничивается устройством клиента. Мы можем ускорить обнаружение, масштабируя службу по горизонтали (добавляя больше экземпляров) и по вертикали (добавляя больше ядер/графических процессоров).
  • 💚 Модель может быть больше, так как нет необходимости загружать ее на клиентскую сторону. Загрузка модели ~10 Мб на стороне клиента может быть в порядке, но в противном случае загрузка модели ~100 Мб может стать большой проблемой для клиентской сети и UX приложений (пользовательский опыт).
  • 💚 Можно контролировать, кто использует модель. Модель защищена API, поэтому у нас будет полный контроль над ее абонентами/клиентами.

Аферы:

  • 💔 Рост сложности системы. Технический стек приложений вырос с простого JavaScript до, скажем, JavaScript + Python . Нам нужно позаботиться об автоматическом масштабировании.
  • 💔 Автономный режим для приложения невозможен, так как для его работы требуется подключение к Интернету.
  • 💔 Слишком много HTTP-запросов между клиентом и сервером в какой-то момент может стать узким местом. Представьте себе, если бы мы хотели улучшить производительность обнаружения, скажем, от 1 чтобы 10+ кадров в секунду. Это означает, что каждый клиент отправит 10+ запросы в секунду. Для 10 одновременные клиенты это уже 100+ запросы в секунду. В этом случае может быть полезна HTTP/2 двунаправленная потоковая передача и gRPC , но здесь мы возвращаемся к возросшей сложности системы.
  • 💔 Система становится дороже. Почти все баллы из раздела ” За ” должны быть оплачены.

Вариант 2: Модель обнаружения на интерфейсе

Поток:

  1. Получите поток камеры (кадр за кадром) на стороне клиента.
  2. Выполните обнаружение и распознавание ссылок на стороне клиента (без отправки чего-либо на серверную часть).
  3. Клиент рисует поля обнаружения с кликабельными ссылками.
Модель на передней панели

Плюсы:

  • 💚 Система менее сложна. Нам не нужно настраивать серверы, создавать API и вводить в систему дополнительный стек Python.
  • 💚 Возможен автономный режим. Приложение не нуждается в подключении к Интернету для работы, так как модель полностью загружена на устройство. Таким образом, Прогрессивное веб-приложение ( PWA ) может быть создано для поддержки этого.
  • 💚 Система “своего рода” автоматически масштабируется. Чем больше у вас клиентов, тем больше ядер и графических процессоров они приносят. Однако это не правильное решение для масштабирования (подробнее об этом в разделе “Минусы” ниже).
  • 💚 Система дешевле. Нам нужен только сервер для статических ресурсов ( HTML , JS , CSS , файлы моделей и т. Д.). Это можно сделать бесплатно, скажем, на GitHub.
  • 💚 Нет проблем с растущим количеством HTTP-запросов в секунду на стороне сервера.

Аферы:

  • 💔 Возможно только горизонтальное масштабирование (каждый клиент будет иметь свой собственный процессор/GPU). Вертикальное масштабирование невозможно, так как мы не можем повлиять на производительность устройства клиента. В результате мы не можем гарантировать быстрое обнаружение устройств с низкой производительностью.
  • 💔 Невозможно защитить использование модели и контролировать абонентов/клиентов модели. Каждый мог скачать модель и повторно использовать ее.
  • 💔 Потребление батареи устройства клиента может стать проблемой. Для работы модели требуются вычислительные ресурсы. Таким образом, клиенты могут быть недовольны тем, что их iPhone становится все теплее и теплее, пока приложение работает.

Заключение высокого уровня

Поскольку целью проекта было больше обучение, а не создание готового к производству решения Я решил пойти со вторым вариантом обслуживания модели со стороны клиента . Это сделало весь проект намного дешевле (на самом деле с GitHub его можно было бесплатно разместить), и я мог больше сосредоточиться на машинном обучении, чем на внутренней инфраструктуре автоматического масштабирования.

Разбивка на более низком уровне

Итак, мы решили использовать бессерверное решение. Теперь у нас есть изображение из потока камеры в качестве входных данных, которое выглядит примерно так:

Ввод печатных ссылок

Нам нужно решить две подзадачи для этого изображения:

  1. Ссылки обнаружение (поиск положения и ограничивающих рамок ссылок)
  2. Ссылки распознавание (распознавание текста ссылок)

Вариант 1: Решение на основе Тессеракта

Первым и наиболее очевидным подходом было бы решение задачи оптического распознавания символов ( OCR ) путем распознавания всего текста изображения с помощью, скажем, Tesseract.js библиотека. Он возвращает ограничительные рамки абзацев, текстовых строк и текстовых блоков вместе с распознанным текстом.

Распознанный текст с ограничивающими рамками

Затем мы можем попытаться извлечь ссылки из распознанных текстовых строк или текстовых блоков с помощью регулярного выражения, такого как этот (пример приведен на TypeScript):

const URL_REG_EXP = /https?:\/\/(www\.)?[-a-zA-Z0-9@:%._+~#=]{2,256}\.[a-z]{2,4}\b([-a-zA-Z0-9@:%_+.~#?&/=]*)/gi;

const extractLinkFromText = (text: string): string | null => {
  const urls: string[] | null = text.match(URL_REG_EXP);
  if (!urls || !urls.length) {
    return null;
  }
  return urls[0];
};

💚 Похоже, что проблема решена довольно простым и простым способом:

  • Мы знаем ограничительные рамки ссылок
  • Мы также знаем текст ссылок, чтобы сделать их кликабельными

💔 Дело в том, что распознавание + обнаружение время может варьироваться от 2 чтобы 20+ секунды в зависимости от размера текста, от количества “чего-то похожего на текст” на изображении, от качества изображения и от других факторов. Так что добиться этого будет очень трудно 0.5-1 кадров в секунду, чтобы сделать пользовательский опыт по крайней мере близким к реальному времени.

💔 Кроме того, если мы подумаем об этом, мы попросим библиотеку распознать весь текст из изображения для нас, даже если он может содержать только одну или две ссылки (т. Е. Только ~10% текста может быть полезно для нас), или он может даже не содержать ссылок вообще. В данном случае это звучит как пустая трата вычислительных ресурсов.

Вариант 2: Решение на основе тессеракта + тензорного потока

Мы могли бы заставить Тессеракт работать быстрее, если бы использовали некоторый дополнительный алгоритм “советника” до распознавания текста ссылок. Этот алгоритм “советника” должен обнаруживать, но не распознавать крайнюю левую позицию каждой ссылки на изображении, если таковые имеются. Это позволит нам ускорить процесс распознавания, следуя этим правилам:

  1. Если изображение не содержит никакой ссылки, мы вообще не должны вызывать обнаружение/распознавание тессеракта.
  2. Если на изображении есть ссылки, нам нужно попросить Тессеракт распознать только те части изображения, которые содержат ссылки. Мы не заинтересованы в том, чтобы тратить время на распознавание нерелевантного текста, который не содержит ссылок.

Алгоритм “советник”, который будет проходить перед Тессерактом, должен работать с постоянным временем независимо от качества изображения или наличия/отсутствия текста на изображении. Он также должен быть довольно быстрым и обнаруживать крайние левые позиции ссылок менее чем за 1 с , чтобы мы могли удовлетворить требование “близко к реальному времени” (т. Е. На iPhone X).

💡 Итак, что, если мы будем использовать другую модель обнаружения объектов, чтобы помочь нам найти все вхождения подстрок https:// (каждая защищенная ссылка имеет этот префикс, не так ли) в изображении? Затем, имея эти https:// ограничительные поля в тексте, мы можем извлечь их правое продолжение и отправить их в Тессеракт для распознавания текста.

Взгляните на картинку ниже:

Решение на основе тессеракта и тензорного потока

Вы можете заметить, что Тессеракт должен делать гораздо меньше работы в случае, если у него будут некоторые подсказки о том, где могут быть расположены ссылки (см. Количество синих квадратов на обеих картинках).

Таким образом, теперь вопрос в том, какую модель обнаружения объектов мы должны выбрать и как переобучить ее для поддержки обнаружения пользовательских https:// объектов.

Наконец-то! Мы подошли ближе к части статьи о тензорном потоке 😀

🤖 Выбор модели обнаружения объекта

Обучение новой модели обнаружения объектов не является разумным вариантом в нашем контексте по следующим причинам:

  • 💔 Процесс обучения может занять несколько дней/недель и долларов.
  • 💔 Мы, скорее всего, не сможем собрать сотни тысяч помеченных изображений книг, в которых есть ссылки (мы могли бы попытаться сгенерировать их, но об этом позже).

Поэтому вместо того, чтобы создавать новую модель, мы должны лучше научить существующую модель обнаружения объектов выполнять обнаружение пользовательских объектов для нас (выполнять обучение передаче ). В нашем случае “пользовательскими объектами” будут изображения с нарисованным в них текстом |/https:/ /. Этот подход имеет следующие преимущества:

  • 💚 Набор данных может быть намного меньше. Нам не нужно собирать сотни тысяч помеченных изображений. Вместо этого мы можем сделать ~100 изображений и пометить их вручную. Это связано с тем, что модель уже предварительно обучена на общем наборе данных, таком как COCO dataset , и уже научилась извлекать общие объекты изображения.
  • 💚 Процесс обучения будет намного быстрее (минуты/часы на GPU вместо дней/недель). Опять же, это связано с меньшим набором данных (меньшими пакетами) и меньшим количеством обучаемых параметров.

Мы можем выбрать существующую модель из TensorFlow 2 Detection Model Zoo , которая предоставляет набор моделей обнаружения, предварительно обученных на наборе данных COCO 2017 . Теперь он содержит ~40 вариантов моделей на выбор.

Для переобучения и точной настройки модели на пользовательском наборе данных мы будем использовать API обнаружения объектов TensorFlow 2 . API обнаружения объектов TensorFlow-это платформа с открытым исходным кодом, построенная поверх TensorFlow , которая позволяет легко создавать, обучать и развертывать модели обнаружения объектов.

Если вы перейдете по ссылке Model Zoo , вы найдете скорость обнаружения и точность для каждой модели.

Модель Зоопарка

Источник изображения: Зоопарк модели тензорного потока хранилище

Конечно, мы хотели бы найти правильный баланс между обнаружением скоростью и точностью при выборе модели. Но что может быть еще более важным в нашем случае, так это размер модели, поскольку она будет загружена на клиентскую сторону.

Размер архивной модели может сильно варьироваться от ~20 Мб до ~1 Гб . Есть несколько примеров:

  • 1386 (Мб) centernet_hg 104_1024x1024_kpts_coco 17_tp u32
  • 330 (Мб) |/centernet_resnet101_v1_png_512x512_coco 17_tpu-8 195 (Мб)
  • centernet_resnet 50_v1_png_512x512_coco 17_tpu-8 198 (Мб)
  • centernet_resnet 50_v1_png_512x512_kits_coco 17_tpu-8 227 (Мб)
  • centernet_resnet 50_v2_512x512_coco 17_tpu-8 230 (Мб)
  • centernet_resnet 50_v2_512x512_kits_coco 17_tpu-8 29 (Мб)
  • эффективный det_d0_coco 17_tp u32 49 (Мб)
  • эффективный det_d1_coco 17_tp u32 60 (Мб)
  • эффективный det_d2_coco 17_tp u32 89 (Мб)
  • эффективный data_d3_coco 17_tp u32 151 (Мб)
  • эффективный data_d4_coco 17_tp u32 244 (Мб)
  • эффективный data_d5_coco 17_tp u32 376 (Мб)
  • эффективный data_d6_coco 17_tp u32 376 (Мб)
  • эффективный data_d7_coco 17_tp u32 665 (Мб)
  • extremenet 427 (Мб)
  • faster_rcnn_inception_resnet_v2_1024x1024_coco 17_tpu-8 424 (Мб)
  • faster_rcnn_inception_resnet_v2_640 x 640_coco 17_tpu-8 337 (Мб)
  • faster_rcnn_resnet101_v1_1024x1024_coco 17_tpu-8 337 (Мб)
  • faster_rcnn_resnet101_v1_640x640_coco 17_tpu-8 343 (Мб)
  • faster_rcnn_resnet101_v1_800x1333_coco17_gpu-8 449 (Мб)
  • faster_rcnn_resnet152_v1_1024x1024_coco 17_tpu-8 449 (Мб)
  • faster_rcnn_resnet152_v1_640x640_coco 17_tpu-8 454 (Мб)
  • faster_rcnn_resnet152_v1_800x1333_coco17_gpu-8 202 (Мб)
  • faster_rcnn_resnet 50_v1_1024x1024_coco 17_tpu-8 202 (Мб)
  • faster_rcnn_resnet 50_v1_640x640_coco 17_tpu-8 207 (Мб)
  • faster_rcnn_resnet 50_v1_800 x 1333_coco 17_gpu-8 462 (Мб)
  • |/mask_rcnn_inception_resnet_v2_1024x1024_coco 17_gpu-8 86 (Мб) |/ssd_mobile net_v1_fpn_640 x 640_coco 17_tpu-8
  • 44 (Мб) |/ssd_mobile net_v2_320x320_coco 17_tpu-8 20 (Мб)
  • |/ssd_mobile net_v2_fpnlite_320x320_coco 17_tpu-8 20 (Мб) |/ssd_mobile net_v2_fpnlite_640x640_coco 17_tpu-8
  • 369 (Мб) |/ssd_resnet101_v1_fpn_1024x1024_coco 17_tpu-8 369 (Мб)
  • |/ssd_resnet101_v1_fpn_640 x 640_coco 17_tpu-8 481 (Мб) |/ssd_resnet152_v1_fpn_1024x1024_coco 17_tpu-8
  • 480 (Мб) |/ssd_resnet152_v1_fpn_640 x 640_coco 17_tpu-8 233 (Мб)
  • |/ssd_resnet 50_v1_fpn_1024x1024_coco 17_tpu-8 233 (Мб) |/ssd_resnet 50_v1_fpn_640 x 640_coco 17_tpu-8

Модель ssd_mobile net_v2_fpnlite_640 x 640_coco 17_tpu-8 может хорошо подойти в нашем случае:

  • 💚 Он относительно легкий: 20 Мб архивирован.
  • 💚 Это довольно быстро: 39 мс для обнаружения.
  • 💚 Он использует сеть Mobile Net v2 в качестве экстрактора функций, который оптимизирован для использования на мобильных устройствах для снижения энергопотребления.
  • 💚 Он выполняет обнаружение объектов для всего изображения и для всех объектов в нем за один раз независимо от содержимого изображения (не задействовано предложение регионов , что ускоряет обнаружение).
  • 💔 Это не самая точная модель, хотя (все является компромиссом ⚖ )).

Название модели кодирует несколько важных характеристик, о которых вы можете прочитать больше, если хотите:

🛠 Установка API обнаружения объектов

В этой статье мы собираемся установить API обнаружения объектов Tensorflow 2 в качестве пакета Python . Это удобно в том случае, если вы экспериментируете в Google Colab (рекомендуется) или в Jupiter . В обоих случаях локальная установка не требуется, вы можете поэкспериментировать прямо в своем браузере.

Вы также можете следовать официальной документации , если вы предпочитаете устанавливать API обнаружения объектов через Docker.

Если вы застряли с чем-то во время установки API или во время подготовки набора данных, попробуйте прочитать учебник TensorFlow 2 Object Detection API , который добавляет много полезных деталей в этот процесс.

Во-первых, давайте клонируем репозиторий API :

git clone --depth 1 https://github.com/tensorflow/models

выход →

Cloning into 'models'...
remote: Enumerating objects: 2301, done.
remote: Counting objects: 100% (2301/2301), done.
remote: Compressing objects: 100% (2000/2000), done.
remote: Total 2301 (delta 561), reused 922 (delta 278), pack-reused 0
Receiving objects: 100% (2301/2301), 30.60 MiB | 13.90 MiB/s, done.
Resolving deltas: 100% (561/561), done.

Теперь давайте скомпилируем файлы API proto в файлы Python с помощью инструмента protoc :

cd ./models/research
protoc object_detection/protos/*.proto --python_out=.

Наконец, давайте установим версию TF2 setup.py через pip :

cp ./object_detection/packages/tf2/setup.py .
pip install . --quiet

Возможно, что последний шаг завершится неудачей из-за некоторых ошибок зависимостей. В этом случае вы можете запустить pip install . --тихо еще раз.

Мы можем проверить, что установка прошла успешно, выполнив следующие тесты:

python object_detection/builders/model_builder_tf2_test.py

Вы должны увидеть журналы, которые заканчиваются чем-то похожим на это:

[       OK ] ModelBuilderTF2Test.test_unknown_ssd_feature_extractor
----------------------------------------------------------------------
Ran 20 tests in 45.072s

OK (skipped=1)

Установлен API обнаружения объектов TensorFlow! Теперь вы можете использовать сценарии, предоставляемые API для выполнения модели вывода , обучения или тонкой настройки .

⬇️ Загрузка предварительно подготовленной модели

Давайте загрузим нашу выбранную ssd_mobile net_v2_fpnlite_640x640_coco 17_tpu-8 модель из био модели TensorFlow и проверим, как она выполняет общее обнаружение объектов (обнаружение объектов классов из набора данных COCO, таких как “кошка”, “собака”, “автомобиль” и т. Д.).

Мы будем использовать помощник get_file() TensorFlow, чтобы загрузить архивную модель с URL-адреса и распаковать ее.

import tensorflow as tf
import pathlib

MODEL_NAME = 'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8'
TF_MODELS_BASE_PATH = 'http://download.tensorflow.org/models/object_detection/tf2/20200711/'
CACHE_FOLDER = './cache'

def download_tf_model(model_name, cache_folder):
    model_url = TF_MODELS_BASE_PATH + model_name + '.tar.gz'
    model_dir = tf.keras.utils.get_file(
        fname=model_name, 
        origin=model_url,
        untar=True,
        cache_dir=pathlib.Path(cache_folder).absolute()
    )
    return model_dir

# Start the model download.
model_dir = download_tf_model(MODEL_NAME, CACHE_FOLDER)
print(model_dir)

выход →

/content/cache/datasets/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8

Вот как выглядит структура папок до сих пор:

Папка кэша

Папка checkpoint содержит моментальный снимок предварительно обученной модели.

Файл pipeline.config содержит параметры обнаружения модели. Мы вернемся к этому файлу позже, когда нам нужно будет точно настроить модель.

🏄 🏻 ️ Пробуем модель (Делаем вывод)

На данный момент модель может обнаруживать объекты 90 классов набора данных COCO , такие как автомобиль , птица , хот-дог и т. Д.

Классы КАКАО

Источник изображения: Набор данных COCO вебсайт

Давайте посмотрим, как модель работает с некоторыми общими изображениями, содержащими объекты этих классов.

Загрузка этикеток COCO

API обнаружения объектов уже имеет полный набор меток (классов) COCO, определенных для нас.

import os

# Import Object Detection API helpers.
from object_detection.utils import label_map_util

# Loads the COCO labels data (class names and indices relations).
def load_coco_labels():
    # Object Detection API already has a complete set of COCO classes defined for us.
    label_map_path = os.path.join(
        'models/research/object_detection/data',
        'mscoco_complete_label_map.pbtxt'
    )
    label_map = label_map_util.load_labelmap(label_map_path)

    # Class ID to Class Name mapping.
    categories = label_map_util.convert_label_map_to_categories(
        label_map,
        max_num_classes=label_map_util.get_max_label_map_index(label_map),
        use_display_name=True
    )
    category_index = label_map_util.create_category_index(categories)
    
    # Class Name to Class ID mapping.
    label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)

    return category_index, label_map_dict

# Load COCO labels.
coco_category_index, coco_label_map_dict = load_coco_labels()

print('coco_category_index:', coco_category_index)
print('coco_label_map_dict:', coco_label_map_dict)

выход →

coco_category_index:
{
    1: {'id': 1, 'name': 'person'},
    2: {'id': 2, 'name': 'bicycle'},
    ...
    90: {'id': 90, 'name': 'toothbrush'},
}

coco_label_map_dict:
{
    'background': 0,
    'person': 1,
    'bicycle': 2,
    'car': 3,
    ...
    'toothbrush': 90,
}

Постройте функцию обнаружения

Нам нужно создать функцию обнаружения, которая будет использовать предварительно обученную модель, которую мы загрузили, для обнаружения объекта.

import tensorflow as tf

# Import Object Detection API helpers.
from object_detection.utils import config_util
from object_detection.builders import model_builder

# Generates the detection function for specific model and specific model's checkpoint
def detection_fn_from_checkpoint(config_path, checkpoint_path):
    # Build the model.
    pipeline_config = config_util.get_configs_from_pipeline_file(config_path)
    model_config = pipeline_config['model']
    model = model_builder.build(
        model_config=model_config,
        is_training=False,
    )

    # Restore checkpoints.
    ckpt = tf.compat.v2.train.Checkpoint(model=model)
    ckpt.restore(checkpoint_path).expect_partial()

    # This is a function that will do the detection.
    @tf.function
    def detect_fn(image):
        image, shapes = model.preprocess(image)
        prediction_dict = model.predict(image, shapes)
        detections = model.postprocess(prediction_dict, shapes)

        return detections, prediction_dict, tf.reshape(shapes, [-1])
    
    return detect_fn

inference_detect_fn = detection_fn_from_checkpoint(
    config_path=os.path.join('cache', 'datasets', MODEL_NAME, 'pipeline.config'),
    checkpoint_path=os.path.join('cache', 'datasets', MODEL_NAME, 'checkpoint', 'ckpt-0'),
)

Эта функция inference_detect_fn примет изображение и вернет информацию об обнаруженных объектах.

Загрузка изображений для вывода

Давайте попробуем обнаружить объект на этом изображении:

Общий вывод объекта

Источник изображения: алексей_трехлеб Инстаграм

Для этого давайте сохраним изображение в папку inference/test/ вашего проекта. Если вы используете Google Colab, вы можете создать эту папку и загрузить изображение вручную.

Вот как выглядит структура папок до сих пор:

Структура папок
import matplotlib.pyplot as plt
%matplotlib inline

# Creating a TensorFlow dataset of just one image.
inference_ds = tf.keras.preprocessing.image_dataset_from_directory(
  directory='inference',
  image_size=(640, 640),
  batch_size=1,
  shuffle=False,
  label_mode=None
)
# Numpy version of the dataset.
inference_ds_numpy = list(inference_ds.as_numpy_iterator())

# You may preview the images in dataset like this.
plt.figure(figsize=(14, 14))
for i, image in enumerate(inference_ds_numpy):
    plt.subplot(2, 2, i + 1)
    plt.imshow(image[0].astype("uint8"))
    plt.axis("off")
plt.show()

Запуск обнаружения по тестовым данным

Теперь мы готовы запустить обнаружение. Массив inference_ds_numpy[0] хранит данные пикселей для первого изображения в формате Numpy .

detections, predictions_dict, shapes = inference_detect_fn(
    inference_ds_numpy[0]
)

Давайте посмотрим на формы вывода:

boxes = detections['detection_boxes'].numpy()
scores = detections['detection_scores'].numpy()
classes = detections['detection_classes'].numpy()
num_detections = detections['num_detections'].numpy()[0]

print('boxes.shape: ', boxes.shape)
print('scores.shape: ', scores.shape)
print('classes.shape: ', classes.shape)
print('num_detections:', num_detections)

выход →

boxes.shape:  (1, 100, 4)
scores.shape:  (1, 100)
classes.shape:  (1, 100)
num_detections: 100.0

Модель сделала 100 обнаружение для нас. Это не значит, что он нашел 100 однако объекты на изображении. Это означает, что модель имеет 100 слоты, и он может обнаружить 100 объекты на максимуме на одном изображении. Каждое обнаружение имеет оценку, которая отражает уверенность модели в нем. Ограничительные рамки для каждого обнаружения хранятся в массиве boxes . Оценки или доверительные данные модели о каждом обнаружении хранятся в массиве scores . Наконец, массив classes хранит метки (классы) для каждого обнаружения.

Давайте проверим первые 5 обнаружений:

print('First 5 boxes:')
print(boxes[0,:5])

print('First 5 scores:')
print(scores[0,:5])

print('First 5 classes:')
print(classes[0,:5])

class_names = [coco_category_index[idx + 1]['name'] for idx in classes[0]]
print('First 5 class names:')
print(class_names[:5])

выход →

First 5 boxes:
[[0.17576033 0.84654826 0.25642633 0.88327974]
 [0.5187813  0.12410264 0.6344235  0.34545377]
 [0.5220358  0.5181462  0.6329132  0.7669856 ]
 [0.50933677 0.7045719  0.5619138  0.7446198 ]
 [0.44761637 0.51942706 0.61237675 0.75963426]]

First 5 scores:
[0.6950246 0.6343004 0.591157  0.5827219 0.5415643]

First 5 classes:
[9. 8. 8. 0. 8.]

First 5 class names:
['traffic light', 'boat', 'boat', 'person', 'boat']

Модель видит светофор , три лодки и человека на изображении. Мы можем подтвердить, что действительно эти объекты видны на изображении.

Из массива scores можно увидеть, что модель наиболее уверена (близка к 70% вероятности) в объекте светофор .

Каждая запись массива boxes является [y1, x1, y2, x2] , где (x1, y1) и (x2, y2) являются верхним левым и нижним правым углами ограничивающего прямоугольника.

Давайте визуализируем поля обнаружения:

# Importing Object Detection API helpers.
from object_detection.utils import visualization_utils

# Visualizes the bounding boxes on top of the image.
def visualize_detections(image_np, detections, category_index):
    label_id_offset = 1
    image_np_with_detections = image_np.copy()

    visualization_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detections['detection_boxes'][0].numpy(),
        (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
        detections['detection_scores'][0].numpy(),
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=200,
        min_score_thresh=.4,
        agnostic_mode=False,
    )

    plt.figure(figsize=(12, 16))
    plt.imshow(image_np_with_detections)
    plt.show()

# Visualizing the detections.
visualize_detections(
    image_np=tf.cast(inference_ds_numpy[0][0], dtype=tf.uint32).numpy(),
    detections=detections,
    category_index=coco_category_index,
)

Вот результат:

Результат вывода

Если мы сделаем обнаружение для текстового изображения вот что мы увидим:

Результат вывода для текстового изображения

Модель ничего не смогла обнаружить на этом изображении. Это то, что мы собираемся изменить, мы хотим научить модель “видеть” префиксы https:// на этом изображении.

📝 Подготовка пользовательского набора данных

Чтобы “научить” модель ssd_mobile net_v2_fpnlite_640 x 640_coco 17_tpu-8 обнаруживать пользовательские объекты, которые не являются частью набора данных COCO, нам нужно выполнить точную настройку обучения на новом пользовательском наборе данных.

Наборы данных для обнаружения объектов состоят из двух частей:

  1. Само изображение (т. е. изображение страницы книги)
  2. Поля границ, которые показывают, где именно на изображении расположены пользовательские объекты.
Ограничительные Рамки

В приведенном выше примере каждое поле имеет слева вверху и справа внизу координаты в абсолютных значениях (в пикселях). Однако существуют также различные форматы записи местоположения ограничивающих прямоугольников. Например, мы можем найти ограничивающую рамку, установив координату ее центральной точки и ее ширину и высоту . Мы также можем использовать относительные значения (в процентах от ширины и высоты изображения) для настройки координат. Но у вас есть идея, что сеть должна знать, что такое изображение и где на изображении расположены объекты.

Теперь, как мы можем получить пользовательский набор данных для обучения? У нас есть три варианта:

  1. Повторное использование существующего набора данных.
  2. Создайте новый набор данных изображений поддельных книг.
  3. Создайте набор данных вручную, сделав или загрузив фотографии реальных страниц книги, которые содержат https:// ссылки и пометив все ограничительные рамки.

Вариант 1: Повторное использование существующего набора данных

Существует множество наборов данных, которые являются общими для повторного использования исследователями. Мы могли бы начать со следующих ресурсов, чтобы найти правильный набор данных:

💚 Если вы смогли найти необходимый набор данных и его лицензия позволяет вам повторно использовать его, это, вероятно, самый быстрый способ сразу перейти к обучению модели.

💔 Я не смог найти набор данных с помеченными префиксами https://|/.

Поэтому нам нужно пропустить этот вариант.

Вариант 2: Создание синтетического набора данных

Существуют инструменты (например, keras_ocr ), которые могут помочь нам генерировать случайный текст, включать в него ссылку и рисовать его на изображениях с некоторым фоном и искажениями.

💚 Самое интересное в этом подходе заключается в том, что у нас есть свобода генерировать обучающие примеры для различных шрифтов , лигатур , цветов текста , цветов фона . Это очень полезно, если мы хотим избежать переоснащения модели во время обучения (чтобы модель могла хорошо обобщаться на невидимые примеры реального мира, а не терпеть неудачу, как только немного изменится оттенок фона).

💚 Также возможно генерировать различные типы ссылок, такие как http:// , http:// , ftp:// , tcp:// и т.д. В противном случае, возможно, будет трудно найти достаточно реальных примеров такого рода ссылок для обучения.

💚 Еще одно преимущество этого подхода заключается в том, что мы можем генерировать столько обучающих примеров, сколько захотим. Мы не ограничиваемся количеством страниц печатной книги, которую мы нашли для набора данных. Увеличение числа обучающих примеров также может повысить точность модели.

💔 Однако можно злоупотреблять генератором и генерировать обучающие изображения, которые будут сильно отличаться от реальных примеров. Допустим, мы можем использовать неправильные и нереалистичные искажения для страницы (т. Е. Использовать изгиб волн вместо дугового). В этом случае модель не будет хорошо обобщаться на реальные примеры.

Я считаю этот подход действительно многообещающим. Это может помочь преодолеть многие проблемы с моделью (подробнее об этом ниже). Хотя я еще не пробовал. Но это может быть хорошим кандидатом для другой статьи.

Вариант 3: Создание набора данных вручную

Однако самый простой способ-получить книгу (или книги), сделать фотографии страниц со ссылками и пометить их все вручную.

Хорошей новостью является то, что набор данных может быть довольно небольшим (сотен изображений может быть достаточно), потому что мы не собираемся обучать модель с нуля , но вместо этого мы собираемся сделать обучение передаче (также см. Обучение с несколькими кадрами .)

💚 В этом случае обучающий набор данных будет действительно близок к реальным данным. Вы буквально возьмете печатную книгу, сфотографируете ее с реалистичными шрифтами, изгибами, оттенками, перспективами и цветами.

💔 Даже если для этого не требуется много изображений, это все равно может занять много времени.

💔 Трудно придумать разнообразную базу данных, в которой обучающие примеры имели бы разные шрифты, цвета фона и разные типы ссылок (для этого нам нужно найти много разных книг и журналов).

Поскольку статья имеет цель обучения и поскольку мы не пытаемся выиграть конкурс по обнаружению объектов, давайте пока воспользуемся этой опцией и попробуем создать набор данных самостоятельно.

Предварительная обработка данных

Итак, я закончил съемку 125 изображения страниц книги, содержащих одну или несколько ссылок https:// на них.

Необработанный Набор данных

Я поместил все эти изображения в папку dataset/printed_links/raw .

Далее я собираюсь предварительно обработать изображения, выполнив следующие действия:

  • Измените размер каждого изображения на ширину 1024px (изначально они слишком большие и имеют ширину 3024px )
  • Обрезайте каждое изображение, чтобы сделать их квадратными (это необязательно, и мы могли бы просто изменить размер изображения, просто сжав его, но я хочу, чтобы модель была обучена реалистичным пропорциям https: ящиков).
  • При необходимости поверните изображение, применив метаданные exif .
  • Оттенки серого изображение (нам не нужна модель, чтобы учитывать цвета).
  • Увеличьте яркость
  • Увеличение контрастности
  • Увеличьте резкость

Помните, что после того, как мы решили применить эти преобразования и корректировки к набору данных, нам нужно сделать то же самое в будущем для каждого изображения, которое мы отправим в модель для обнаружения.

Вот как мы могли бы применить эти настройки к изображению с помощью Python:

import os
import math
import shutil

from pathlib import Path
from PIL import Image, ImageOps, ImageEnhance

# Resize an image.
def preprocess_resize(target_width):
    def preprocess(image: Image.Image, log) -> Image.Image:
        (width, height) = image.size
        ratio = width / height

        if width > target_width:
            target_height = math.floor(target_width / ratio)
            log(f'Resizing: To size {target_width}x{target_height}')
            image = image.resize((target_width, target_height))
        else:
            log('Resizing: Image already resized, skipping...')

        return image
    return preprocess

# Crop an image.
def preprocess_crop_square():
    def preprocess(image: Image.Image, log) -> Image.Image:
        (width, height) = image.size
        
        left = 0
        top = 0
        right = width
        bottom = height
        
        crop_size = min(width, height)
        
        if width >= height:
            # Horizontal image.
            log(f'Squre cropping: Horizontal {crop_size}x{crop_size}')
            left = width // 2 - crop_size // 2
            right = left + crop_size
        else:
            # Vetyical image.
            log(f'Squre cropping: Vertical {crop_size}x{crop_size}')
            top = height // 2 - crop_size // 2
            bottom = top + crop_size

        image = image.crop((left, top, right, bottom))
        return image
    return preprocess

# Apply exif transpose to an image.
def preprocess_exif_transpose():
    # @see: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html
    def preprocess(image: Image.Image, log) -> Image.Image:
        log('EXif transpose')
        image = ImageOps.exif_transpose(image)
        return image
    return preprocess

# Apply color transformations to the image.
def preprocess_color(brightness, contrast, color, sharpness):
    # @see: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
    def preprocess(image: Image.Image, log) -> Image.Image:
        log('Coloring')
        
        enhancer = ImageEnhance.Color(image)
        image = enhancer.enhance(color)

        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(brightness)
        
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(contrast)
        
        enhancer = ImageEnhance.Sharpness(image)
        image = enhancer.enhance(sharpness)
        
        return image
    return preprocess

# Image pre-processing pipeline.
def preprocess_pipeline(src_dir, dest_dir, preprocessors=[], files_num_limit=0, override=False):
    # Create destination folder if not exists.
    Path(dest_dir).mkdir(parents=False, exist_ok=True)
    
    # Get the list of files to be copied.
    src_file_names = os.listdir(src_dir)
    files_total = files_num_limit if files_num_limit > 0 else len(src_file_names)
    files_processed = 0
    
    # Logger function.
    def preprocessor_log(message):
        print('  ' + message)
    
    # Iterate through files.
    for src_file_index, src_file_name in enumerate(src_file_names):
        if files_num_limit > 0 and src_file_index >= files_num_limit:
            break
            
        # Copy file.        
        src_file_path = os.path.join(src_dir, src_file_name)
        dest_file_path = os.path.join(dest_dir, src_file_name)
        
        progress = math.floor(100 * (src_file_index + 1) / files_total)
        print(f'Image {src_file_index + 1}/{files_total} | {progress}% |  {src_file_path}')
        
        if not os.path.isfile(src_file_path):
            preprocessor_log('Source is not a file, skipping...\n')
            continue
        
        if not override and os.path.exists(dest_file_path):
            preprocessor_log('File already exists, skipping...\n')
            continue
            
        shutil.copy(src_file_path, dest_file_path)
        files_processed += 1
        
        # Preprocess file.
        image = Image.open(dest_file_path)
        
        for preprocessor in preprocessors:
            image = preprocessor(image, preprocessor_log)
        
        image.save(dest_file_path, quality=95)
        print('')
        
    print(f'{files_processed} out of {files_total} files have been processed')

# Launching the image preprocessing pipeline.
preprocess_pipeline(
    src_dir='dataset/printed_links/raw',
    dest_dir='dataset/printed_links/processed',
    override=True,
    # files_num_limit=1,
    preprocessors=[
        preprocess_exif_transpose(),
        preprocess_resize(target_width=1024),
        preprocess_crop_square(),
        preprocess_color(brightness=2, contrast=1.3, color=0, sharpness=1),
    ]
)

В результате все обработанные изображения были сохранены в папке dataset/printed_links/processed .

Обработанный набор данных

Вы можете просмотреть изображения следующим образом:

import matplotlib.pyplot as plt
import numpy as np

def preview_images(images_dir, images_num=1, figsize=(15, 15)):
    image_names = os.listdir(images_dir)
    image_names = image_names[:images_num]
    
    num_cells = math.ceil(math.sqrt(images_num))
    figure = plt.figure(figsize=figsize)
    
    for image_index, image_name in enumerate(image_names):
        image_path = os.path.join(images_dir, image_name)
        image = Image.open(image_path)
        
        figure.add_subplot(num_cells, num_cells, image_index + 1)
        plt.imshow(np.asarray(image))
    
    plt.show()

preview_images('dataset/printed_links/processed', images_num=4, figsize=(16, 16))

Маркировка набора данных

Для выполнения маркировки (для обозначения местоположения объектов, которые нас интересуют, а именно префиксов https:// ), мы можем использовать инструмент аннотации графических изображений LabelImg .

Для этого шага вы можете установить инструмент маркировки на локальном компьютере (не в Colab). Подробные инструкции по установке можно найти в файле Labeling README .

Как только у вас установлен инструмент маркировки, вы можете запустить его для папки dataset/printed_links/processed из корня вашего проекта следующим образом:

labelImg dataset/printed_links/processed

Затем вам нужно будет пометить все изображения из папки dataset/printed_links/processed и сохранить аннотации в виде XML-файлов в папку dataset/printed_links/labels/xml/|/.

Маркировка
Процесс Маркировки

После маркировки у нас должен быть XML-файл с данными ограничивающих рамок для каждого изображения:

Структура папок ярлыков

Разделение набора данных на подмножества train, test и validation

Чтобы определить проблему переоснащения или недооснащения модели, нам нужно разделить набор данных на train и test dataset. Допустим, 80% наших изображений будут использоваться для обучения модели и 20% изображений будут использоваться для проверки того, насколько хорошо модель обобщается на изображения, которые она не видела раньше.

В этом разделе мы разделим файлы, скопировав их в разные папки ( test и train папки). Однако это может быть не самым оптимальным способом. Вместо этого разделение набора данных может быть выполнено на tf.data.Набор данных уровень.

import re
import random

def partition_dataset(
    images_dir,
    xml_labels_dir,
    train_dir,
    test_dir,
    val_dir,
    train_ratio,
    test_ratio,
    val_ratio,
    copy_xml
):    
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
        
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
        
    if not os.path.exists(val_dir):
        os.makedirs(val_dir)

    images = [f for f in os.listdir(images_dir)
              if re.search(r'([a-zA-Z0-9\s_\\.\-\(\):])+(.jpg|.jpeg|.png)$', f, re.IGNORECASE)]

    num_images = len(images)
    
    num_train_images = math.ceil(train_ratio * num_images)
    num_test_images = math.ceil(test_ratio * num_images)
    num_val_images = math.ceil(val_ratio * num_images)
    
    print('Intended split')
    print(f'  train: {num_train_images}/{num_images} images')
    print(f'  test: {num_test_images}/{num_images} images')
    print(f'  val: {num_val_images}/{num_images} images')
    
    actual_num_train_images = 0
    actual_num_test_images = 0
    actual_num_val_images = 0
    
    def copy_random_images(num_images, dest_dir):
        copied_num = 0
        
        if not num_images:
            return copied_num
        
        for i in range(num_images):
            if not len(images):
                break
                
            idx = random.randint(0, len(images)-1)
            filename = images[idx]
            shutil.copyfile(os.path.join(images_dir, filename), os.path.join(dest_dir, filename))
            
            if copy_xml:
                xml_filename = os.path.splitext(filename)[0]+'.xml'
                shutil.copyfile(os.path.join(xml_labels_dir, xml_filename), os.path.join(dest_dir, xml_filename))
            
            images.remove(images[idx])
            copied_num += 1
        
        return copied_num
    
    actual_num_train_images = copy_random_images(num_train_images, train_dir)
    actual_num_test_images = copy_random_images(num_test_images, test_dir)
    actual_num_val_images = copy_random_images(num_val_images, val_dir)
    
    print('\n', 'Actual split')
    print(f'  train: {actual_num_train_images}/{num_images} images')
    print(f'  test: {actual_num_test_images}/{num_images} images')
    print(f'  val: {actual_num_val_images}/{num_images} images')

partition_dataset(
    images_dir='dataset/printed_links/processed',
    train_dir='dataset/printed_links/partitioned/train',
    test_dir='dataset/printed_links/partitioned/test',
    val_dir='dataset/printed_links/partitioned/val',
    xml_labels_dir='dataset/printed_links/labels/xml',
    train_ratio=0.8,
    test_ratio=0.2,
    val_ratio=0,
    copy_xml=True
)

После разделения структура папок набора данных должна выглядеть примерно так:

dataset/
└── printed_links
    ├── labels
    │   └── xml
    ├── partitioned
    │   ├── test
    │   └── train
    │       ├── IMG_9140.JPG
    │       ├── IMG_9140.xml
    │       ├── IMG_9141.JPG
    │       ├── IMG_9141.xml
    │       ...
    ├── processed
    └── raw

Экспорт набора данных

Последняя манипуляция, которую мы должны сделать с данными, – это преобразовать наши наборы данных в формат TF Record . Формат записи TF – это формат, который TensorFlow использует для хранения последовательности двоичных записей.

Во-первых, давайте создадим две папки: одна-для меток в формате CSV , а другая-для конечного набора данных в формате TF Record .

mkdir -p dataset/printed_links/labels/csv
mkdir -p dataset/printed_links/tfrecords

Теперь нам нужно создать файл dataset/printed_links/labels/label_map.pbtxt proto, который будет описывать классы объектов в нашем наборе данных. В нашем случае у нас есть только один класс , который мы можем назвать http . Вот содержание этого файла:

item {
  id: 1
  name: 'http'
}

Теперь мы готовы генерировать наборы данных записей TF из изображений в формате jpg и меток в формате xml :

import os
import io
import math
import glob
import tensorflow as tf
import pandas as pd
import xml.etree.ElementTree as ET
from PIL import Image
from collections import namedtuple
from object_detection.utils import dataset_util, label_map_util

tf1 = tf.compat.v1

# Convers labels from XML format to CSV.
def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                int(root.find('size')[0].text),
                int(root.find('size')[1].text),
                member[0].text,
                int(member[4][0].text),
                int(member[4][1].text),
                int(member[4][2].text),
                int(member[4][3].text)
            )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df


def class_text_to_int(row_label, label_map_dict):
    return label_map_dict[row_label]


def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]


# Creates a TFRecord.
def create_tf_example(group, path, label_map_dict):
    with tf1.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_jpg = fid.read()
        
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size

    filename = group.filename.encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        xmins.append(row['xmin'] / width)
        xmaxs.append(row['xmax'] / width)
        ymins.append(row['ymin'] / height)
        ymaxs.append(row['ymax'] / height)
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class'], label_map_dict))

    tf_example = tf1.train.Example(features=tf1.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    
    return tf_example


def dataset_to_tfrecord(
    images_dir,
    xmls_dir, 
    label_map_path,
    output_path,
    csv_path=None
):
    label_map = label_map_util.load_labelmap(label_map_path)
    label_map_dict = label_map_util.get_label_map_dict(label_map)
    
    tfrecord_writer = tf1.python_io.TFRecordWriter(output_path)
    images_path = os.path.join(images_dir)
    csv_examples = xml_to_csv(xmls_dir)
    grouped_examples = split(csv_examples, 'filename')
    
    for group in grouped_examples:
        tf_example = create_tf_example(group, images_path, label_map_dict)
        tfrecord_writer.write(tf_example.SerializeToString())
        
    tfrecord_writer.close()
    
    print('Successfully created the TFRecord file: {}'.format(output_path))
    
    if csv_path is not None:
        csv_examples.to_csv(csv_path, index=None)
        print('Successfully created the CSV file: {}'.format(csv_path))

# Generate a TFRecord for train dataset.
dataset_to_tfrecord(
    images_dir='dataset/printed_links/partitioned/train',
    xmls_dir='dataset/printed_links/partitioned/train',
    label_map_path='dataset/printed_links/labels/label_map.pbtxt',
    output_path='dataset/printed_links/tfrecords/train.record',
    csv_path='dataset/printed_links/labels/csv/train.csv'
)

# Generate a TFRecord for test dataset.
dataset_to_tfrecord(
    images_dir='dataset/printed_links/partitioned/test',
    xmls_dir='dataset/printed_links/partitioned/test',
    label_map_path='dataset/printed_links/labels/label_map.pbtxt',
    output_path='dataset/printed_links/tfrecords/test.record',
    csv_path='dataset/printed_links/labels/csv/test.csv'
)

В результате теперь у нас должно быть два файла: test.record и train.record в dataset/printed_links/tfrecords/| папка:

dataset/
└── printed_links
    ├── labels
    │   ├── csv
    │   ├── label_map.pbtxt
    │   └── xml
    ├── partitioned
    │   ├── test
    │   ├── train
    │   └── val
    ├── processed
    ├── raw
    └── tfrecords
        ├── test.record
        └── train.record

Эти два файла test.record и train.record являются нашими окончательными наборами данных, которые мы будем использовать для точной настройки модели ssb_mobile net_v2_fpnlite_640x640_coco17_tpu-8 .

📖 Изучение наборов данных записей TF

В этом разделе мы рассмотрим, как мы можем использовать API обнаружения объектов TensorFlow 2 для изучения наборов данных в формате TF Record .

Проверка количества элементов в наборе данных

Чтобы подсчитать количество элементов в наборе данных, мы можем сделать следующее:

import tensorflow as tf

# Count the number of examples in the dataset.
def count_tfrecords(tfrecords_filename):
    raw_dataset = tf.data.TFRecordDataset(tfrecords_filename)
    # Keep in mind that the list() operation might be
    # a performance bottleneck for large datasets. 
    return len(list(raw_dataset))

TRAIN_RECORDS_NUM = count_tfrecords('dataset/printed_links/tfrecords/train.record')
TEST_RECORDS_NUM = count_tfrecords('dataset/printed_links/tfrecords/test.record')

print('TRAIN_RECORDS_NUM: ', TRAIN_RECORDS_NUM)
print('TEST_RECORDS_NUM:  ', TEST_RECORDS_NUM)

выход →

TRAIN_RECORDS_NUM:  100
TEST_RECORDS_NUM:   25

Поэтому мы будем тренировать модель дальше 100 примеры, и мы проверим точность модели на 25 тестовые изображения.

Предварительный просмотр изображений набора данных с ограничивающими рамками

Для предварительного просмотра изображений с полями обнаружения мы можем сделать следующее:

import tensorflow as tf
import numpy as np
from google.protobuf import text_format
import matplotlib.pyplot as plt

# Import Object Detection API.
from object_detection.utils import visualization_utils
from object_detection.protos import string_int_label_map_pb2
from object_detection.data_decoders.tf_example_decoder import TfExampleDecoder

%matplotlib inline

# Visualize the TFRecord dataset.
def visualize_tfrecords(tfrecords_filename, label_map=None, print_num=1):
    decoder = TfExampleDecoder(
        label_map_proto_file=label_map,
        use_display_name=False
    )

    if label_map is not None:
        label_map_proto = string_int_label_map_pb2.StringIntLabelMap()

        with tf.io.gfile.GFile(label_map,'r') as f:
            text_format.Merge(f.read(), label_map_proto)
            class_dict = {}
            
            for entry in label_map_proto.item:
                class_dict[entry.id] = {'name': entry.name}

    raw_dataset = tf.data.TFRecordDataset(tfrecords_filename)

    for raw_record in raw_dataset.take(print_num):
        example = decoder.decode(raw_record)

        image = example['image'].numpy()
        boxes = example['groundtruth_boxes'].numpy()
        confidences = example['groundtruth_image_confidences']
        filename = example['filename']
        area = example['groundtruth_area']
        classes = example['groundtruth_classes'].numpy()
        image_classes = example['groundtruth_image_classes']
        weights = example['groundtruth_weights']

        scores = np.ones(boxes.shape[0])

        visualization_utils.visualize_boxes_and_labels_on_image_array( 
            image,                                               
            boxes,                                                     
            classes,
            scores,
            class_dict,
            max_boxes_to_draw=None,
            use_normalized_coordinates=True
        )

        plt.figure(figsize=(8, 8))
        plt.imshow(image)

    plt.show()

# Visualizing the training TFRecord dataset.
visualize_tfrecords(
    tfrecords_filename='dataset/printed_links/tfrecords/train.record',
    label_map='dataset/printed_links/labels/label_map.pbtxt',
    print_num=3
)

В результате мы должны увидеть несколько изображений с ограничивающими рамками, нарисованными поверх каждого изображения.

Предварительный просмотр записи TF

📈 Настройка тензорной Доски

Перед началом тренировочного процесса нам необходимо запустить Тензорную доску .

TensorBoard позволит нам отслеживать процесс обучения и видеть, действительно ли модель чему-то учится, или нам лучше прекратить обучение и настроить параметры обучения. Это также поможет нам проанализировать, какие объекты и в каком месте обнаруживает модель.

Тензорная доска

Источник изображения: Домашняя страница TensorBoard

Самое интересное в TensorBoard то, что мы можем запустить его непосредственно в Google Colab. Однако, если вы используете ноутбук в локальной установке Jupiter, вы также можете установить его как пакет Python и запустить его с терминала.

Во-первых, давайте создадим папку ./logs , в которую будут записаны все журналы обучения:

mkdir -p logs

Затем мы можем загрузить расширение TensorBoard в Google Colab:

%load_ext tensorboard

И, наконец, мы можем запустить тензорную плату для мониторинга папки ./logs :

%tensorboard --logdir ./logs

В результате вы должны увидеть пустую панель тензорной доски:

Пустая Панель Тензорной платы

После начала обучения модели вы можете вернуться к этой панели и посмотреть ход процесса обучения.

🏋 🏻 ️ Обучение модели

Настройка конвейера обнаружения

Теперь пришло время вернуться к файлу cache/datasets/ssd_mobile net_v2_fpnlite_640 x 640_coco 17_tpu-8/pipeline.config , о котором мы упоминали ранее. Этот файл определяет параметры ssd_mobile net_v2_fpnlite_640 x 640_coco 17_tpu-8 обучение модели.

Нам нужно скопировать файл pipeline.config в корень проекта и настроить в нем пару вещей:

  1. Мы должны изменить количество классов с 90 (классы COCO), чтобы просто 1 (класс http ).
  2. Мы должны уменьшить размер пакета до 8 чтобы избежать ошибок, связанных с недостаточной памятью.
  3. Нам нужно указать модели на ее контрольные точки , так как мы не хотим обучать модель с нуля.
  4. Нам нужно изменить fine_tune_checkpoint_type на detection .
  5. Нам нужно указать модель на правильную карту меток .
  6. Наконец, нам нужно распечатать модель в наборах данных train и test .

Все эти изменения могут быть сделаны вручную непосредственно в файле pipeline.config . Но мы также можем выполнять их с помощью кода:

import tensorflow as tf
from shutil import copyfile
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2

# Adjust pipeline config modification here if needed.
def modify_config(pipeline):
    # Model config.
    pipeline.model.ssd.num_classes = 1    

    # Train config.
    pipeline.train_config.batch_size = 8

    pipeline.train_config.fine_tune_checkpoint = 'cache/datasets/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/checkpoint/ckpt-0'
    pipeline.train_config.fine_tune_checkpoint_type = 'detection'

    # Train input reader config.
    pipeline.train_input_reader.label_map_path = 'dataset/printed_links/labels/label_map.pbtxt'
    pipeline.train_input_reader.tf_record_input_reader.input_path[0] = 'dataset/printed_links/tfrecords/train.record'

    # Eval input reader config.
    pipeline.eval_input_reader[0].label_map_path = 'dataset/printed_links/labels/label_map.pbtxt'
    pipeline.eval_input_reader[0].tf_record_input_reader.input_path[0] = 'dataset/printed_links/tfrecords/test.record'

    return pipeline

def clone_pipeline_config():
    copyfile(
        'cache/datasets/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/pipeline.config',
        'pipeline.config'
    )

def setup_pipeline(pipeline_config_path):
    clone_pipeline_config()
    pipeline = read_pipeline_config(pipeline_config_path)
    pipeline = modify_config(pipeline)
    write_pipeline_config(pipeline_config_path, pipeline)
    return pipeline

def read_pipeline_config(pipeline_config_path):
    pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          
    with tf.io.gfile.GFile(pipeline_config_path, "r") as f:                                                                                                                                                                                                                     
        proto_str = f.read()                                                                                                                                                                                                                                          
        text_format.Merge(proto_str, pipeline)
    return pipeline

def write_pipeline_config(pipeline_config_path, pipeline):
    config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
    with tf.io.gfile.GFile(pipeline_config_path, "wb") as f:                                                                                                                                                                                                                       
        f.write(config_text)

# Adjusting the pipeline configuration.
pipeline = setup_pipeline('pipeline.config')

print(pipeline)

Вот содержимое файла pipeline.config :

model {
  ssd {
    num_classes: 1
    image_resizer {
      fixed_shape_resizer {
        height: 640
        width: 640
      }
    }
    feature_extractor {
      type: "ssd_mobilenet_v2_fpn_keras"
      depth_multiplier: 1.0
      min_depth: 16
      conv_hyperparams {
        regularizer {
          l2_regularizer {
            weight: 3.9999998989515007e-05
          }
        }
        initializer {
          random_normal_initializer {
            mean: 0.0
            stddev: 0.009999999776482582
          }
        }
        activation: RELU_6
        batch_norm {
          decay: 0.996999979019165
          scale: true
          epsilon: 0.0010000000474974513
        }
      }
      use_depthwise: true
      override_base_feature_extractor_hyperparams: true
      fpn {
        min_level: 3
        max_level: 7
        additional_layer_depth: 128
      }
    }
    box_coder {
      faster_rcnn_box_coder {
        y_scale: 10.0
        x_scale: 10.0
        height_scale: 5.0
        width_scale: 5.0
      }
    }
    matcher {
      argmax_matcher {
        matched_threshold: 0.5
        unmatched_threshold: 0.5
        ignore_thresholds: false
        negatives_lower_than_unmatched: true
        force_match_for_each_row: true
        use_matmul_gather: true
      }
    }
    similarity_calculator {
      iou_similarity {
      }
    }
    box_predictor {
      weight_shared_convolutional_box_predictor {
        conv_hyperparams {
          regularizer {
            l2_regularizer {
              weight: 3.9999998989515007e-05
            }
          }
          initializer {
            random_normal_initializer {
              mean: 0.0
              stddev: 0.009999999776482582
            }
          }
          activation: RELU_6
          batch_norm {
            decay: 0.996999979019165
            scale: true
            epsilon: 0.0010000000474974513
          }
        }
        depth: 128
        num_layers_before_predictor: 4
        kernel_size: 3
        class_prediction_bias_init: -4.599999904632568
        share_prediction_tower: true
        use_depthwise: true
      }
    }
    anchor_generator {
      multiscale_anchor_generator {
        min_level: 3
        max_level: 7
        anchor_scale: 4.0
        aspect_ratios: 1.0
        aspect_ratios: 2.0
        aspect_ratios: 0.5
        scales_per_octave: 2
      }
    }
    post_processing {
      batch_non_max_suppression {
        score_threshold: 9.99999993922529e-09
        iou_threshold: 0.6000000238418579
        max_detections_per_class: 100
        max_total_detections: 100
        use_static_shapes: false
      }
      score_converter: SIGMOID
    }
    normalize_loss_by_num_matches: true
    loss {
      localization_loss {
        weighted_smooth_l1 {
        }
      }
      classification_loss {
        weighted_sigmoid_focal {
          gamma: 2.0
          alpha: 0.25
        }
      }
      classification_weight: 1.0
      localization_weight: 1.0
    }
    encode_background_as_zeros: true
    normalize_loc_loss_by_codesize: true
    inplace_batchnorm_update: true
    freeze_batchnorm: false
  }
}
train_config {
  batch_size: 8
  data_augmentation_options {
    random_horizontal_flip {
    }
  }
  data_augmentation_options {
    random_crop_image {
      min_object_covered: 0.0
      min_aspect_ratio: 0.75
      max_aspect_ratio: 3.0
      min_area: 0.75
      max_area: 1.0
      overlap_thresh: 0.0
    }
  }
  sync_replicas: true
  optimizer {
    momentum_optimizer {
      learning_rate {
        cosine_decay_learning_rate {
          learning_rate_base: 0.07999999821186066
          total_steps: 50000
          warmup_learning_rate: 0.026666000485420227
          warmup_steps: 1000
        }
      }
      momentum_optimizer_value: 0.8999999761581421
    }
    use_moving_average: false
  }
  fine_tune_checkpoint: "cache/datasets/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/checkpoint/ckpt-0"
  num_steps: 50000
  startup_delay_steps: 0.0
  replicas_to_aggregate: 8
  max_number_of_boxes: 100
  unpad_groundtruth_tensors: false
  fine_tune_checkpoint_type: "detection"
  fine_tune_checkpoint_version: V2
}
train_input_reader {
  label_map_path: "dataset/printed_links/labels/label_map.pbtxt"
  tf_record_input_reader {
    input_path: "dataset/printed_links/tfrecords/train.record"
  }
}
eval_config {
  metrics_set: "coco_detection_metrics"
  use_moving_averages: false
}
eval_input_reader {
  label_map_path: "dataset/printed_links/labels/label_map.pbtxt"
  shuffle: false
  num_epochs: 1
  tf_record_input_reader {
    input_path: "dataset/printed_links/tfrecords/test.record"
  }
}

Запуск учебного процесса

Теперь мы готовы запустить процесс обучения с использованием API обнаружения объектов TensorFlow 2. API содержит model_main_tf2.py скрипт, который будет запускать обучение для нас. Не стесняйтесь исследовать флаги, которые этот скрипт Python поддерживает в исходном коде (т. Е. num_train_steps , model_dir и другие), чтобы увидеть их значения.

Мы будем обучать модель для 1000 итерации (эпохи). Не стесняйтесь тренировать его для меньшего или большего числа итераций в зависимости от прогресса обучения (см. Диаграммы тензорной доски).

%%bash

NUM_TRAIN_STEPS=1000
CHECKPOINT_EVERY_N=1000

PIPELINE_CONFIG_PATH=pipeline.config
MODEL_DIR=./logs
SAMPLE_1_OF_N_EVAL_EXAMPLES=1

python ./models/research/object_detection/model_main_tf2.py \
  --model_dir=$MODEL_DIR \
  --num_train_steps=$NUM_TRAIN_STEPS \
  --sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \
  --pipeline_config_path=$PIPELINE_CONFIG_PATH \
  --checkpoint_every_n=$CHECKPOINT_EVERY_N \
  --alsologtostderr

Пока модель обучается (это может занять около ~10 минут для 1000 итерации в Google Co lab GPU runtime) вы должны иметь возможность наблюдать за ходом обучения в TensorBoard. Потери локализации и классификации должны уменьшиться, что означает, что модель хорошо справляется с локализацией и классификацией новых пользовательских объектов.

Процесс обучения

Также во время обучения новые контрольные точки модели (параметры, которые модель изучила во время обучения) будут сохранены в папке logs .

Структура папок logs теперь выглядит следующим образом:

logs
├── checkpoint
├── ckpt-1.data-00000-of-00001
├── ckpt-1.index
└── train
    └── events.out.tfevents.1606560330.b314c371fa10.1747.1628.v2

Оценка модели (необязательно)

Процесс оценки использует обученные контрольные точки модели и оценивает, насколько хорошо модель работает при обнаружении объектов в тестовом наборе данных. Результаты этой оценки суммируются в виде некоторых метрик , которые могут быть рассмотрены с течением времени. Вы можете прочитать больше о том, как оценить эти показатели здесь .

В этой статье мы пропустим этап оценки метрик. Но мы все еще можем использовать шаг оценки, чтобы увидеть обнаружение модели в тензорной доске:

%%bash

PIPELINE_CONFIG_PATH=pipeline.config
MODEL_DIR=logs

python ./models/research/object_detection/model_main_tf2.py \
  --model_dir=$MODEL_DIR \
  --pipeline_config_path=$PIPELINE_CONFIG_PATH \
  --checkpoint_dir=$MODEL_DIR \

После запуска скрипта вы должны увидеть несколько изображений бок о бок с полями обнаружения:

Оценка модели

🗜 Экспорт модели

После завершения процесса обучения мы должны сохранить обученную модель для дальнейшего использования. Для экспорта модели мы будем использовать exporter_main_v2.py скрипт из API обнаружения объектов. Он подготавливает график тензорного потока обнаружения объектов для вывода с использованием конфигурации модели и обученной контрольной точки. Сценарий выводит связанные файлы контрольных точек, сохраненную модель и копию конфигурации модели:

%%bash

python ./models/research/object_detection/exporter_main_v2.py \
    --input_type=image_tensor \
    --pipeline_config_path=pipeline.config \
    --trained_checkpoint_dir=logs \
    --output_directory=exported/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8

Вот что содержит папка exported после экспорта:

exported
└── ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
    ├── checkpoint
    │   ├── checkpoint
    │   ├── ckpt-0.data-00000-of-00001
    │   └── ckpt-0.index
    ├── pipeline.config
    └── saved_model
        ├── assets
        ├── saved_model.pb
        └── variables
            ├── variables.data-00000-of-00001
            └── variables.index

В данный момент у нас есть saved_model , который может быть использован для вывода.

🚀 Использование экспортированной модели

Давайте посмотрим, как мы можем использовать сохраненную модель из предыдущего шага для обнаружения объектов.

Во-первых, нам нужно создать функцию обнаружения, которая будет использовать сохраненную модель. Он примет изображение и выведет обнаруженные объекты:

import time
import math

PATH_TO_SAVED_MODEL = 'exported/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/saved_model'

def detection_function_from_saved_model(saved_model_path):
    print('Loading saved model...', end='')
    start_time = time.time()

    # Load saved model and build the detection function
    detect_fn = tf.saved_model.load(saved_model_path)

    end_time = time.time()
    elapsed_time = end_time - start_time

    print('Done! Took {} seconds'.format(math.ceil(elapsed_time)))

    return detect_fn

exported_detect_fn = detection_function_from_saved_model(
    PATH_TO_SAVED_MODEL
)

выход →

Loading saved model...Done! Took 9 seconds

Чтобы сопоставить идентификаторы обнаруженных классов с именами классов, нам также нужно загрузить карту меток:

from object_detection.utils import label_map_util

category_index = label_map_util.create_category_index_from_labelmap(
    'dataset/printed_links/labels/label_map.pbtxt',
    use_display_name=True
)

print(category_index)

выход →

{1: {'id': 1, 'name': 'http'}}

Тестирование модели на тестовом наборе данных.

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

from object_detection.utils import visualization_utils
from object_detection.data_decoders.tf_example_decoder import TfExampleDecoder

%matplotlib inline

def tensors_from_tfrecord(
    tfrecords_filename,
    tfrecords_num,
    dtype=tf.float32
):
    decoder = TfExampleDecoder()
    raw_dataset = tf.data.TFRecordDataset(tfrecords_filename)
    images = []

    for raw_record in raw_dataset.take(tfrecords_num):
        example = decoder.decode(raw_record)
        image = example['image']
        image = tf.cast(image, dtype=dtype)
        images.append(image)
    
    return images

def test_detection(tfrecords_filename, tfrecords_num, detect_fn):
    image_tensors = tensors_from_tfrecord(
        tfrecords_filename,
        tfrecords_num,
        dtype=tf.uint8
    )

    for image_tensor in image_tensors:   
        image_np = image_tensor.numpy()
    
        # The model expects a batch of images, so add an axis with `tf.newaxis`.
        input_tensor = tf.expand_dims(image_tensor, 0)

        detections = detect_fn(input_tensor)

        # All outputs are batches tensors.
        # Convert to numpy arrays, and take index [0] to remove the batch dimension.
        # We're only interested in the first num_detections.
        num_detections = int(detections.pop('num_detections'))
        
        detections = {key: value[0, :num_detections].numpy() for key, value in detections.items()}
        detections['num_detections'] = num_detections

        # detection_classes should be ints.
        detections['detection_classes'] = detections['detection_classes'].astype(np.int64)
        
        image_np_with_detections = image_np.astype(int).copy()

        visualization_utils.visualize_boxes_and_labels_on_image_array(
            image_np_with_detections,
            detections['detection_boxes'],
            detections['detection_classes'],
            detections['detection_scores'],
            category_index,
            use_normalized_coordinates=True,
            max_boxes_to_draw=100,
            min_score_thresh=.3,
            agnostic_mode=False
        )

        plt.figure(figsize=(8, 8))
        plt.imshow(image_np_with_detections)
        
    plt.show()


test_detection(
    tfrecords_filename='dataset/printed_links/tfrecords/test.record',
    tfrecords_num=10,
    detect_fn=exported_detect_fn
)

В результате вы должны увидеть 10 изображения из тестового набора данных и выделенные https: префиксы, которые были обнаружены моделью:

Тестирование модели на тестовом наборе данных

Тот факт, что модель способна обнаруживать пользовательские объекты (в нашем случае префиксы https:// ) на изображениях, которые она раньше не видела, является хорошим знаком и тем, чего мы хотели достичь.

🗜 Преобразование модели для Web

Как вы помните из начала этой статьи, наша цель состояла в том, чтобы использовать пользовательскую модель обнаружения объектов в браузере. К счастью, есть TensorFlow.js Существует JavaScript-версия библиотеки TensorFlow. В JavaScript мы не можем напрямую работать с нашей сохраненной моделью. Вместо этого нам нужно преобразовать его в формат tfjs_graph_model .

Для этого нам нужно установить пакет tensorflow Python:

pip install tensorflowjs --quiet

Модель может быть экспортирована следующим образом:

%%bash

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_format=tfjs_graph_model \
    exported/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/saved_model \
    exported_web/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8

Папка exported_web содержит файл .json с метаданными модели и кучу файлов .bin с обученными параметрами модели:

exported_web
└── ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8
    ├── group1-shard1of4.bin
    ├── group1-shard2of4.bin
    ├── group1-shard3of4.bin
    ├── group1-shard4of4.bin
    └── model.json

Наконец, у нас есть модель, которая способна обнаруживать префиксы https:// для нас, и она сохраняется в понятном для JavaScript формате.

Давайте проверим размер модели, чтобы увидеть, достаточно ли она легка, чтобы быть полностью загруженной на клиентскую сторону:

import pathlib

def get_folder_size(folder_path):
    mB = 1000000
    root_dir = pathlib.Path(folder_path)
    sizeBytes = sum(f.stat().st_size for f in root_dir.glob('**/*') if f.is_file())
    return f'{sizeBytes//mB} MB'


print(f'Original model size:      {get_folder_size("cache/datasets/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8")}')
print(f'Exported model size:      {get_folder_size("exported/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8")}')
print(f'Exported WEB model size:  {get_folder_size("exported_web/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8")}')

выход →

Original model size:      31 MB
Exported model size:      28 MB
Exported WEB model size:  13 MB

Как вы можете видеть, модель, которую мы собираемся использовать для Интернета, имеет 13 МБ , что вполне приемлемо в нашем случае.

Позже в JavaScript мы можем начать использовать модель следующим образом:

import * as tf from '@tensorflow/tfjs';
const model = await tf.loadGraphModel(modelURL);

🧭 Следующим шагом является реализация пользовательского интерфейса детектора ссылок, который будет использовать эту модель, но это уже другая история для другой статьи. Окончательный исходный код приложения можно найти в репозитории links-detector на GitHub.

Выводы 🤔

В этой статье мы начали решать проблему с обнаружением печатных ссылок. В итоге мы создали пользовательский детектор объектов для распознавания префиксов https:// на текстовых изображениях (т. Е. на потоковых изображениях камеры смартфона). Мы также преобразовали модель в tfjs_graph_model , чтобы иметь возможность повторно использовать ее на стороне клиента.

Вы можете 🚀 запустить демо-версию детектора ссылок со своего смартфона, чтобы увидеть конечный результат и попробовать, как модель работает в ваших книгах или журналах.

Вот как выглядит окончательное решение:

Демонстрация детектора ссылок

Вы также можете 📝 просмотреть репозиторий links-detector на GitHub, чтобы увидеть полный исходный код части пользовательского интерфейса приложения.

⚠️ В настоящее время приложение находится в экспериментальной |/Альфа-стадии и имеет много проблем и ограничений . Поэтому не повышайте уровень своих ожиданий слишком высоко, пока эти проблемы не будут решены..

В качестве следующих шагов, которые могут улучшить производительность модели, мы могли бы сделать следующее:

  • Расширьте набор данных с помощью большего количества типов ссылок ( http:// , tcp:// , ftp:// и т. Д)
  • Расширил набор данных изображениями с темным фоном
  • Расширьте набор данных с помощью подчеркнутых ссылок
  • Расширьте набор данных примерами различных шрифтов и лигатур
  • и т.д.

Несмотря на то, что модель еще многое предстоит улучшить, чтобы приблизить ее к состоянию готовности к производству, я все еще надеюсь, что эта статья была полезна для вас и дала вам некоторые рекомендации и вдохновение для игры с вашими пользовательскими детекторами объектов.

Счастливого обучения, ребята!