Untitled

mail@pastecode.io avatar
unknown
plain_text
2 months ago
1.1 kB
1
Indexable
Never
async def _train_model(step: Step) -> types.StepMessage:
    project: Project = step.project

    merged_dataset_artifact = await get_artifact(step, "Merged dataset")
    dataset_id = json.loads(merged_dataset_artifact.embed).get("id")
    dataset = await Dataset.objects.aget(id=int(dataset_id))

    dataset_metadata = dataset.content.metadata.data_dictionary
    selected_fields_artifact = await get_artifact(step, "Field selection")
    selected_fields = json.loads(selected_fields_artifact.embed)

    feature_metadata_path = _create_feature_metadata_file(dataset_metadata, selected_fields, project.team)

    use_case_artifact = await get_artifact(step, "Use case")
    model_name, target = await _gen_model_target(use_case_artifact.embed, dataset)

    current_time = datetime.now().strftime("%Y_%m_%d_%H_%M")

    model = await train_model(project.team, dataset.file.path, f"{current_time}_{model_name}", target, feature_metadata_path)

    artifact = types.MessageEmbed(type="mlmodel", embed=str(model.pk), description="")

    return types.StepMessage(
        role="assistant",
        message="",
        embeds=[artifact],
    )
Leave a Comment