Untitled

 avatar
unknown
python
9 months ago
2.1 kB
7
Indexable
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

# Load the data
train = pd.read_csv('/mnt/data/train.csv')
test = pd.read_csv('/mnt/data/test.csv')
sample_submission = pd.read_csv('/mnt/data/sample_submission.csv')

# Preprocessing
def preprocess_data(df):
    # Convert launch_date to datetime and extract year
    df['launch_date'] = pd.to_datetime(df['launch_date'])
    df['launch_year'] = df['launch_date'].dt.year
    
    return df.drop(columns=['launch_date'])

train = preprocess_data(train)
test = preprocess_data(test)

# Define the target and features
X = train.drop(columns=['score'])
y = train['score']
X_test = test

# Define column transformer
preprocessor = ColumnTransformer(
    transformers=[
        ('med_review', TfidfVectorizer(max_features=1000), 'medicine_review'),
        ('cat', OneHotEncoder(handle_unknown='ignore'), ['disease_type']),
        ('num', 'passthrough', ['market_value', 'launch_year'])
    ])

# Define the model pipeline
model = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('regressor', RandomForestRegressor(n_estimators=100, random_state=42))
])

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the model
model.fit(X_train, y_train)

# Predict and evaluate the model
y_pred = model.predict(X_val)
print(f'Validation RMSE: {np.sqrt(mean_squared_error(y_val, y_pred))}')

# Predict on the test set
test_predictions = model.predict(X_test)

# Create the submission file
submission = pd.DataFrame({
    'medicine_no': test['medicine_no'],
    'score': test_predictions
})

submission.to_csv('/mnt/data/submission.csv', index=False)
Editor is loading...
Leave a Comment