Untitled

mail@pastecode.io avatar
unknown
plain_text
11 days ago
2.8 kB
2
Indexable
Never
import pandas as pd
from atom import ATOMClassifier
import joblib
import argparse
import os
import multiprocessing
from functools import partial

def process_solver(solver, X, models, base_filename):
    print(f"Starting process for solver: {solver}")
    
    # Initialize ATOM classifier
    atom = ATOMClassifier(X, y="y", test_size=0.2, verbose=2, random_state=1, index=True)
    
    # Create a new branch for the solver
    atom.branch = f"sfs_{solver}"
    print(f"Created branch: {atom.branch}")
    
    # Perform feature selection using the current solver
    print(f"Performing feature selection with solver: {solver}")
    atom.feature_selection(strategy="sfs", solver=solver, n_features=10, random_state=0)
    
    # Append solver name to each model for distinction
    models_named = [f"{model}_{solver}" for model in models]
    
    # Train and evaluate models
    print(f"Training and evaluating models for solver: {solver}")
    atom.run(models=models_named, n_trials=50, metric="AUC", n_bootstrap=5)
    
    # Save the ATOM object with the trained models and results
    joblib_filename = f"results/{base_filename}_atom_sfm_{solver}.pkl"
    joblib.dump(atom, joblib_filename)
    print(f"Model saved to {joblib_filename}")

def run_atom_classification(csv_file):
    print(f"Loading CSV file: {csv_file}")
    # Load the CSV file
    X = pd.read_csv(csv_file, index_col=0)
    
    # Assign values to the 'y' column based on the starting character of the index
    X['y'] = [0 if row_name.startswith('1') else (1 if row_name.startswith('2') else 2) for row_name in X.index]
    
    solvers = ["LGB", "LR", "RF", "LDA", "XGB"]
    models = ["RF", "XGB", "LDA", "GBM", "LR", "SVM"]
    
    # Generate a derivative filename from the CSV filename
    base_filename = os.path.splitext(os.path.basename(csv_file))[0]
    
    # Ensure results directory exists
    os.makedirs("results", exist_ok=True)
    
    # Create a partial function with fixed arguments
    process_solver_partial = partial(process_solver, X=X, models=models, base_filename=base_filename)
    
    # Create a process pool
    print("Starting multiprocessing pool")
    with multiprocessing.Pool(processes=len(solvers)) as pool:
        # Map the process_solver function to each solver
        pool.map(process_solver_partial, solvers)
    
    print("All processes completed")

if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run ATOM classification with feature selection and model training.")
    parser.add_argument("csv_file", type=str, help="Path to the input CSV file.")
    
    # Parse the arguments
    args = parser.parse_args()
    
    # Call the function with the provided CSV file path
    run_atom_classification(args.csv_file)
Leave a Comment