Untitled

 avatar
unknown
plain_text
5 months ago
16 kB
1
Indexable
import os
from src.dependables import logging_config
logger = logging_config.logger(__name__)

from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi.responses import RedirectResponse
from alembic import command
from alembic.config import Config
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import select
from src.dependables.token_cache import remove_token_from_cache, clear_token_cache

from src.database import dbutils
from src.database.models.Metadata import MetadataModel
from src.dependables import verify_admin_token
from src.dependables import db
from urllib import parse
from sqlalchemy import create_engine, text
from . import api, app, auth
import traceback
# import logging
# from httpx import AsyncClient
# import httpx

# logging.basicConfig(level=logging.DEBUG)
# httpx_logger = logging.getLogger("httpx")
# httpx_logger.setLevel(logging.DEBUG)

router = APIRouter()

authorize_url=os.environ.get('AUTHORIZATION_ENDPOINT')
access_token_url=os.environ.get('ACCESS_TOKEN_ENDPOINT')
jwks_uri=os.environ.get('JWKS_ID_TOKEN_ENDPOINT')
scope=os.environ.get('OIDC_SCOPE')
client_id=os.environ.get('OIDC_CLIENT_ID')
client_secret=os.environ.get('OIDC_CLIENT_SECRET')
redirect_uri=os.environ.get("AUTHCALLBACK_REDIRECT_URI")

#### Login Flow Start ####
class UnauthenticatedError(HTTPException):
    def __init__(self) -> None:
        super().__init__(status_code=401, detail="You are not authenticated.")

oauth = OAuth()

# async def create_custom_client():
#     async with httpx.AsyncClient() as client:
#         return client

# custom_client = create_custom_client()

oauth.register(  # this allows us to call oauth.discord later on
    'pingfed',
    authorize_url=authorize_url,
    access_token_url=access_token_url,
    jwks_uri=jwks_uri,
    scope=scope,
    client_id=client_id,
    client_secret=client_secret,
    # client_kwargs={'client': custom_client},
)

# define the endpoints for the OAuth2 flow
@router.get('/login')
async def get_authorization_code(request: Request):
    # Clear Session Before Login start
    request.session.clear()
    """OAuth2 flow, step 1: have the user log into Discord to obtain an authorization code grant
    """

    return await oauth.pingfed.authorize_redirect(request, redirect_uri)

# Change the path to /authCallback later on and give the PCR to ping fed accordingly
# @router.get('/AuthPlayground/authorization_code/callback', name='authCallback')
@router.get('/authcallback', name='authCallback')
async def authCallback(request: Request):
    """OAuth2 flow, step 2: exchange the authorization code for access token
    """

    print(f"authorize_url: {authorize_url}")
    print(f"access_token_url: {access_token_url}")
    print(f"jwks_uri: {jwks_uri}")
    print(f"scope: {scope}")
    print(f"client_id: {client_id}")
    print(f"client_secret: {client_secret}")
