Better random forests
This commit is contained in:
1
classifiers/__init__.py
Normal file
1
classifiers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from classifiers.evaluation import *
|
||||
35
classifiers/evaluation.py
Normal file
35
classifiers/evaluation.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from data.data_processing import process_train_test
|
||||
from sklearn.model_selection import cross_validate, ParameterGrid
|
||||
|
||||
|
||||
def crossvalidate_params(classifier, params, experiments_train, metadata_train, y_train, cv=5):
|
||||
process_params = {key: params[key] for key in ['baseline_lam', 'baseline_p', 'smooth_window_length', 'smooth_polyorder']}
|
||||
classifier_params = {key: params[key] for key in params.keys() if key not in ['baseline_lam', 'baseline_p', 'smooth_window_length', 'smooth_polyorder']}
|
||||
X_train, _ = process_train_test(process_params, experiments_train, metadata_train, scale=True)
|
||||
clf = classifier(**classifier_params)
|
||||
return cross_validate(clf, X_train, y_train.to_numpy().ravel(), cv=cv, return_estimator=True)
|
||||
|
||||
|
||||
def param_grid_search(classifier, param_grid, experiments_train, metadata_train, y_train, cv=5):
|
||||
results = []
|
||||
for params in ParameterGrid(param_grid):
|
||||
try:
|
||||
results.append([params, crossvalidate_params(classifier, params, experiments_train, metadata_train, y_train, cv=cv)])
|
||||
print(results[-1])
|
||||
except Exception as e:
|
||||
pass # print(params, e)
|
||||
return results
|
||||
|
||||
|
||||
def evaluate_classifier_params(classifier, params, X_train, y_train, X_test, y_test, iters=10):
|
||||
train_score_mean = 0
|
||||
test_score_mean = 0
|
||||
for i in range(iters):
|
||||
clf = classifier(**params)
|
||||
clf.fit(X_train, y_train.to_numpy().ravel())
|
||||
train_score_mean += clf.score(X_train, y_train.to_numpy().ravel())
|
||||
test_score_mean += clf.score(X_test, y_test.to_numpy().ravel())
|
||||
return train_score_mean / iters, test_score_mean / iters
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user