Квадратичный дискриминантный анализ в Python (шаг за шагом)


Квадратичный дискриминантный анализ — это метод, который вы можете использовать, когда у вас есть набор переменных-предикторов, и вы хотите классифицировать переменную ответа по двум или более классам.

Он считается нелинейным эквивалентом линейного дискриминантного анализа .

В этом руководстве представлен пошаговый пример выполнения квадратичного дискриминантного анализа в Python.

Шаг 1: Загрузите необходимые библиотеки

Во-первых, мы загрузим необходимые функции и библиотеки для этого примера:

from sklearn. model_selection import train_test_split
from sklearn. model_selection import RepeatedStratifiedKFold
from sklearn. model_selection import cross_val_score
from sklearn. discriminant_analysis import QuadraticDiscriminantAnalysis 
from sklearn import datasets
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

Шаг 2: Загрузите данные

В этом примере мы будем использовать набор данных iris из библиотеки sklearn. В следующем коде показано, как загрузить этот набор данных и преобразовать его в кадр данных pandas, чтобы с ним было легко работать:

#load *iris* dataset
iris = datasets. load_iris ()

#convert dataset to pandas DataFrame
df = pd.DataFrame(data = np.c_[iris['data'], iris['target']],
 columns = iris['feature_names'] + ['target'])
df['species'] = pd.Categorical.from_codes (iris.target, iris.target_names)
df.columns = ['s_length', 's_width', 'p_length', 'p_width', 'target', 'species']

#view first six rows of DataFrame
df.head ()

 s_length s_width p_length p_width target species
0 5.1 3.5 1.4 0.2 0.0 setosa
1 4.9 3.0 1.4 0.2 0.0 setosa
2 4.7 3.2 1.3 0.2 0.0 setosa
3 4.6 3.1 1.5 0.2 0.0 setosa
4 5.0 3.6 1.4 0.2 0.0 setosa

#find how many total observations are in dataset
len(df.index)

150

Мы видим, что набор данных содержит всего 150 наблюдений.

В этом примере мы построим модель квадратичного дискриминантного анализа, чтобы классифицировать, к какому виду принадлежит данный цветок.

Мы будем использовать следующие переменные-предикторы в модели:

  • Длина чашелистика
  • Ширина чашелистика
  • Длина лепестка
  • Ширина лепестка

И мы будем использовать их для прогнозирования переменной отклика Species , которая принимает следующие три потенциальных класса:

  • сетоза
  • лишай
  • виргиния

Шаг 3: Соответствуйте модели QDA

Далее мы подгоним модель QDA к нашим данным, используя функцию QuadraticDiscriminantAnalsyis из sklearn:

#define predictor and response variables
X = df[['s_length', 's_width', 'p_length', 'p_width']]
y = df['species']

#Fit the QDA model
model = QuadraticDiscriminantAnalysis()
model. fit (X, y)

Шаг 4: Используйте модель для прогнозирования

После того, как мы подогнали модель, используя наши данные, мы можем оценить, насколько хорошо модель работает, используя повторную стратифицированную k-кратную перекрестную проверку.

В этом примере мы будем использовать 10 сгибов и 3 повтора:

#Define method to evaluate model
cv = RepeatedStratifiedKFold(n_splits= 10 , n_repeats= 3 , random_state= 1 )

#evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
print(np.mean (scores))  

0.97333333333334

Мы видим, что модель показала среднюю точность 97,33% .

Мы также можем использовать модель, чтобы предсказать, к какому классу принадлежит новый цветок, на основе входных значений:

#define new observation
new = [5, 3, 1, .4]

#predict which class the new observation belongs to
model. predict([new])

array(['setosa'], dtype='<U10')

Мы видим, что модель предсказывает, что это новое наблюдение принадлежит виду, называемому setosa .

Вы можете найти полный код Python, использованный в этом руководстве , здесь .

Замечательно! Вы успешно подписались.
Добро пожаловать обратно! Вы успешно вошли
Вы успешно подписались на кодкамп.
Срок действия вашей ссылки истек.
Ура! Проверьте свою электронную почту на наличие волшебной ссылки для входа.
Успех! Ваша платежная информация обновлена.
Ваша платежная информация не была обновлена.