Untitled

mail@pastecode.io avatar
unknown
plain_text
2 months ago
2.1 kB
4
Indexable
Never
def generate_lstm_data(path, cols=CSV_COLUMNS + [DATETIME_COLUMN], label_column=LABEL_COLUMN, y_column=-1,
                       norm_cols=cols_to_norm, history_size=LSTM_HISTORY, target_size=LSTM_FUTURE_TARGET,
                       step=LSTM_STEP, cities=CATEGORIES['city'], index_col=DATETIME_COLUMN, single_step=False,
                       train_frac=TRAIN_DATASET_FRAC, train_scale=None, scale_cols=[], prepend_with_file=None,
                       extra_columns=[], group_by_column=False):
    dataset = extract_data(path, cols, categorical_columns=None)
    if prepend_with_file is not None:
        pre_dataset = extract_data(prepend_with_file, cols, categorical_columns=None)

    datasets = []

    scale = None

    if label_column not in dataset.columns:
        dataset[label_column] = pd.Series(np.zeros(len(dataset[DATETIME_COLUMN])), index=dataset.index)

    for city_name in cities:
        city_data = dataset[dataset['city'] == city_name]
        if prepend_with_file is not None:
            pre_data = pre_dataset[pre_dataset['city'] == city_name].iloc[-(history_size+1):]
            city_data = pd.concat([pre_data, city_data], ignore_index=True)
        if train_scale is None:
            train_scale = city_data.copy()
        city_data.index = city_data[index_col]
        city_data, scale = preproc_data(city_data[norm_cols + scale_cols + extra_columns + [label_column]], norm_cols=norm_cols,
                                        scale_cols=scale_cols, train_scale=train_scale)
        datasets.append(city_data.values)

    datasets = list(map(lambda x: generate_multivariate_data(x, target_index=y_column, single_step=single_step,
                                                             history_size=history_size, target_size=target_size,
                                                             step=step, train_frac=train_frac), datasets))

    if group_by_column:
        datasets = group_data_by_columns(datasets, columns=norm_cols + scale_cols + extra_columns)
        return datasets, scale, norm_cols + scale_cols + extra_columns
    return datasets, scale
Leave a Comment