ML.Net Задача регрессии

Обучать модель регрессии будем на примере решения классической задачи по предсказанию выживаемости пассажиров Титаника.

Подготовка IDE

Первым делом необходимо обновить Visual Studio и установить расширение ML.NET Model Builder (ссылка).

На момент написания статьи последняя версия библиотеки Microsoft.ML 1.6.0. Она написана под .NET Standard 2.0, таким образом ее можно использовать в проектах .NET Framework 4.6.1+\.NET Core 2.0+\.NET 5.0+.

Создание проекта

Я создал проект .net Framework 4.7.2. Расширение ML.NET Model Builder находится в стадии активной разработки и для его включения нужно перейти в раздел меню средства\параметры\окружение\функции предварительной версии и поставить галочку:

ML.Net Задача регрессии

Обучение модели

Добавим модель машинного обучения в проект. Для этого откройте меню «Добавить» и выберите «Машинной обучение»

ML.Net Задача регрессии

Назовем модель «TitanikModel»

ML.Net Задача регрессии

После добавления должен открыться Model Builder.

Мы решаем задачу регрессии, поэтому выберем сценарий «Прогнозирование значений».

ML.Net Задача регрессии

Я предварительно загрузил файлы Train.csv и Test.csv. Укажем путь до первого в разделе «Данные». Выберем прогнозируемый столбец «Survived».

ML.Net Задача регрессии

Можно более детально настроить данные перейдя в расширенные параметры:

ML.Net Задача регрессии

Идентификатор пассажира, ФИО, номер билета и номер каюты я счел лишними и проставил напротив этих столбцов значение «Ignore».

Перейдем в раздел «обучение». Здесь нужно указать время обучения. Чем больше это значение, тем точнее будет модель. Я оставлю 10 секунд и запущу обучение.

ML.Net Задача регрессии

После обучения в проект добавится 3 файла:

TitanikModel.consumption.cs – описание входных\выходных данных, реализация функции Predict

TitanikModel.training.cs – код обучения модели

TitanikModel.zip – модель

Использование модели

Попробуем предсказать выживаемость для тестовой выборки

string testPath = "test.csv"; string resultPath="predict.csv"; //индексы колонок в .csv файле int colPassengerId=0, colPclass=1, colName=2, colSex=3, colAge=4, colSibSp=5, colParch=6, colTicket=7, colFare=8, colCabin=9, colEmbarked=10; float thresold = 0.5f; Regex CSVParser = new Regex(",(?=(?:[^\"]*\"[^\"]*\")*(?![^\"]*\"))"); using (StreamWriter writer = File.CreateText(resultPath)) { writer.WriteLine("PassengerId,Survived"); //заполняю заголовок string[] testRows = File.ReadAllLines(testPath); for(int i=1;i<testRows.Count();i++) { string[] row = CSVParser.Split(testRows[i]); //Подготовка данных TitanikModel.ModelInput sampleData = new TitanikModel.ModelInput() { Pclass = float.Parse(row[colPclass], CultureInfo.InvariantCulture.NumberFormat), Sex = row[colSex], Age = string.IsNullOrEmpty(row[colAge])?0f:float.Parse(row[colAge], CultureInfo.InvariantCulture.NumberFormat), SibSp = float.Parse(row[colSibSp], CultureInfo.InvariantCulture.NumberFormat), Parch = float.Parse(row[colParch], CultureInfo.InvariantCulture.NumberFormat), Fare = string.IsNullOrEmpty(row[colFare]) ? 0f : float.Parse(row[colFare], CultureInfo.InvariantCulture.NumberFormat), Embarked = row[colEmbarked], }; float result = TitanikModel.Predict(sampleData).Score; //Predict string toWrite = $"{row[colPassengerId]},{(result > thresold ? "1" : "0")}"; Console.WriteLine(toWrite); writer.WriteLine(toWrite); //если predict > порогового значения - значит выжил } } Console.WriteLine("Завершено"); Console.ReadKey();

Этот код предскажет значение столбца Survived для каждой строки выборки и сохранит его в predict.csv. Загрузим этот файл на Kaggle

ML.Net Задача регрессии

Я никак не подготавливал данные, не экспериментировал с пороговым значением (threshold) и получил результат статистически выше среднего. Библиотека Micorsoft.ML показала себя с хорошей стороны: наличие документации, постоянные обновления, хорошая производительность и графический интерфейс для обучения моделей.

Исходный код проекта можно скачать с Github.

66 показов
594594 открытия
Начать дискуссию