Untitled
unknown
plain_text
12 days ago
8.4 kB
2
Indexable
Never
import strawberry from strawberry.fastapi import GraphQLRouter from sqlalchemy import Column, Integer, String, ForeignKey, create_engine, select, text from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.ext.declarative import declarative_base from fastapi import FastAPI, Depends, HTTPException from typing import List, Optional # Database configuration SQLALCHEMY_DATABASE_URL = "" # Create SQLAlchemy engine engine = create_engine(SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Base declarative class Base = declarative_base() # Database models class User(Base): __tablename__ = 'mdm_users' MDM_USER_ID = Column(Integer, primary_key=True) MDM_LOGIN = Column(String, unique=True) countries = relationship('UserCountry', back_populates='user') therapeutics = relationship('UserTherapeutic', back_populates='user') class UserCountry(Base): __tablename__ = 'mdm_users_countries' USER_CNTRY_ID = Column(Integer, primary_key=True) USER_ID = Column(Integer, ForeignKey('mdm_users.MDM_USER_ID')) COUNTRY_ID = Column(Integer) user = relationship('User', back_populates='countries') class UserTherapeutic(Base): __tablename__ = 'mdm_users_therapeutics' USER_TA_ID = Column(Integer, primary_key=True) USER_ID = Column(Integer, ForeignKey('mdm_users.MDM_USER_ID')) TA_ID = Column(Integer) user = relationship('User', back_populates='therapeutics') # FastAPI app initialization app = FastAPI() # Dependency to get DB session def get_db(): db = SessionLocal() try: yield db finally: db.close() # GraphQL Type Definitions @strawberry.type class UserType: mdmUserId: int mdmLogin: str @strawberry.type @strawberry.type class ProductViewType: lastActive: Optional[str] aespqcsBeReported: Optional[str] analyticsDisplayName: Optional[str] blackboxWarning: Optional[str] brandStrength: Optional[str] comarketingPartner: Optional[str] # Allow null values comments: Optional[int] countryId: Optional[int] countryName: Optional[str] createdById: Optional[str] createdDate: Optional[str] currentRecord: Optional[str] deleteState: Optional[str] fdaApproved: Optional[str] fieldTeamBeSubmittingMirsOnThisProduct: Optional[str] genericName: Optional[str] groupType: Optional[str] indication: Optional[int] isDeleted: Optional[str] isItTrademarkOrRegistered: Optional[str] janssenMstrPrdctNm: Optional[str] jjOperatingCompany: Optional[str] jnjFlag: Optional[str] jnjFullCompoundId: Optional[int] marketedBy: Optional[int] mdmId: Optional[str] piLink: Optional[str] productName: Optional[str] productPhase: Optional[str] productStatus: Optional[str] quadrant: Optional[int] recordId: Optional[int] recordId2: Optional[int] regionId: Optional[str] regionName: Optional[str] reltioId: Optional[str] requestStatus: Optional[int] requestStatusId: Optional[int] seriesId: Optional[str] taName: Optional[str] taSubType: Optional[int] thArea: Optional[str] updatedById: Optional[str] updatedDate: Optional[str] websiteProductName: Optional[str] @strawberry.type class Query: @strawberry.field async def getUser(self, user: str) -> UserType: with SessionLocal() as db: user_query = db.execute(select(User).filter_by(MDM_LOGIN=user)) user_obj = user_query.scalar_one_or_none() if user_obj is None: raise HTTPException(status_code=404, detail="User not found") return UserType(mdmUserId=user_obj.MDM_USER_ID, mdmLogin=user_obj.MDM_LOGIN) @strawberry.field async def getProductData(self, user: str, page: int = 1, rows: int = 10) -> List[ProductViewType]: with SessionLocal() as db: # Fetch user user_query = db.execute(select(User).filter_by(MDM_LOGIN=user)) user_obj = user_query.scalar_one_or_none() if user_obj is None: raise HTTPException(status_code=404, detail="User not found") # Fetch associated country IDs and therapeutic area IDs country_ids = [uc.COUNTRY_ID for uc in user_obj.countries] ta_ids = [ut.TA_ID for ut in user_obj.therapeutics] # Build the SQL query sql = "SELECT * FROM mdm_product_view WHERE current_record=1" params = {} if country_ids: sql += " AND COUNTRY_ID IN :countryFilterIds" params['countryFilterIds'] = tuple(country_ids) if ta_ids: sql += " AND TH_AREA IN :taFilterIds" params['taFilterIds'] = tuple(ta_ids) # Pagination logic start_index = (page - 1) * rows sql += " LIMIT :start_index, :rows" params['start_index'] = start_index params['rows'] = rows # Execute the product query using text() result = db.execute(text(sql), params) product_list = result.fetchall() # Fetch all results # Convert the result into the GraphQL type return [ ProductViewType( lastActive=product.LAST_ACTIVE, aespqcsBeReported=product.AESPQCS_BE_REPORTED, analyticsDisplayName=product.ANALYTICS_DISPLAY_NAME, blackboxWarning=product.BLACKBOX_WARNING, brandStrength=product.BRAND_STRENGTH, comarketingPartner=product.COMARKETING_PARTNER or "No Partner", # Default value if None comments=product.COMMENTS, countryId=product.COUNTRY_ID, countryName=product.COUNTRY_NAME, createdById=product.CREATED_BY_ID, createdDate=product.CREATED_DATE, currentRecord=product.CURRENT_RECORD, deleteState=product.DELETE_STATE, fdaApproved=product.FDA_APPROVED, fieldTeamBeSubmittingMirsOnThisProduct=product.FIELD_TEAM_BE_SUBMITTING_MIRS_ON_THIS_PRODUCT, genericName=product.GENERIC_NAME, groupType=product.GROUP_TYPE, indication=product.INDICATION, isDeleted=product.IS_DELETED, isItTrademarkOrRegistered=product.IS_IT_TRADEMARK_OR_REGISTERED, janssenMstrPrdctNm=product.JANSSEN_MSTR_PRDCT_NM, jjOperatingCompany=product.JJ_OPERATING_COMPANY, jnjFlag=product.JNJ_FLAG, jnjFullCompoundId=product.JNJ_FULL_COMPOUND_ID, marketedBy=product.MARKETED_BY, mdmId=product.MDM_ID, piLink=product.PI_LINK, productName=product.PRODUCT_NAME, productPhase=product.PRODUCT_PHASE, productStatus=product.PRODUCT_STATUS, quadrant=product.QUADRANT, recordId=product.RECORD_ID, recordId2=product.RECORD_ID_2, regionId=product.REGION_ID, regionName=product.REGION_NAME, reltioId=product.RELTIO_ID, requestStatus=product.REQUEST_STATUS, requestStatusId=product.REQUEST_STATUS_ID, seriesId=product.SERIES_ID, taName=product.TA_NAME, taSubType=product.TA_SUB_TYPE, thArea=product.TH_AREA, updatedById=product.UPDATED_BY_ID, updatedDate=product.UPDATED_DATE, websiteProductName=product.WEBSITE_PRODUCT_NAME, ) for product in product_list ] # Create GraphQL schema and router schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) # Include the GraphQL router in your FastAPI app app.include_router(graphql_app, prefix="/graphql")
Leave a Comment