ensemble

 avatar
unknown
python
2 years ago
2.0 kB
3
Indexable
# In[]

models = []
models.append(('LR', LogisticRegression(solver='lbfgs', max_iter=200)))
models.append(('KNN', KNeighborsClassifier()))
models.append(('NB', GaussianNB()))
models.append(('RF', RandomForestClassifier(max_depth=10, n_estimators=40)))
models.append(('SVM', SVC(kernel='linear', C=2, probability=True)))

# evaluate each model in turn
results = []
names = []
scoring = 'recall'

for name, model in models : 
    model.fit(X_train, y_train)
    tn, fp, fn, tp = confusion_matrix(y_test, model.predict(X_test)).ravel()
    profit = 1* (tp-fn) + 5* (tn-fp)
    print('===============')
    print('Model : ', name)
    print('Accuracy : ', model.score(X_test,y_test))
    print('Benifits : ', profit)
    print('===============\n')

list_y_pred = [model[1].predict_proba(X_test)[:, 1] for model in models]

optModels = [None, None, None]
accMax = -1

for idx1 in range(len(models)) : 
    for idx2 in range(idx1+1, len(models)) :
        accMaxInner = -1
        optP = None
        for p in np.arange(0, 1.01, 0.01) : 
            y_pred = [p* list_y_pred[idx1][i] + (1-p)* list_y_pred[idx2][i] for i in range(len(y_test))]
            y_pred = [(lambda y: 1 if y >= 0.5 else 0)(y) for y in y_pred]
            acc = (y_pred-y_test).tolist().count(0)/len(y_test)
            if acc > accMaxInner : 
                optP = p
                accMaxInner = acc
        if accMaxInner > accMax : 
            accMax = accMaxInner
            optModels = [idx1, idx2, optP]
            
y_pred = optModels[-1]* list_y_pred[optModels[0]] + (1-optModels[-1])* list_y_pred[optModels[1]]
y_pred = [(lambda y: 1 if y >= 0.5 else 0)(y) for y in y_pred]            
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
profit = 1* (tp-fn) + 5* (tn-fp)
print('===============')
print('Model : ', models[optModels[0]][0], '+', models[optModels[1]][0], ', weight = ', round(optModels[-1], 2), ':', round(1-optModels[-1], 2))
print('Accuracy : ', accMax)
print('Benifits : ', profit)
print('===============\n')  
Editor is loading...