# exchange auth code for token
    try:
            # async with httpx.AsyncClient() as client:
            #     request_data = {
            #         "code": request.query_params.get("code"),
            #         "grant_type": "authorization_code",
            #         "redirect_uri": str(request.url_for("authCallback")),
            #         "client_id": client_id,
            #         "client_secret": client_secret,
            #     }
            #     print(f"Request Data: {request_data}")
            #     print(f"Request URL: {access_token_url}")
            #     response = await client.post(access_token_url, data=request_data)
        token = await oauth.pingfed.authorize_access_token(request)
        access_token = token.get("access_token")
        id_token = token.get("id_token")
        refresh_token = token.get("refresh_token")
    
                # print(f"Response Status Code: {response.status_code}")
                # print(f"Response Data: {response.text}")
                # token = oauth.pingfed.parse_access_token_response(response)

    except OAuthError as e:
        print(f"OAuthError: {e}")
        traceback.print_exc()
        raise UnauthenticatedError() from e
    except Exception as e:
        print(f"Exception: {e}")
        traceback.print_exc()
        raise HTTPException(status_code=500, detail="Internal Server Error") from e
    
    # Bellow is required if you want to store the token in the session for later use in
    # other requests (e.g. to make authenticated requests to other APIs outside application)
    # or to refresh the token when it expires from application itself.    
    # request.session["access_token"] = token.get("access_token")
    # request.session["refresh_token"] = token.get("refresh_token")

    # Following is cookie approach is NOT working when frontend is in different domain
    # headers = {
    #             "Authorization": f"Bearer {token.get('access_token')}",
    #             "ID-Token": f"Bearer {token.get('id_token')}",
    #             "Refresh-Token": f"Bearer {token.get('refresh_token')}",
    #             "Location": f"{os.environ['FRONTEND_URL']}"
    #         }
    redirect_url = (
        f"{os.environ['FRONTEND_URL']}?"
        f"access_token={access_token}&"
        f"id_token={id_token}&"
        f"refresh_token={refresh_token}"
    )
    return RedirectResponse(url=redirect_url, status_code=307)
    # # response.set_cookie(key='access_token', value=token.get('access_token'), secure=True)
    # response.set_cookie(key='access_token', value=token.get('access_token'), samesite="None", secure=True)
    # # response.set_cookie(key='id_token', value=token.get('id_token'), secure=True)
    # response.set_cookie(key='id_token', value=token.get('id_token'), samesite="None", secure=True)
    # # response.set_cookie(key='refresh_token', value=token.get('refresh_token'), secure=True)
    # return response

    ## IF you want to set as cookies to frontend URL Which is in different domain
    # print("Token is: ", token.get("access_token"))
    # redirect_response = RedirectResponse(url=os.environ['FRONTEND_URL'])
    # redirect_response.set_cookie(
    #     key="access_token",
    #     value=token.get("access_token"),
    #     httponly=True,
    #     secure=True,
    #     samesite="None",
    # )
    # redirect_response.set_cookie(
    #     key="id_token",
    #     value=token.get("id_token"),
    #     httponly=True,
    #     secure=True,
    #     samesite="None",
    # )

    # return redirect_response

#### Login Flow End ####

# @router.get("/delete", dependencies=[Depends(authDependable)])
# def delete(request: Request):
#     if request.state.auth.role != "admin":
#         raise HTTPException(status_code=403)

#     os.remove(os.environ.get("DATABASE_URL"))


@router.get("/logout")
async def logout(request: Request):
    # response = RedirectResponse(url=os.environ['FRONTEND_URL'])
    # response.delete_cookie("access_token")
    # response.delete_cookie("id_token")
    # response.delete_cookie("refresh_token")
    # return response
    authorization: str = request.headers.get("Authorization")
    if not authorization:
        return {"message": "You are not logged in"}
    
    if not authorization.startswith("Bearer "):
        return {"message": "Invalid Authorization Scheme"}
    
    access_token = authorization[len("Bearer "):] # Extracting the token stripping the Bearer
    remove_token_from_cache(access_token)

    id_token: str = request.headers.get("ID-Token")
    if id_token and id_token.startswith("Bearer "):
        id_token = id_token[len("Bearer "):]
        remove_token_from_cache(id_token)

    return {"message": "You are sucessfully logged out"}

