Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
2.6 kB
0
Indexable
Never
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, recall_score, precision_score, confusion_matrix

# Преобразование целевой переменной в бинарный формат
ferma_main['vkus_milk'] = ferma_main['vkus_milk'].apply(lambda x: 1 if x == 'вкусно' else 0)

# Выделение в отдельные переменные целевого признака и входных признаков
X = ferma_main.drop(columns=['vkus_milk'])
y = ferma_main['vkus_milk']

# Разделение данных на тренировочные и тестовые, зафиксируем random_state
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=RANDOM_STATE)

# Список категориальных признаков
cat_col_names = ['poroda', 'type_pastbicha', 'poroda_dad_bull', 'age', 'name_dad']

# Список количественных признаков
num_col_names = ['eke', 'syroi_protein_g', 'spo', 'ghirnost_per', 'belok_per']

# Подготовка признаков (масштабирование и кодирование)
encoder = OneHotEncoder(drop='first', sparse=False)
scaler = StandardScaler()

# Создание transformer для категориальных и количественных признаков
preprocessor = ColumnTransformer(
    transformers=[
        ('cat', encoder, cat_col_names),
        ('num', scaler, num_col_names)
    ])

# Создание пайплайна с преобразованием и обучением модели
model_lr = LogisticRegression(random_state=RANDOM_STATE, max_iter=1000)
pipeline = Pipeline([
    ('preprocessor', preprocessor),
    ('model', model_lr)
])

# Обучение модели
pipeline.fit(X_train, y_train)

# Получение предсказаний
predictions = pipeline.predict(X_test)

# Оценка качества модели
accuracy = accuracy_score(y_test, predictions)
recall = recall_score(y_test, predictions)
precision = precision_score(y_test, predictions)
conf_matrix = confusion_matrix(y_test, predictions)

# Вывод результатов
print("Accuracy:", accuracy)
print("Recall:", recall)
print("Precision:", precision)
print("Confusion Matrix:")
print(conf_matrix)
Leave a Comment