Обнаружение таблиц на сканах с использованием Fast-RСNN на PyTorch
Компьютерное зрение — очень интересная и востребованная область искусственного интеллекта. Компьютерное зрение сейчас используется повсеместно, начиная от сегментации медицинских изображений, заканчивая управлением автомобилем. Сейчас мы коснемся одной из основных задач компьютерного зрения — обнаружения объектов.
Множество информации хранится в форме сканированных документов. Извлечь текст можно при помощи методов оптического распознавания. Но если нас интересует структурированная информация, например, таблицы, то ограничиться методами оптического распознавания мы не сможем.
Для того, чтобы извлечь информацию из таблицы, нужно в первую очередь эту таблицу найти. С этой проблемой нам поможет справиться PyTorch и готовая модель Fast-RCNN из библиотеки torchvision. Примеры ее использования, можно прочитать в официальной документации или, например, в этом руководстве.
Данные для обучения возьмем с github. Все действия будем производить в облачном сервисе Google Colaboratory.
Извлечем нужные библиотеки.
Скопируем репозиторий из github с данными для обучения и валидации.
В репозитории мы найдем два csv-файла и папку со сканами различных таблиц. Загрузим их в pandas.DataFrame и посмотрим, что внутри.
Число изображений в тренировочной выборке: 338
Число изображений в валидационной выборке: 65
Внутри мы видим название картинки, координаты границ рамки, огранивающую таблицу, и класс. К слову, в нашем примере единственным классом будет таблица.
Объединим координаты рамки в одну колонку. И поскольку на одном скане может присутствовать несколько таблиц, сгруппируем данные по названию картинки и обернем их в список. Также заменим класс таблицы на единицу.
После таких манипуляций мы получим следующую таблицу:
Давайте взглянем на какое-нибудь изображение из тренировочного набора.
Создадим экземпляр класса torch.utils.data.Dataset. В конструктор класса передадим получившуюся таблицу и путь до папки с изображениями. Переопределим методы __len__ и __getitem__. На выходе класс будет возвращать по индексу изображение в формате torch.tensor и словарь target, в котором будет информация о искомом объекте: координаты рамки и класс.
Далее напишем функцию, которая будет создавать модель. Исходная модель рассчитана на детекцию 91 класса. Нам нужно детектировать только один объект, поэтому необходимо заменить box_predictor внутри модели. Меняем количество выходов на 2, потому что нулевым классом должен быть фон.
Также создадим вспомогательную функцию, которая будет применяться к бачам при итерации по torch.utils.data.DataLoader. Она поможет избежать ошибок c размерностями внутри бачей.
Теперь, когда все подготовительные этапы завершены, мы можем приступать к основной части.В начале определим устройство, на котором будем обучать модель. Создадим саму модель, выберем оптимизатор и регулятор скорости обучения, который каждую эпоху будет уменьшать коэффициент скорости обучения. Функции потерь уже заданы внутри модели, поэтому нам не нужно их прописывать вручную.Обернем наши данные в torch.utils.data.DataLoader, установим размер бача в три изображения. Изображения из тренировочной выборки будем перемешивать перед подачей в модель.
Напишем функции для тренировки и валидации модели. Единственным отличием между ними будет отсутствие расчета градиента при валидации.Внутри функций мы переводим все данные внутри бача на устройство, на котором будут производится расчеты. Затем подаем полученные тензоры в модель и получаем словарь со значениями функций потерь. После этого посчитаем их сумму и для удобства запишем ее в переменную running_loss для отслеживания прогресса обучения. В конце каждой эпохи будем выводить среднее значение функций потерь.
Начинаем обучение.Создадим два списка, куда будем сохранять значения функций потерь после каждой эпохи на тренировке и валидации
Построим график истории обучения модели
Посмотрим, на сколько хорошо обучилась модель. Применим алгоритм non-maximum suppression, который реализован в библиотеке torchvision. Он объединяет похожие рамки на основе их взаимного пересечения.
Сохраним веса модели
Для последующего доступа к модели необходимо выполнить следующие команды:
Теперь у нас есть обученная модель, которая может находить таблицы на сканах. Конечно, результат не всегда будет соответствовать ожиданию. Для улучшения результата можно обучить модель на более большом объеме данных. Также можно вообще поменять модель или воспользоваться уже готовым решением.