Untitled

 avatar
unknown
plain_text
a year ago
3.2 kB
18
Indexable
import pandas as pd
from atom import ATOMClassifier
import joblib
import argparse
import os
import multiprocessing
from functools import partial
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Manager

def load_and_preprocess_data(csv_file):
    print(f"Loading CSV file: {csv_file}")
    X = pd.read_csv(csv_file, index_col=0)
    X['y'] = [0 if row_name.startswith('1') else (1 if row_name.startswith('2') else 2) for row_name in X.index]
    return X

def parallel_feature_selection(atom, solver, n_features=10):
    print(f"Performing feature selection with solver: {solver}")
    atom.feature_selection(strategy="sfs", solver=solver, n_features=n_features, random_state=0)
    return atom

def train_model(atom, model, solver):
    model_name = f"{model}_{solver}"
    print(f"Training model: {model_name}")
    atom.run(models=[model_name], n_trials=50, metric="AUC", n_bootstrap=5)
    return atom

def process_solver(solver, X, models, results_dict):
    print(f"Starting process for solver: {solver}")
    
    atom = ATOMClassifier(X, y="y", test_size=0.2, verbose=2, random_state=1, index=True)
    atom.branch = "main"
    atom.branch = f"sfs_{solver}"
    
    atom = parallel_feature_selection(atom, solver)
    
    with ProcessPoolExecutor(max_workers=min(len(models), multiprocessing.cpu_count())) as executor:
        future_to_model = {executor.submit(train_model, atom, model, solver): model for model in models}
        for future in as_completed(future_to_model):
            model = future_to_model[future]
            try:
                atom = future.result()
            except Exception as exc:
                print(f'{model} generated an exception: {exc}')
    
    # Store results in the shared dictionary
    results_dict[solver] = atom
    print(f"Results for solver {solver} stored")

def run_atom_classification(csv_file):
    X = load_and_preprocess_data(csv_file)
    
    solvers = ["LGB", "LR", "RF", "LDA", "XGB"]
    #solvers = ["LR", "RF", "LDA", "XGB"]
    models = ["RF", "XGB", "LDA", "GBM", "LR", "SVM"]
    
    base_filename = os.path.splitext(os.path.basename(csv_file))[0]
    os.makedirs("results", exist_ok=True)
    
    # Create a manager to share the results dictionary across processes
    with Manager() as manager:
        results_dict = manager.dict()
        
        process_solver_partial = partial(process_solver, X=X, models=models, results_dict=results_dict)
        
        with ProcessPoolExecutor(max_workers=len(solvers)) as executor:
            executor.map(process_solver_partial, solvers)
        
        print("All processes completed")
        
        # Save all results to a single file
        joblib_filename = f"results/{base_filename}_atom_sfm_all.pkl"
        joblib.dump(dict(results_dict), joblib_filename)
        print(f"All results saved to {joblib_filename}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run ATOM classification with parallel feature selection and model training.")
    parser.add_argument("csv_file", type=str, help="Path to the input CSV file.")
    args = parser.parse_args()
    run_atom_classification(args.csv_file)
Editor is loading...
Leave a Comment