Untitled

 avatar
unknown
plain_text
a month ago
14 kB
4
Indexable
import logging
from io import BytesIO
from typing import List

import openpyxl
import pandas as pd
import strawberry
from sqlalchemy.sql import text
from strawberry import mutation, type
from strawberry.file_uploads import Upload

from app.graphql.schema import Info
from app.graphql.types.group_type import GroupType
from app.graphql.types.lov_type import LovType
from app.graphql.types.product_category_type import ProductCategoryType
from app.graphql.types.product_country_type import ProductCountryType
from app.models import (
    Group,
    Lov,
    ProductCategory,
    ProductCategoryVersion,
    ProductCountry,
    ProductSubcategory,
)

logger = logging.getLogger(__name__)


def get_last_version_id(info):
    # Implement the logic to get the last version ID
    version = (
        info.context.session.query(ProductCategoryVersion)
        .filter(ProductCategoryVersion.version == "Version")
        .first()
    )
    if version:
        version_id = version.version_id
        version.version_id += 1
        return version_id + 1
    else:
        return 1


def get_full_product_country_list(info):
    # Implement the logic to get the full product country list
    products = (
        info.context.session.query(ProductCountry)
        .filter(ProductCountry.is_deleted == 0)
        .all()
    )
    return [
        ProductCountryType(
            id=p.id,
            country_id=p.country_id,
            product_name=p.product_name,
            is_deleted=p.is_deleted,
        )
        for p in products
    ]


def get_full_product_category_list(info):
    # Implement the logic to get the full product category list
    categories = (
        info.context.session.query(ProductCategory)
        .filter(ProductCategory.is_deleted == 0)
        .all()
    )
    return [
        ProductCategoryType(
            id=c.id,
            version_id=c.version_id,
            mdm_id=c.mdm_id,
            prod_cat_grp=c.prod_cat_grp,
            category_id=c.category_id,
            ta_sub_type=c.ta_sub_type,
            last_active=c.last_active,
            is_deleted=c.is_deleted,
            created_date=c.created_date,
            updated_date=c.updated_date,
            updated_by_id=c.updated_by_id,
            created_by_id=c.created_by_id,
        )
        for c in categories
    ]


def get_full_group_list(info):
    # Implement the logic to get the full group list
    groups = info.context.session.query(Group).filter(Group.is_deleted == 0).all()
    return [
        GroupType(
            id=g.id,
            name=g.name,
            type=g.type,
            version_id=g.version_id,
            status_id=g.status_id,
            created_date=g.created_date,
            created_by_id=g.created_by_id,
            updated_date=g.updated_date,
            updated_by_id=g.updated_by_id,
            is_deleted=g.is_deleted,
            description=g.description,
        )
        for g in groups
    ]


def get_lov_values(info, identifier_field_name):
    # Implement the logic to get the list of values (LOV) for the given identifier field name
    lovs = (
        info.context.session.query(Lov)
        .filter(Lov.identifier_field_name == identifier_field_name)
        .all()
    )
    return [
        LovType(
            id=lov.id,
            name=lov.name,
            value=lov.value,
            identifier_field_name=lov.identifier_field_name,
            is_deleted=lov.is_deleted,
            created_date=lov.created_date,
            created_by_id=lov.created_by_id,
            updated_date=lov.updated_date,
            updated_by_id=lov.updated_by_id,
        )
        for lov in lovs
    ]


async def get_mdm_id(info, product_name, country, product_country_list):
    # Implement the logic to get the MDM ID for the given product name and country
    for pc in product_country_list:
        country_data = await info.context.load_country_by_id.load(pc.country_id)
        if pc.product_name == product_name and country_data.name == country:
            product_country = (
                info.context.session.query(ProductCountry)
                .filter(
                    ProductCountry.product_name == product_name,
                    country_data.name == country,
                )
                .first()
            )
            return product_country.id if product_country else None
    return None


