Untitled
unknown
plain_text
5 months ago
7.2 kB
4
Indexable
import json import os import requests from requests.exceptions import HTTPError, Timeout, RequestException from urllib.request import urlopen import ssl import certifi from authlib.jose import jwt from datetime import datetime from fastapi import HTTPException, Request from src.dependables.token_cache import get_token_from_cache, set_token_to_cache from src.dependables import logging_config logger = logging_config.logger(__name__) def fetch_jwk_set(jwks_uri: str): try: ###Bellow get request also works to avoid ssl issue, if you want to add verify=False, you can do it # jsonurl = requests.get(os.environ['JWKS_ID_TOKEN_ENDPOINT'], verify=False) jsonurl = requests.get(jwks_uri ,timeout=10) # Raise HTTPException if status code is not 200 jsonurl.raise_for_status() return jsonurl.json() except HTTPError as http_err: logger.error(f'HTTP error occurred: {http_err}') raise HTTPException(status_code=401, detail="Failed to fetch JWKS") except Timeout as timeout_err: logger.error(f'Request timed out: {timeout_err}') raise HTTPException(status_code=401, detail="Failed to fetch JWKS") except RequestException as req_err: logger.error(f'Request Exception occurred: {req_err}') raise HTTPException(status_code=401, detail="Failed to fetch JWKS") except Exception as err: logger.error(f'Other error occurred: {err}') raise HTTPException(status_code=401, detail="Failed to fetch JWKS") async def verify_idToken(id_token: str): #Following fetch_jwk_set does not work as it try to get jwks of ID Token # jwks = await oauth.pingfed.fetch_jwk_set(force=True) #Following code requires to avoid cert issue # context = ssl.create_default_context(cafile=certifi.where()) # jsonurl = urlopen(os.environ['JWKS_ID_TOKEN_ENDPOINT'], context=context) # jwks = json.loads(jsonurl.read()) # Check Cache for ID Token Content in Cache first jwks = get_token_from_cache('JWKS_FOR_IDTOKEN') if not jwks: logger.debug("Calling fetch_jwk_set for ID Token") jwks = fetch_jwk_set(os.environ['JWKS_ID_TOKEN_ENDPOINT']) set_token_to_cache('JWKS_FOR_IDTOKEN', jwks) try: decoded_jwt = jwt.decode(s=id_token, key=jwks) # We will not validate the ID Token as it gets expired in 5 minutes and we validate access token. # We just need to decode it to get the AD Groups and Check if user is in the group # decoded_jwt.validate() decoded_jwt.get('name') # print ("Name is: ", decoded_jwt.get('name')) except Exception: raise HTTPException(status_code=401, detail="ID Token validation failed") # Note: We will validate AD Groups coming in ID Token to gran permission to use application allowed_groups = os.environ['ALLOWED_GROUPS'].split(',') # Check if user is in the allowed groups of TD_memberOf if isinstance(decoded_jwt.get('TD_memberOf'), str): allowed = decoded_jwt.get('TD_memberOf') in allowed_groups else: allowed = any(name in decoded_jwt.get('TD_memberOf') for name in allowed_groups) # If contains is False, then user is not in the allowed groups if not allowed: logger.debug("User is not authorized to use this application") raise HTTPException(status_code=401, detail="User is not authorized to use this application") return decoded_jwt async def verify_accessToken(access_token: str): #Following fetch_jwk_set does not work as it try to get jwks of ID Token # jwks = await oauth.pingfed.fetch_jwk_set(force=True) #Following code requires to avoid cert issue # context = ssl.create_default_context(cafile=certifi.where()) # jsonurl = urlopen(os.environ['JWKS_ACCESS_TOKEN_ENDPOINT'], context=context) # jwks = json.loads(jsonurl.read()) # Check Cache for ID Token Content in Cache first jwks = get_token_from_cache('JWKS_FOR_ACCESSTOKEN') if not jwks: logger.debug("Calling fetch_jwk_set for ACCESS Token") jwks = fetch_jwk_set(os.environ['JWKS_ACCESS_TOKEN_ENDPOINT']) set_token_to_cache('JWKS_FOR_ACCESSTOKEN', jwks) try: decoded_jwt = jwt.decode(s=access_token, key=jwks) # print ("Access token is: ", decoded_jwt) decoded_jwt.validate() except Exception: logger.error("Token validation failed") raise HTTPException(status_code=401, detail="Token validation failed") exp = datetime.fromtimestamp(decoded_jwt["exp"]) if exp < datetime.now(): logger.error("Token Expired") raise HTTPException(status_code=401, detail="Token Expired") return decoded_jwt async def verify_user(request: Request): # NOTE: Here, we have to add Authorization of AD Groups, PENDING # Decode ID Token and check if user is in the group id_token: str = request.headers.get("ID-Token") if id_token is None: raise HTTPException(status_code=401, detail="ID Token header is missing") if not id_token.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid ID Token Scheme") id_token = id_token[len("Bearer "):] # Extracting the token stripping the Bearer # Check Cache for ID Token decoded_jwt_idToken = get_token_from_cache(id_token) # print("ID Token Data in the cache prior to setting: ", decoded_jwt_idToken) if not decoded_jwt_idToken: decoded_jwt_idToken = await verify_idToken(id_token=id_token) logger.info("Setting ID Token to Cache") set_token_to_cache(id_token, decoded_jwt_idToken) ## NOTE: Here, we have to check ID Token and verify the user # NOTE: Below, we have to check access token and verify the user authorization: str = request.headers.get("Authorization") # print ("Access Token inside verify_user is: ", authorization) if not authorization: raise HTTPException(status_code=401, detail="Authorization Header is missing") if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid Authorization Scheme") access_token = authorization[len("Bearer "):] # Extracting the token stripping the Bearer if access_token is None: logger.error("Access Token is missing") raise HTTPException(status_code=401, detail="Access Token is missing") decoded_jwt_accessToken = get_token_from_cache(access_token) # print("Access Token Data in the cache prior to setting: ", decoded_jwt_accessToken) if not decoded_jwt_accessToken: decoded_jwt_accessToken = await verify_accessToken(access_token=access_token) logger.debug("Setting Access Token to Cache") set_token_to_cache(access_token, decoded_jwt_accessToken) # set user_id to the user_id from the access token user_id = decoded_jwt_accessToken["userid"] logger.debug (f'User ID from access token: {user_id}') return '{ "user": "' + user_id + '" }'
Editor is loading...
Leave a Comment