@router.get("/upgrade", dependencies=[Depends(db), Depends(verify_admin_token)])
def upgrade(request: Request):
    # if not request.state.is_admin:
    #     raise HTTPException(status_code=403)

    # if not os.path.exists(os.environ.get("DATABASE_BACKUP_FOLDER_URL")):
    #     os.makedirs(os.environ.get("DATABASE_BACKUP_FOLDER_URL"))

    # shutil.copy2(
    #     os.environ.get("DATABASE_URL"),
    #     f"{os.environ.get('DATABASE_BACKUP_FOLDER_URL')}{os.environ.get('MODE')}-{datetime.datetime.now().strftime('%d-%m-%Y-%H:%M:%S')}.db",
    # )

    SQL_SERVER = os.environ.get('SQLSERVER')
    SQL_DATABASE = os.environ.get('SQLDATABASE')
    
    alembic_cfg = Config()
    alembic_cfg.set_main_option("script_location", "src/database/migrations")

    if os.environ.get("ENV") != "local":
        connecting_string = f"Driver={{ODBC Driver 18 for SQL Server}};Server=tcp:{SQL_SERVER},1433;Database={SQL_DATABASE};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;Authentication=ActiveDirectoryMSI"
        params = parse.quote_plus(connecting_string).replace("%", "%%")
        alembic_cfg.set_main_option(
        "sqlalchemy.url",
        "mssql+pyodbc:///?odbc_connect=%s" % params,
    )
    else:
        alembic_cfg.set_main_option(
        "sqlalchemy.url",
        f"mssql+pyodbc://{SQL_SERVER}:{SQL_DATABASE}@MSSQL?TrustServerCertificate=yes",
    )

    try:
        dbutils.lock_db(request.state.db)
        command.upgrade(alembic_cfg, "head")

        # if there is no metadata entry, create one to handle the lock db and updated data datetimes
        metadata = request.state.db.scalar(select(MetadataModel))

        if metadata is None:
            new_metadata = MetadataModel(is_db_locked=False)
            request.state.db.add(new_metadata)
            request.state.db.commit()
    except Exception as e:
        raise e
    finally:
        dbutils.unlock_db(request.state.db)


# @router.get("/upgrade/{version}", dependencies=[Depends(authDependable)])
# def upgrade_version(request: Request, version: str):
#     if not request.state.is_admin:
#         raise HTTPException(status_code=403)

#     if not os.path.exists(os.environ.get("DATABASE_BACKUP_FOLDER_URL")):
#         os.makedirs(os.environ.get("DATABASE_BACKUP_FOLDER_URL"))

#     shutil.copy2(
#         os.environ.get("DATABASE_URL"),
#         f"{os.environ.get('DATABASE_BACKUP_FOLDER_URL')}{os.environ.get('MODE')}-{datetime.datetime.now().strftime('%d-%m-%Y-%H:%M:%S')}.db",
#     )

#     alembic_cfg = Config()
#     alembic_cfg.set_main_option("script_location", "src/database/migrations")
#     alembic_cfg.set_main_option(
#         "sqlalchemy.url", f"sqlite:///{os.environ.get('DATABASE_URL')}"
#     )
#     command.upgrade(alembic_cfg, version)


# @router.get("/downgrade/{version}", dependencies=[Depends(authDependable)])
# def downgrade_version(request: Request, version: str):
#     if not request.state.is_admin:
#         raise HTTPException(status_code=403)

#     alembic_cfg = Config()
#     alembic_cfg.set_main_option("script_location", "src/database/migrations")
#     alembic_cfg.set_main_option(
#         "sqlalchemy.url", f"sqlite:///{os.environ.get('DATABASE_URL')}"
#     )
#     command.downgrade(alembic_cfg, version)


# @router.get("/current", dependencies=[Depends(authDependable)])
# def current_version(request: Request):
#     if not request.state.is_admin:
#         raise HTTPException(status_code=403)

#     alembic_cfg = Config()
#     alembic_cfg.set_main_option("script_location", "src/database/migrations")
#     alembic_cfg.set_main_option(
#         "sqlalchemy.url", f"sqlite:///{os.environ.get('DATABASE_URL')}"
#     )
#     return command.current(alembic_cfg)


# @router.get("/list-backups", dependencies=[Depends(authDependable)])
# def list_backups(request: Request):
#     if not request.state.is_admin:
#         raise HTTPException(status_code=403)

#     backups = os.listdir(os.environ.get("DATABASE_BACKUP_FOLDER_URL"))

#     return [backup for backup in backups]

