Untitled

 avatar
unknown
plain_text
5 months ago
7.2 kB
4
Indexable
import json
import os
import requests
from requests.exceptions import HTTPError, Timeout, RequestException
from urllib.request import urlopen
import ssl
import certifi
from authlib.jose import jwt
from datetime import datetime

from fastapi import HTTPException, Request

from src.dependables.token_cache import get_token_from_cache, set_token_to_cache
from src.dependables import logging_config
logger = logging_config.logger(__name__)

def fetch_jwk_set(jwks_uri: str):
    try:
        ###Bellow get request also works to avoid ssl issue, if you want to add verify=False, you can do it
        # jsonurl = requests.get(os.environ['JWKS_ID_TOKEN_ENDPOINT'], verify=False)
        jsonurl = requests.get(jwks_uri ,timeout=10)
        # Raise HTTPException if status code is not 200
        jsonurl.raise_for_status()
        return jsonurl.json()

    except HTTPError as http_err:
        logger.error(f'HTTP error occurred: {http_err}')
        raise HTTPException(status_code=401, detail="Failed to fetch JWKS")
    except Timeout as timeout_err:
        logger.error(f'Request timed out: {timeout_err}')
        raise HTTPException(status_code=401, detail="Failed to fetch JWKS")
    except RequestException as req_err:
        logger.error(f'Request Exception occurred: {req_err}')
        raise HTTPException(status_code=401, detail="Failed to fetch JWKS")
    except Exception as err:
        logger.error(f'Other error occurred: {err}')
        raise HTTPException(status_code=401, detail="Failed to fetch JWKS")

async def verify_idToken(id_token: str):
    #Following fetch_jwk_set does not work as it try to get jwks of ID Token
    # jwks = await oauth.pingfed.fetch_jwk_set(force=True)
    #Following code requires to avoid cert issue
    # context = ssl.create_default_context(cafile=certifi.where())
    # jsonurl = urlopen(os.environ['JWKS_ID_TOKEN_ENDPOINT'], context=context)
    # jwks = json.loads(jsonurl.read())

    # Check Cache for ID Token Content in Cache first
    jwks = get_token_from_cache('JWKS_FOR_IDTOKEN')
    if not jwks:
        logger.debug("Calling fetch_jwk_set for ID Token")
        jwks = fetch_jwk_set(os.environ['JWKS_ID_TOKEN_ENDPOINT'])
        set_token_to_cache('JWKS_FOR_IDTOKEN', jwks)
    
    
    try:
        decoded_jwt = jwt.decode(s=id_token, key=jwks)

        # We will not validate the ID Token as it gets expired in 5 minutes and we validate access token.
        # We just need to decode it to get the AD Groups and Check if user is in the group
        # decoded_jwt.validate()

        decoded_jwt.get('name')
        # print ("Name is: ", decoded_jwt.get('name'))
        
    except Exception:
        raise HTTPException(status_code=401, detail="ID Token validation failed")
    

    # Note: We will validate AD Groups coming in ID Token to gran permission to use application
    allowed_groups = os.environ['ALLOWED_GROUPS'].split(',')
    # Check if user is in the allowed groups of TD_memberOf
    if isinstance(decoded_jwt.get('TD_memberOf'), str):
        allowed = decoded_jwt.get('TD_memberOf') in allowed_groups
    else:
        allowed = any(name in decoded_jwt.get('TD_memberOf') for name in allowed_groups)

    # If contains is False, then user is not in the allowed groups
    if not allowed:
        logger.debug("User is not authorized to use this application")
        raise HTTPException(status_code=401, detail="User is not authorized to use this application")

    
    return decoded_jwt

async def verify_accessToken(access_token: str):
    #Following fetch_jwk_set does not work as it try to get jwks of ID Token
    # jwks = await oauth.pingfed.fetch_jwk_set(force=True)
    #Following code requires to avoid cert issue
    # context = ssl.create_default_context(cafile=certifi.where())
    # jsonurl = urlopen(os.environ['JWKS_ACCESS_TOKEN_ENDPOINT'], context=context)
    # jwks = json.loads(jsonurl.read())

    # Check Cache for ID Token Content in Cache first
    jwks = get_token_from_cache('JWKS_FOR_ACCESSTOKEN')
    if not jwks:
        logger.debug("Calling fetch_jwk_set for ACCESS Token")
        jwks = fetch_jwk_set(os.environ['JWKS_ACCESS_TOKEN_ENDPOINT'])
        set_token_to_cache('JWKS_FOR_ACCESSTOKEN', jwks)

    try:
        decoded_jwt = jwt.decode(s=access_token, key=jwks)
        # print ("Access token is: ", decoded_jwt)
        decoded_jwt.validate()
        
    except Exception:
        logger.error("Token validation failed")
        raise HTTPException(status_code=401, detail="Token validation failed")
    
    exp = datetime.fromtimestamp(decoded_jwt["exp"])
    if exp < datetime.now():
        logger.error("Token Expired")
        raise HTTPException(status_code=401, detail="Token Expired")
    
    return decoded_jwt

async def verify_user(request: Request):

    
    # NOTE: Here, we have to add Authorization of AD Groups, PENDING
    # Decode ID Token and check if user is in the group

    id_token: str = request.headers.get("ID-Token")

    if id_token is None:
        raise HTTPException(status_code=401, detail="ID Token header is missing")
    
    if not id_token.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid ID Token Scheme")
    
    id_token = id_token[len("Bearer "):] # Extracting the token stripping the Bearer

    # Check Cache for ID Token
    decoded_jwt_idToken = get_token_from_cache(id_token)
    # print("ID Token Data in the cache prior to setting: ", decoded_jwt_idToken)
    if not decoded_jwt_idToken:
        decoded_jwt_idToken = await verify_idToken(id_token=id_token)
        logger.info("Setting ID Token to Cache")
        set_token_to_cache(id_token, decoded_jwt_idToken)

    ## NOTE: Here, we have to check ID Token and verify the user

    # NOTE: Below, we have to check access token and verify the user
    
    authorization: str = request.headers.get("Authorization")
    # print ("Access Token inside verify_user is: ", authorization)
    
    if not authorization:
        raise HTTPException(status_code=401, detail="Authorization Header is missing")
    
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid Authorization Scheme")
    
    access_token = authorization[len("Bearer "):] # Extracting the token stripping the Bearer

    if access_token is None:
        logger.error("Access Token is missing")
        raise HTTPException(status_code=401, detail="Access Token is missing")
    
    decoded_jwt_accessToken = get_token_from_cache(access_token)
    # print("Access Token Data in the cache prior to setting: ", decoded_jwt_accessToken)
    if not decoded_jwt_accessToken:
        decoded_jwt_accessToken = await verify_accessToken(access_token=access_token)
        logger.debug("Setting Access Token to Cache")
        set_token_to_cache(access_token, decoded_jwt_accessToken)

    # set user_id to the user_id from the access token
    user_id = decoded_jwt_accessToken["userid"]

    logger.debug (f'User ID from access token: {user_id}')
    return '{ "user": "' + user_id + '" }'
Editor is loading...
Leave a Comment