Untitled
unknown
plain_text
7 months ago
15 kB
4
Indexable
import logging
from io import BytesIO
from typing import List
import openpyxl
import pandas as pd
import strawberry
from sqlalchemy.sql import text
from strawberry import mutation, type
from strawberry.file_uploads import Upload
from app.graphql.schema import Info
from app.graphql.types.group_type import GroupType
from app.graphql.types.lov_type import LovType
from app.graphql.types.product_category_type import ProductCategoryType
from app.graphql.types.product_country_type import ProductCountryType
from app.models import (
    Group,
    Lov,
    ProductCategory,
    ProductCategoryVersion,
    ProductCountry,
    ProductSubcategory,
)
logger = logging.getLogger(__name__)
@type
class ImportMutationPayload:
    success: bool
    message: str = ""
def get_last_version_id(info):
    # Implement the logic to get the last version ID
    version = (
        info.context.session.query(ProductCategoryVersion)
        .filter(ProductCategoryVersion.version == "Version")
        .first()
    )
    if version:
        version_id = version.version_id
        version.version_id += 1
        return version_id + 1
    else:
        return 1
def get_full_product_country_list(info):
    # Implement the logic to get the full product country list
    products = (
        info.context.session.query(ProductCountry)
        .filter(ProductCountry.is_deleted == 0)
        .all()
    )
    return [
        ProductCountryType(
            id=p.id,
            country_id=p.country_id,
            product_name=p.product_name,
            is_deleted=p.is_deleted,
        )
        for p in products
    ]
def get_full_product_category_list(info):
    # Implement the logic to get the full product category list
    categories = (
        info.context.session.query(ProductCategory)
        .filter(ProductCategory.is_deleted == 0)
        .all()
    )
    return [
        ProductCategoryType(
            id=c.id,
            version_id=c.version_id,
            mdm_id=c.mdm_id,
            prod_cat_grp=c.prod_cat_grp,
            category_id=c.category_id,
            ta_sub_type=c.ta_sub_type,
            last_active=c.last_active,
            is_deleted=c.is_deleted,
            created_date=c.created_date,
            updated_date=c.updated_date,
            updated_by_id=c.updated_by_id,
            created_by_id=c.created_by_id,
        )
        for c in categories
    ]
def get_full_group_list(info):
    # Implement the logic to get the full group list
    groups = info.context.session.query(Group).filter(Group.is_deleted == 0).all()
    return [
        GroupType(
            id=g.id,
            name=g.name,
            type=g.type,
            version_id=g.version_id,
            status_id=g.status_id,
            created_date=g.created_date,
            created_by_id=g.created_by_id,
            updated_date=g.updated_date,
            updated_by_id=g.updated_by_id,
            is_deleted=g.is_deleted,
            description=g.description,
        )
        for g in groups
    ]
def get_lov_values(info, identifier_field_name):
    # Implement the logic to get the list of values (LOV) for the given identifier field name
    lovs = (
        info.context.session.query(Lov)
        .filter(Lov.identifier_field_name == identifier_field_name)
        .all()
    )
    return [
        LovType(
            id=lov.id,
            name=lov.name,
            value=lov.value,
            identifier_field_name=lov.identifier_field_name,
            is_deleted=lov.is_deleted,
            created_date=lov.created_date,
            created_by_id=lov.created_by_id,
            updated_date=lov.updated_date,
            updated_by_id=lov.updated_by_id,
        )
        for lov in lovs
    ]
async def get_mdm_id(info, product_name, country, product_country_list):
    # Implement the logic to get the MDM ID for the given product name and country
    for pc in product_country_list:
        country_data = await info.context.load_country_by_id.load(pc.country_id)
        if pc.product_name == product_name and country_data.name == country:
            product_country = (
                info.context.session.query(ProductCountry)
                .filter(
                    ProductCountry.product_name == product_name,
                    country_data.name == country,
                )
                .first()
            )
            return product_country.id if product_country else None
    return None
def get_category_id(cat_name, group_list):
    # Implement the logic to get the category ID for the given category name
    for group in group_list:
        if group.name == cat_name and group.type == "CATEGORY":
            return group.id
    return None
