Untitled

mail@pastecode.io avatar
unknown
plain_text
13 days ago
5.0 kB
2
Indexable
Never
import strawberry
from strawberry.fastapi import GraphQLRouter

from sqlalchemy import Column, Integer, String, ForeignKey, create_engine, select, inspect, text
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.ext.declarative import declarative_base

from fastapi import FastAPI, Depends, HTTPException

from typing import List, Optional, Dict, Any
import sqlalchemy

# Database configuration
SQLALCHEMY_DATABASE_URL = ""

# Create SQLAlchemy engine
engine = create_engine(SQLALCHEMY_DATABASE_URL)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Base declarative class
Base = declarative_base()

# Database models
class User(Base):
    __tablename__ = 'mdm_users'
    MDM_USER_ID = Column(Integer, primary_key=True)
    MDM_LOGIN = Column(String, unique=True)
    countries = relationship('UserCountry', back_populates='user')
    therapeutics = relationship('UserTherapeutic', back_populates='user')


class UserCountry(Base):
    __tablename__ = 'mdm_users_countries'
    USER_CNTRY_ID = Column(Integer, primary_key=True)
    USER_ID = Column(Integer, ForeignKey('mdm_users.MDM_USER_ID'))
    COUNTRY_ID = Column(Integer)
    user = relationship('User', back_populates='countries')


class UserTherapeutic(Base):
    __tablename__ = 'mdm_users_therapeutics'
    USER_TA_ID = Column(Integer, primary_key=True)
    USER_ID = Column(Integer, ForeignKey('mdm_users.MDM_USER_ID'))
    TA_ID = Column(Integer)
    user = relationship('User', back_populates='therapeutics')


# FastAPI app initialization
app = FastAPI()

# Dependency to get DB session
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# Introspect database schema
def get_table_columns(table_name: str):
    inspector = inspect(engine)
    columns = inspector.get_columns(table_name)
    return {col['name']: col['type'] for col in columns}

# Dynamically create GraphQL type for product view
def create_dynamic_product_type():
    columns = get_table_columns('mdm_product_view')
    fields = {}

    # Map SQLAlchemy types to GraphQL types
    for col_name, col_type in columns.items():
        if isinstance(col_type, sqlalchemy.Integer):
            fields[col_name] = Optional[int]
        elif isinstance(col_type, sqlalchemy.String):
            fields[col_name] = Optional[str]
        else:
            fields[col_name] = Optional[str]  # Default to string for other types

    # Dynamically create a class with these fields
    return type('ProductViewType', (object,), fields)

# Create dynamic GraphQL type
ProductViewType = strawberry.experimental.create_type(
    "ProductViewType",
    lambda: create_dynamic_product_type()
)

# GraphQL Queries
@strawberry.type
class Query:
    @strawberry.field
    async def getUser(self, user: str) -> Dict[str, Any]:
        with SessionLocal() as db:
            user_query = db.execute(select(User).filter_by(MDM_LOGIN=user))
            user_obj = user_query.scalar_one_or_none()

            if user_obj is None:
                raise HTTPException(status_code=404, detail="User not found")

            return {
                "mdmUserId": user_obj.MDM_USER_ID,
                "mdmLogin": user_obj.MDM_LOGIN
            }

    @strawberry.field
    async def getProductData(self, user: str, page: int = 1, rows: int = 10) -> List[ProductViewType]:
        with SessionLocal() as db:
            # Fetch user
            user_query = db.execute(select(User).filter_by(MDM_LOGIN=user))
            user_obj = user_query.scalar_one_or_none()

            if user_obj is None:
                raise HTTPException(status_code=404, detail="User not found")

            # Fetch associated country IDs and therapeutic area IDs
            country_ids = [uc.COUNTRY_ID for uc in user_obj.countries]
            ta_ids = [ut.TA_ID for ut in user_obj.therapeutics]

            # Build the SQL query dynamically
            sql = "SELECT * FROM mdm_product_view WHERE current_record=1"
            params = {}

            if country_ids:
                sql += " AND COUNTRY_ID IN :countryFilterIds"
                params['countryFilterIds'] = tuple(country_ids)

            if ta_ids:
                sql += " AND TH_AREA IN :taFilterIds"
                params['taFilterIds'] = tuple(ta_ids)

            # Pagination logic
            start_index = (page - 1) * rows
            sql += " LIMIT :start_index, :rows"
            params['start_index'] = start_index
            params['rows'] = rows

            # Execute the product query using text()
            result = db.execute(text(sql), params)
            product_list = result.fetchall()  # Fetch all results

            # Return the results as a dynamic ProductViewType
            return [ProductViewType(**dict(product)) for product in product_list]


# Create GraphQL schema and router
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter(schema)

# Include the GraphQL router in FastAPI app
app.include_router(graphql_app, prefix="/graphql")
Leave a Comment