Untitled
unknown
plain_text
a year ago
2.8 kB
8
Indexable
import pandas as pd
from atom import ATOMClassifier
import joblib
import argparse
import os
import multiprocessing
from functools import partial
def process_solver(solver, X, models, base_filename):
print(f"Starting process for solver: {solver}")
# Initialize ATOM classifier
atom = ATOMClassifier(X, y="y", test_size=0.2, verbose=2, random_state=1, index=True)
# Create a new branch for the solver
atom.branch = f"sfs_{solver}"
print(f"Created branch: {atom.branch}")
# Perform feature selection using the current solver
print(f"Performing feature selection with solver: {solver}")
atom.feature_selection(strategy="sfs", solver=solver, n_features=10, random_state=0)
# Append solver name to each model for distinction
models_named = [f"{model}_{solver}" for model in models]
# Train and evaluate models
print(f"Training and evaluating models for solver: {solver}")
atom.run(models=models_named, n_trials=50, metric="AUC", n_bootstrap=5)
# Save the ATOM object with the trained models and results
joblib_filename = f"results/{base_filename}_atom_sfm_{solver}.pkl"
joblib.dump(atom, joblib_filename)
print(f"Model saved to {joblib_filename}")
def run_atom_classification(csv_file):
print(f"Loading CSV file: {csv_file}")
# Load the CSV file
X = pd.read_csv(csv_file, index_col=0)
# Assign values to the 'y' column based on the starting character of the index
X['y'] = [0 if row_name.startswith('1') else (1 if row_name.startswith('2') else 2) for row_name in X.index]
solvers = ["LGB", "LR", "RF", "LDA", "XGB"]
models = ["RF", "XGB", "LDA", "GBM", "LR", "SVM"]
# Generate a derivative filename from the CSV filename
base_filename = os.path.splitext(os.path.basename(csv_file))[0]
# Ensure results directory exists
os.makedirs("results", exist_ok=True)
# Create a partial function with fixed arguments
process_solver_partial = partial(process_solver, X=X, models=models, base_filename=base_filename)
# Create a process pool
print("Starting multiprocessing pool")
with multiprocessing.Pool(processes=len(solvers)) as pool:
# Map the process_solver function to each solver
pool.map(process_solver_partial, solvers)
print("All processes completed")
if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(description="Run ATOM classification with feature selection and model training.")
parser.add_argument("csv_file", type=str, help="Path to the input CSV file.")
# Parse the arguments
args = parser.parse_args()
# Call the function with the provided CSV file path
run_atom_classification(args.csv_file)Editor is loading...
Leave a Comment