def get_subcategory_ids(topic_names, group_list):
    # Implement the logic to get the subcategory IDs for the given topic names
    topic_names_list = [
        topic.strip() for line in topic_names.splitlines() for topic in line.split(",")
    ]
    subcategory_ids = []
    for topic in topic_names_list:
        for group in group_list:
            if group.name == topic and group.type == "SUBCATEGORY":
                subcategory_ids.append(group.id)
                break
    return subcategory_ids
def is_valid_data(
    prd_group,
    mdm_id,
    cat_id,
    ta_sub_type,
    prdt_sbctgries,
    func_group_list,
    sub_ta_list,
):
    # Validate functional group
    for fg in func_group_list:
        if prd_group not in fg.value:
            return False
    # Validate MDM ID
    if mdm_id is None:
        return False
    # Validate category ID
    if cat_id is None:
        return False
    # Validate subcategories
    if not prdt_sbctgries:
        return False
    # Validate TA sub type
    if ta_sub_type not in [st.value for st in sub_ta_list]:
        return False
    return True
def is_existing_record(mdm_id, cat_id, prd_group, ta_sub_type, product_category_list):
    for prd_cat in product_category_list:
        if (
            prd_cat.mdm_id == mdm_id
            and prd_cat.category_id == cat_id
            and prd_cat.prod_cat_grp == prd_group
            and prd_cat.ta_sub_type == ta_sub_type
        ):
            return True
    return False
def truncate_tables(info):
    try:
        # Implement the logic to truncate the necessary tables
        info.context.session.execute(text("SET FOREIGN_KEY_CHECKS=0"))
        info.context.session.execute(text("TRUNCATE TABLE mdm_product_category"))
        info.context.session.execute(text("TRUNCATE TABLE mdm_product_subcategory"))
        info.context.session.execute(text("SET FOREIGN_KEY_CHECKS=1"))
        info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"
def save_product_categories(info, prd_cat_list):
    try:
        # Implement the logic to save the product categories to the database
        for prd_cat in prd_cat_list:
            product_category = ProductCategory(**prd_cat)
            info.context.session.add(product_category)
        info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"
def update_product_category_version(info, last_version_id):
    try:
        # Implement the logic to update the product category version
        version = (
            info.context.session.query(ProductCategoryVersion)
            .filter(ProductCategoryVersion.version == "Version")
            .first()
        )
        if version:
            version.version_id = last_version_id
            info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"
def save_topics_from_excel(info, df, prd_cat_list, group_list):
    try:
        # Implement the logic to save the topics from the Excel file
        for i, row in df.iterrows():
            topic_names = row["TOPIC"].strip()
            subcategory_ids = get_subcategory_ids(topic_names, group_list)
            product_category_id = prd_cat_list[i]["category_id"]
            for subcat_id in subcategory_ids:
                product_category = (
                    info.context.session.query(ProductCategory)
                    .filter(ProductCategory.category_id == product_category_id)
                    .first()
                )
                if product_category:
                    product_subcategory = ProductSubcategory(
                        pro_cat_id=product_category.id, subcategory_id=subcat_id
                    )
                    info.context.session.add(product_subcategory)
        info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"
