Untitled
unknown
plain_text
a year ago
8.5 kB
9
Indexable
import strawberry
from strawberry.fastapi import GraphQLRouter
from sqlalchemy import Column, Integer, String, ForeignKey, create_engine, select, text
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from fastapi import FastAPI, Depends, HTTPException
from typing import List, Optional
# Database configuration
SQLALCHEMY_DATABASE_URL = "mysql+pymysql://jsa_dev_admin:RM#AXR02NhAp@itx-acm-jsa-mdm-dev.czijpxum5el7.us-east-1.rds.amazonaws.com/jsamdm_dev"
# Create SQLAlchemy engine
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base declarative class
Base = declarative_base()
# Database models
class User(Base):
__tablename__ = 'mdm_users'
MDM_USER_ID = Column(Integer, primary_key=True)
MDM_LOGIN = Column(String, unique=True)
countries = relationship('UserCountry', back_populates='user')
therapeutics = relationship('UserTherapeutic', back_populates='user')
class UserCountry(Base):
__tablename__ = 'mdm_users_countries'
USER_CNTRY_ID = Column(Integer, primary_key=True)
USER_ID = Column(Integer, ForeignKey('mdm_users.MDM_USER_ID'))
COUNTRY_ID = Column(Integer)
user = relationship('User', back_populates='countries')
class UserTherapeutic(Base):
__tablename__ = 'mdm_users_therapeutics'
USER_TA_ID = Column(Integer, primary_key=True)
USER_ID = Column(Integer, ForeignKey('mdm_users.MDM_USER_ID'))
TA_ID = Column(Integer)
user = relationship('User', back_populates='therapeutics')
# FastAPI app initialization
app = FastAPI()
# Dependency to get DB session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# GraphQL Type Definitions
@strawberry.type
class UserType:
mdmUserId: int
mdmLogin: str
@strawberry.type
@strawberry.type
class ProductViewType:
lastActive: Optional[str]
aespqcsBeReported: Optional[str]
analyticsDisplayName: Optional[str]
blackboxWarning: Optional[str]
brandStrength: Optional[str]
comarketingPartner: Optional[str] # Allow null values
comments: Optional[int]
countryId: Optional[int]
countryName: Optional[str]
createdById: Optional[str]
createdDate: Optional[str]
currentRecord: Optional[str]
deleteState: Optional[str]
fdaApproved: Optional[str]
fieldTeamBeSubmittingMirsOnThisProduct: Optional[str]
genericName: Optional[str]
groupType: Optional[str]
indication: Optional[int]
isDeleted: Optional[str]
isItTrademarkOrRegistered: Optional[str]
janssenMstrPrdctNm: Optional[str]
jjOperatingCompany: Optional[str]
jnjFlag: Optional[str]
jnjFullCompoundId: Optional[int]
marketedBy: Optional[int]
mdmId: Optional[str]
piLink: Optional[str]
productName: Optional[str]
productPhase: Optional[str]
productStatus: Optional[str]
quadrant: Optional[int]
recordId: Optional[int]
recordId2: Optional[int]
regionId: Optional[str]
regionName: Optional[str]
reltioId: Optional[str]
requestStatus: Optional[int]
requestStatusId: Optional[int]
seriesId: Optional[str]
taName: Optional[str]
taSubType: Optional[int]
thArea: Optional[str]
updatedById: Optional[str]
updatedDate: Optional[str]
websiteProductName: Optional[str]
@strawberry.type
class Query:
@strawberry.field
async def getUser(self, user: str) -> UserType:
with SessionLocal() as db:
user_query = db.execute(select(User).filter_by(MDM_LOGIN=user))
user_obj = user_query.scalar_one_or_none()
if user_obj is None:
raise HTTPException(status_code=404, detail="User not found")
return UserType(mdmUserId=user_obj.MDM_USER_ID, mdmLogin=user_obj.MDM_LOGIN)
@strawberry.field
async def getProductData(self, user: str, page: int = 1, rows: int = 10) -> List[ProductViewType]:
with SessionLocal() as db:
# Fetch user
user_query = db.execute(select(User).filter_by(MDM_LOGIN=user))
user_obj = user_query.scalar_one_or_none()
if user_obj is None:
raise HTTPException(status_code=404, detail="User not found")
# Fetch associated country IDs and therapeutic area IDs
country_ids = [uc.COUNTRY_ID for uc in user_obj.countries]
ta_ids = [ut.TA_ID for ut in user_obj.therapeutics]
# Build the SQL query
sql = "SELECT * FROM mdm_product_view WHERE current_record=1"
params = {}
if country_ids:
sql += " AND COUNTRY_ID IN :countryFilterIds"
params['countryFilterIds'] = tuple(country_ids)
if ta_ids:
sql += " AND TH_AREA IN :taFilterIds"
params['taFilterIds'] = tuple(ta_ids)
# Pagination logic
start_index = (page - 1) * rows
sql += " LIMIT :start_index, :rows"
params['start_index'] = start_index
params['rows'] = rows
# Execute the product query using text()
result = db.execute(text(sql), params)
product_list = result.fetchall() # Fetch all results
# Convert the result into the GraphQL type
return [
ProductViewType(
lastActive=product.LAST_ACTIVE,
aespqcsBeReported=product.AESPQCS_BE_REPORTED,
analyticsDisplayName=product.ANALYTICS_DISPLAY_NAME,
blackboxWarning=product.BLACKBOX_WARNING,
brandStrength=product.BRAND_STRENGTH,
comarketingPartner=product.COMARKETING_PARTNER or "No Partner", # Default value if None
comments=product.COMMENTS,
countryId=product.COUNTRY_ID,
countryName=product.COUNTRY_NAME,
createdById=product.CREATED_BY_ID,
createdDate=product.CREATED_DATE,
currentRecord=product.CURRENT_RECORD,
deleteState=product.DELETE_STATE,
fdaApproved=product.FDA_APPROVED,
fieldTeamBeSubmittingMirsOnThisProduct=product.FIELD_TEAM_BE_SUBMITTING_MIRS_ON_THIS_PRODUCT,
genericName=product.GENERIC_NAME,
groupType=product.GROUP_TYPE,
indication=product.INDICATION,
isDeleted=product.IS_DELETED,
isItTrademarkOrRegistered=product.IS_IT_TRADEMARK_OR_REGISTERED,
janssenMstrPrdctNm=product.JANSSEN_MSTR_PRDCT_NM,
jjOperatingCompany=product.JJ_OPERATING_COMPANY,
jnjFlag=product.JNJ_FLAG,
jnjFullCompoundId=product.JNJ_FULL_COMPOUND_ID,
marketedBy=product.MARKETED_BY,
mdmId=product.MDM_ID,
piLink=product.PI_LINK,
productName=product.PRODUCT_NAME,
productPhase=product.PRODUCT_PHASE,
productStatus=product.PRODUCT_STATUS,
quadrant=product.QUADRANT,
recordId=product.RECORD_ID,
recordId2=product.RECORD_ID_2,
regionId=product.REGION_ID,
regionName=product.REGION_NAME,
reltioId=product.RELTIO_ID,
requestStatus=product.REQUEST_STATUS,
requestStatusId=product.REQUEST_STATUS_ID,
seriesId=product.SERIES_ID,
taName=product.TA_NAME,
taSubType=product.TA_SUB_TYPE,
thArea=product.TH_AREA,
updatedById=product.UPDATED_BY_ID,
updatedDate=product.UPDATED_DATE,
websiteProductName=product.WEBSITE_PRODUCT_NAME,
) for product in product_list
]
# Create GraphQL schema and router
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter(schema)
# Include the GraphQL router in your FastAPI app
app.include_router(graphql_app, prefix="/graphql")
Editor is loading...
Leave a Comment