def get_category_id(cat_name, group_list):
    # Implement the logic to get the category ID for the given category name
    for group in group_list:
        if group.name == cat_name and group.type == "CATEGORY":
            return group.id
    return None


def get_subcategory_ids(topic_names, group_list):
    # Implement the logic to get the subcategory IDs for the given topic names
    topic_names_list = [
        topic.strip() for line in topic_names.splitlines() for topic in line.split(",")
    ]
    subcategory_ids = []
    for topic in topic_names_list:
        for group in group_list:
            if group.name == topic and group.type == "SUBCATEGORY":
                subcategory_ids.append(group.id)
                break
    return subcategory_ids


def is_valid_data(
    prd_group,
    mdm_id,
    cat_id,
    ta_sub_type,
    prdt_sbctgries,
    func_group_list,
    sub_ta_list,
):
    # Validate functional group
    for fg in func_group_list:
        if prd_group not in fg.value:
            return False

    # Validate MDM ID
    if mdm_id is None:
        return False

    # Validate category ID
    if cat_id is None:
        return False

    # Validate subcategories
    if not prdt_sbctgries:
        return False

    # Validate TA sub type
    if ta_sub_type not in [st.value for st in sub_ta_list]:
        return False

    return True


def is_existing_record(mdm_id, cat_id, prd_group, ta_sub_type, product_category_list):
    for prd_cat in product_category_list:
        if (
            prd_cat.mdm_id == mdm_id
            and prd_cat.category_id == cat_id
            and prd_cat.prod_cat_grp == prd_group
            and prd_cat.ta_sub_type == ta_sub_type
        ):
            return True
    return False


def truncate_tables(info):
    try:
        # Implement the logic to truncate the necessary tables
        info.context.session.execute(text("SET FOREIGN_KEY_CHECKS=0"))
        info.context.session.execute(text("TRUNCATE TABLE mdm_product_category"))
        info.context.session.execute(text("TRUNCATE TABLE mdm_product_subcategory"))
        info.context.session.execute(text("SET FOREIGN_KEY_CHECKS=1"))
        info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"


def save_product_categories(info, prd_cat_list):
    try:
        # Implement the logic to save the product categories to the database
        for prd_cat in prd_cat_list:
            product_category = ProductCategory(**prd_cat)
            info.context.session.add(product_category)
        info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"


def update_product_category_version(info, last_version_id):
    try:
        # Implement the logic to update the product category version
        version = (
            info.context.session.query(ProductCategoryVersion)
            .filter(ProductCategoryVersion.version == "Version")
            .first()
        )
        if version:
            version.version_id = last_version_id
            info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"


def save_topics_from_excel(info, df, prd_cat_list, group_list):
    try:
        # Implement the logic to save the topics from the Excel file
        for i, row in df.iterrows():
            topic_names = row["TOPIC"].strip()
            subcategory_ids = get_subcategory_ids(topic_names, group_list)
            product_category_id = prd_cat_list[i]["category_id"]
            for subcat_id in subcategory_ids:
                product_category = (
                    info.context.session.query(ProductCategory)
                    .filter(ProductCategory.category_id == product_category_id)
                    .first()
                )
                if product_category:
                    product_subcategory = ProductSubcategory(
                        pro_cat_id=product_category.id, subcategory_id=subcat_id
                    )
                    info.context.session.add(product_subcategory)
        info.context.session.commit()
    except Exception as e:
        info.context.session.rollback()
        return f"Error: {str(e)}"