@type
class ImportFunctionalityMutation:
    @mutation
    async def read_file(self, file: Upload, import_type: str, info: Info) -> ImportMutationPayload:
        logger.info(f"Current User {info.context.current_user}")
        logger.info(f"Reading file {file}")
        try:
            # Read the file content into memory
            file_content = await file.read()
            # Load the workbook from the file content
            workbook = openpyxl.load_workbook(filename=BytesIO(file_content))
            # Assuming you want to read the first sheet
            sheet = workbook.active
            # Read the content of the sheet into a DataFrame
            data = []
            for row in sheet.iter_rows(values_only=True):
                data.append(row)
            df = pd.DataFrame(data[1:], columns=data[0])
            # Validate headers
            headers = df.columns.tolist()
            expected_headers = [
                "FUNCTIONAL GROUP",
                "COUNTRY",
                "PRODUCT",
                "CATEGORY",
                "TOPIC",
                "TA SUB TYPE",
            ]
            if headers != expected_headers:
                return ImportMutationPayload(success = False, message="This is not a correct excel to upload")
            # Initialize variables
            data_error_for_all = ""
            prd_cat_list = []
            last_version_id = (
                1
                if import_type.lower() == "truncateandproceed"
                else get_last_version_id(info)
            )
            # Load necessary data for validation
            product_country_list = get_full_product_country_list(info)
            # logger.info(f"COUNTRY LIST: {product_country_list}")
            product_category_list = get_full_product_category_list(info)
            # logger.info(f"CATEGORY LIST: {product_category_list}")
            group_list = get_full_group_list(info)
            # logger.info(f"GROUP LIST: {group_list}")
            func_group_list = get_lov_values(info, "PROD CAT GRP")
            # logger.info(f"FUNCTION GROUP LIST: {func_group_list}")
            sub_ta_list = get_lov_values(info, "therapeutic")
            # logger.info(f"SUB TA LIST: {sub_ta_list}")
            # Process the DataFrame
            for i, row in df.iterrows():
                prd_group = row["FUNCTIONAL GROUP"].strip()
                country = row["COUNTRY"].strip()
                product_name = row["PRODUCT"].strip()
                cat_name = row["CATEGORY"].strip()
                ta_sub_type = row["TA SUB TYPE"].strip()
                topic_names = row["TOPIC"].strip()
                # Validate data
                mdm_id = await get_mdm_id(
                    info, product_name, country, product_country_list
                )
                cat_id = get_category_id(cat_name, group_list)
                prdt_sbctgries = get_subcategory_ids(topic_names, group_list)
                if not is_valid_data(
                    prd_group,
                    mdm_id,
                    cat_id,
                    ta_sub_type,
                    prdt_sbctgries,
                    func_group_list,
                    sub_ta_list,
                ):
                    data_error_for_all += f"Error in data at line {i+2}: Invalid data\n"
                    continue
                if (
                    is_existing_record(
                        mdm_id, cat_id, prd_group, ta_sub_type, product_category_list
                    )
                    and import_type.lower() != "truncateandproceed"
                ):
                    data_error_for_all += f"The record at line {i+2} already exists.\n"
                    continue
                # Create ProductCategory object
                prd_cat = {
                    "version_id": last_version_id,
                    "prod_cat_grp": prd_group,
                    "mdm_id": mdm_id,
                    "category_id": cat_id,
                    "ta_sub_type": ta_sub_type,
                    "last_active": 1,
                    "is_deleted": 0,
                    "created_date": pd.Timestamp.now(),
                    "created_by_id": "SYSTEM",
                    "updated_date": pd.Timestamp.now(),
                    "updated_by_id": "SYSTEM",
                }
                prd_cat_list.append(prd_cat)
                last_version_id += 1
            if data_error_for_all:
                return ImportMutationPayload(success = False, message={data_error_for_all})
            # Save data to the database
            if import_type.lower() == "truncateandproceed":
                truncate_tables(info)
                save_product_categories(info, prd_cat_list)
                update_product_category_version(info, last_version_id)
                save_topics_from_excel(info, df, prd_cat_list, group_list)
            else:
                save_product_categories(info, prd_cat_list)
                update_product_category_version(info, last_version_id)
                save_topics_from_excel(info, df, prd_cat_list, group_list)
            return ImportMutationPayload(success = True, message="Import completed successfully")
        except Exception as e:
            info.context.session.rollback()
            return ImportMutationPayload(success = False, message=f"Error: {str(e)}")
Mutations should be primarily concerned with GraphQL and types. Any non-trivial business logic should be moved into an app/actions/.... All of the utility functions (get_last_version_id, get_full_product_country_list, ...) should live in the action. Actions should not use GraphQL types, but plain python types. Instead of passing in the heavy GraphQL info object, please pass in session, login, and other fields as needed. Mutations should have their own unique payload, not a simple string type. The params to mutations should follow python/GraphQL conventions. ImportType should be import_type in python, which will appear as importType in GraphQL. However, having an import_type param is an anti-pattern in itself. The names in this file are all awful, the file name, the mutation name, etc. It's unclear what is being imported or why.Editor is loading...
Leave a Comment