Untitled
unknown
plain_text
11 days ago
2.8 kB
2
Indexable
Never
import pandas as pd from atom import ATOMClassifier import joblib import argparse import os import multiprocessing from functools import partial def process_solver(solver, X, models, base_filename): print(f"Starting process for solver: {solver}") # Initialize ATOM classifier atom = ATOMClassifier(X, y="y", test_size=0.2, verbose=2, random_state=1, index=True) # Create a new branch for the solver atom.branch = f"sfs_{solver}" print(f"Created branch: {atom.branch}") # Perform feature selection using the current solver print(f"Performing feature selection with solver: {solver}") atom.feature_selection(strategy="sfs", solver=solver, n_features=10, random_state=0) # Append solver name to each model for distinction models_named = [f"{model}_{solver}" for model in models] # Train and evaluate models print(f"Training and evaluating models for solver: {solver}") atom.run(models=models_named, n_trials=50, metric="AUC", n_bootstrap=5) # Save the ATOM object with the trained models and results joblib_filename = f"results/{base_filename}_atom_sfm_{solver}.pkl" joblib.dump(atom, joblib_filename) print(f"Model saved to {joblib_filename}") def run_atom_classification(csv_file): print(f"Loading CSV file: {csv_file}") # Load the CSV file X = pd.read_csv(csv_file, index_col=0) # Assign values to the 'y' column based on the starting character of the index X['y'] = [0 if row_name.startswith('1') else (1 if row_name.startswith('2') else 2) for row_name in X.index] solvers = ["LGB", "LR", "RF", "LDA", "XGB"] models = ["RF", "XGB", "LDA", "GBM", "LR", "SVM"] # Generate a derivative filename from the CSV filename base_filename = os.path.splitext(os.path.basename(csv_file))[0] # Ensure results directory exists os.makedirs("results", exist_ok=True) # Create a partial function with fixed arguments process_solver_partial = partial(process_solver, X=X, models=models, base_filename=base_filename) # Create a process pool print("Starting multiprocessing pool") with multiprocessing.Pool(processes=len(solvers)) as pool: # Map the process_solver function to each solver pool.map(process_solver_partial, solvers) print("All processes completed") if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description="Run ATOM classification with feature selection and model training.") parser.add_argument("csv_file", type=str, help="Path to the input CSV file.") # Parse the arguments args = parser.parse_args() # Call the function with the provided CSV file path run_atom_classification(args.csv_file)
Leave a Comment