@type
class ImportFunctionalityMutation:
    @mutation
    async def read_file(self, file: Upload, ImportType: str, info: Info) -> str:
        logger.info(f"Current User {info.context.current_user}")
        logger.info(f"Reading file {file}")

        try:
            # Read the file content into memory
            file_content = await file.read()

            # Load the workbook from the file content
            workbook = openpyxl.load_workbook(filename=BytesIO(file_content))

            # Assuming you want to read the first sheet
            sheet = workbook.active

            # Read the content of the sheet into a DataFrame
            data = []
            for row in sheet.iter_rows(values_only=True):
                data.append(row)
            df = pd.DataFrame(data[1:], columns=data[0])

            # Validate headers
            headers = df.columns.tolist()
            expected_headers = [
                "FUNCTIONAL GROUP",
                "COUNTRY",
                "PRODUCT",
                "CATEGORY",
                "TOPIC",
                "TA SUB TYPE",
            ]
            if headers != expected_headers:
                return "This is not a correct excel to upload"

            # Initialize variables
            data_error_for_all = ""
            prd_cat_list = []
            prd_cat_version_updated = None
            last_version_id = (
                1
                if ImportType.lower() == "truncateandproceed"
                else get_last_version_id(info)
            )

            # Load necessary data for validation
            product_country_list = get_full_product_country_list(info)
            # logger.info(f"COUNTRY LIST: {product_country_list}")
            product_category_list = get_full_product_category_list(info)
            # logger.info(f"CATEGORY LIST: {product_category_list}")
            group_list = get_full_group_list(info)
            # logger.info(f"GROUP LIST: {group_list}")
            func_group_list = get_lov_values(info, "PROD CAT GRP")
            # logger.info(f"FUNCTION GROUP LIST: {func_group_list}")
            sub_ta_list = get_lov_values(info, "therapeutic")
            # logger.info(f"SUB TA LIST: {sub_ta_list}")

            # Process the DataFrame
            for i, row in df.iterrows():
                prd_group = row["FUNCTIONAL GROUP"].strip()
                country = row["COUNTRY"].strip()
                product_name = row["PRODUCT"].strip()
                cat_name = row["CATEGORY"].strip()
                ta_sub_type = row["TA SUB TYPE"].strip()
                topic_names = row["TOPIC"].strip()

                # Validate data
                mdm_id = await get_mdm_id(
                    info, product_name, country, product_country_list
                )
                cat_id = get_category_id(cat_name, group_list)
                prdt_sbctgries = get_subcategory_ids(topic_names, group_list)

                if not is_valid_data(
                    prd_group,
                    mdm_id,
                    cat_id,
                    ta_sub_type,
                    prdt_sbctgries,
                    func_group_list,
                    sub_ta_list,
                ):
                    data_error_for_all += f"Error in data at line {i+2}: Invalid data\n"
                    continue

                if (
                    is_existing_record(
                        mdm_id, cat_id, prd_group, ta_sub_type, product_category_list
                    )
                    and ImportType.lower() != "truncateandproceed"
                ):
                    data_error_for_all += f"The record at line {i+2} already exists.\n"
                    continue

                # Create ProductCategory object
                prd_cat = {
                    "version_id": last_version_id,
                    "prod_cat_grp": prd_group,
                    "mdm_id": mdm_id,
                    "category_id": cat_id,
                    "ta_sub_type": ta_sub_type,
                    "last_active": 1,
                    "is_deleted": 0,
                    "created_date": pd.Timestamp.now(),
                    "created_by_id": "SYSTEM",
                    "updated_date": pd.Timestamp.now(),
                    "updated_by_id": "SYSTEM",
                }
                prd_cat_list.append(prd_cat)
                last_version_id += 1

            if data_error_for_all:
                return f"message={data_error_for_all}"

            # Save data to the database
            if ImportType.lower() == "truncateandproceed":
                truncate_tables(info)
                save_product_categories(info, prd_cat_list)
                update_product_category_version(info, last_version_id)
                save_topics_from_excel(info, df, prd_cat_list, group_list)
            else:
                save_product_categories(info, prd_cat_list)
                update_product_category_version(info, last_version_id)
                save_topics_from_excel(info, df, prd_cat_list, group_list)

            return "Import completed successfully"

        except Exception as e:
            info.context.session.rollback()
            return f"Error: {str(e)}"

Convert the above code into sqlalchemy 3.x version. Split the above code such that helper functions are in one file and mutation is in another file.
Editor is loading...
Leave a Comment