Обзор универсальных оптимизаций нейросетей

Для решения задачи снижения потребления памяти и повышения скорости расчета моделей нейронных сетей без существенной потери точности используется различные методы оптимизации.

Бывает так, что очень большая модель не помещается на видеоадаптере и требуется 250 ГБ оперативной памяти. В этой связи надо находить баланс, можно уменьшить размер модели в сто раз, при этом уменьшить точность всего на половину процентного пункта. Например, Bert можно сжать с 560 Мб до 2 Мб, почти без потери качества.

Рассмотрим три наиболее часто встречающихся метода оптимизации размера сети такие как дистилляция, квантизация и прунинг.

Основная идея дистилляции, это обучение маленькой модели (модели студента) с помощью предобученной большой модели (модели учителя). Пусть у нас есть предобученная модель «учитель», она выдаёт логиты, это последний слой до Softmax. И есть модель студента, только необученная, которая выдает логиты, такой же размерности. Далее мы логиты учителя и логиты студента, прогоняем через Softmax с температурой. Температура нужна для сглаживания распределения.

Softmax c температурой

Логиты модели-студента отправляются на Softmax как с температурой, так и без температуры.

При Softmax без температуры мы получаем классическую модель-студент обучения. И обычную функцию потерь. Эту функцию мы возьмем с коэффициентом альфа.

При Softmax с температурой мы сравниваем модель учителя и модель студента дивергенцией Кульбака – Лейблера, и этот loss берем с коэффициентом бетта. Потом складываем два loss как на схеме выше и берем от него backward. (обратное распределение потерь). Пример кода представлен ниже.

def loss_fn_kd(outputs, labels, teacher_outputs, params): #дивергенция Кульбака – Лейблера alpha , betta = params.alpha , params.betta T = params.temperature KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ F.cross_entropy(outputs, labels) * (betta) return KD_loss with torch.no_grad(): output_teacher_batch = teacher_model(train_batch) if params.cuda: output_teacher_batch = output_teacher_batch.cuda(async=True) loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params) optimizer.zero_grad() loss.backward()

Иначе говоря, вместо классической функции потерь для обучения модели-студента, мы берем средневзвешенную функцию потерь с реальными данными и с теми что выдает модель учитель.

Квантизация

Идея квантизации предельно проста, все операции проводятся в целочисленных значениях. Чаще всего это накладывается на слой или на какую-то часть сети. Ряд слоёв работает в int8, благодаря чему он потребляет очень мало вычислений и памяти, последний слой, float32. Самый простой способ — это округление весов к определённым значениям.

Квантизовать можно до, после и во время обучения. Приведем пример простой квантизации для однослойной модели и четырех нейронов с четырьмя признаками

import torch # определяем модель class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.fc = torch.nn.Linear(4, 4) def forward(self, x): x = self.fc(x) return x # создаём модель model_fp32 = M() # создаём квантизованную модель с динамической квантизацией model_int8 = torch.quantization.quantize_dynamic( model_fp32, # оригинальная модель {torch.nn.Linear}, # количество слоев для квантизации dtype=torch.qint8) # количество бит # Запуск модели input_fp32 = torch.randn(4, 4, 4, 4) res = model_int8(input_fp32).

Кроме того, можно квантизовать уже дистиллированную модель, что ещё больше сэкономит время и вычислительные ресурсы, практически без серьёзной потери качества.

Прунинг

Прунинг нейронной сети это метод сжатия модели, путем удаления части параметров.

Прунинг, как и квантизация, может быть использован до, после и во время обучения на основе:

● Амплитуды весов, активаций, градиентов, гессианов

● Заданных правил, Байесовских подходов

● Реинициализации, дообучения

Самый простой способ прунинга нейронной сети есть под капотом PyTorch, и он реализуется всего одной строчкой кода.

import torch from torch.nn.utils import prune prune.random_unstructed(module , name = ‘weight’ , amount = 0.3)

В заключение хотелось бы сказать, что все три способа оптимизации можно комбинировать, тем самым получать результаты сопоставимые по качеству с огромными моделями, при этом затрачивая гораздо меньше ресурсов как временных, так и информационных.

0
Комментарии
Читать все 0 комментариев
null