from sqlalchemy.orm import sessionmaker
@router.get("/vars", dependencies=[Depends(db), Depends(verify_admin_token)])
def vars(request: Request):
    # if not request.state.is_admin:
    #     raise HTTPException(status_code=403)
    if os.environ.get("ENV") != 'local':
        connecting_string = f"Driver={{ODBC Driver 18 for SQL Server}};Server=tcp:{os.environ.get('SQLSERVER')},1433;Database={os.environ.get('SQLDATABASE')};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;Authentication=ActiveDirectoryMSI"
        params = parse.quote_plus(connecting_string)
        engine = create_engine("mssql+pyodbc:///?odbc_connect=%s" % params).execution_options(schema_translate_map={None: os.environ.get('DATABASE_SCHEMA')})
        
    else:
        engine = create_engine( f"mssql+pyodbc://{os.environ.get('SQLSERVER')}:{os.environ.get('SQLDATABASE')}@MSSQL?TrustServerCertificate=yes").execution_options(schema_translate_map={None: os.environ.get('DATABASE_SCHEMA')})
        SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
        
    if not engine: print("--------- Connection wasn't created")
    with engine.connect() as con:
        print("---------- Connection has been created\n---------Table Query")
        result = con.execute(text(f"SELECT STRING_AGG(TABLE_NAME, ', ') AS tables_in_schema FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'app_schema'"))
        row = result.fetchone()
        print(f"-----------Tables in app_schema: \n{row}-------")
        result = con.execute(text(f"SELECT STRING_AGG(COLUMN_NAME, ', ') AS users_columns FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = 'app_schema' AND TABLE_NAME = 'users'"))
        row = result.fetchone()
        print(f"-----------Columns in Users Table in app_schema: \n{row}-------")

        result = con.execute(text(f"SELECT count(*) AS users_count FROM app_schema.users"))
        row = result.fetchone()
        print(f"-----------Row count in Users Table in app_schema: \n{row}-------")

        result = con.execute(text(f"SELECT MAX(user_id) FROM app_schema.users"))
        row = result.fetchone()
        print(f"-----------A User ID from Users: \n{row}-------")

        result = con.execute(text(f"SELECT count(*) AS metrics_count FROM app_schema.metrics"))
        row = result.fetchone()
        print(f"-----------Row count in Metrics Table in app_schema: \n{row}-------")

        result = con.execute(text(f"SELECT MAX(eb_sales_unit) FROM app_schema.metrics"))
        row = result.fetchone()
        print(f"-----------Max eb_sales_unit from Metrics: \n{row}-------")

        result = con.execute(text(f"SELECT count(*) AS users2_count FROM app_schema.users2"))
        row = result.fetchone()
        print(f"-----------Row count in Users2 Table in app_schema: \n{row}-------")

        result = con.execute(text(f"SELECT MAX(user_id) FROM app_schema.users2"))
        row = result.fetchone()
        print(f"-----------A User ID from Users2: \n{row}-------")

        # All the users with admin role
        result = con.execute(text(f"SELECT STRING_AGG(first_name, ',') as Admins FROM app_schema.users WHERE role = 'admin'"))
        row = result.fetchone()
        print(f"-----------Admins from User Table: \n{row}-------")

        # All the users with powerUser role
        result = con.execute(text(f"SELECT STRING_AGG(first_name, ',') as PowerUsers FROM app_schema.users WHERE role = 'powerUser'"))
        row = result.fetchone()
        print(f"-----------powerUsers from User Table: \n{row}-------")



    # envvars = {
    #     "Seeder_Schema": os.environ.get("Seeder_schema"),
    #     #"Database_Schema": os.environ.get("DATABASE_SCHEMA"),
    # }
    # return envvars

@router.get("/cache", dependencies=[Depends(db), Depends(verify_admin_token)])
def clear_cache(request: Request):
    clear_token_cache()
    logger.info("Cache Cleared")

router.include_router(api.router)
router.include_router(app.router)
router.include_router(auth.router)
#router.include_router(cron.router)
Editor is loading...
Leave a Comment