IdTokenVerificationWebFilter.kt

 avatar
unknown
plain_text
2 years ago
9.5 kB
4
Indexable
package sncf.solar.server.config.filters

import com.nimbusds.jose.JOSEException
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.proc.BadJOSEException
import com.nimbusds.jwt.JWT
import com.nimbusds.jwt.JWTParser
import com.nimbusds.oauth2.sdk.id.ClientID
import com.nimbusds.oauth2.sdk.id.Issuer
import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator
import io.ktor.http.HttpHeaders
import org.springframework.http.HttpStatus
import org.springframework.http.server.reactive.ServerHttpRequest
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.server.WebFilter
import org.springframework.web.server.WebFilterChain
import reactor.core.publisher.Mono
import sncf.solar.core.tech.controller.writeErrorBody
import sncf.solar.core.tech.errors.Cause
import sncf.solar.core.tech.errors.OutputErrorMessage
import sncf.solar.core.tech.errors.domains.authentication.AuthenticationErrorTypes.Values.AUTHENTICATION_AGENT_ID_MISSING
import sncf.solar.core.tech.errors.domains.authentication.AuthenticationErrorTypes.Values.AUTHENTICATION_EXPIRED_JWT_TOKEN
import sncf.solar.core.tech.errors.domains.authentication.AuthenticationErrorTypes.Values.AUTHENTICATION_FORBIDDEN
import sncf.solar.core.tech.errors.domains.authentication.AuthenticationErrorTypes.Values.AUTHENTICATION_INVALID_JWT_TOKEN
import sncf.solar.core.tech.http.Cookies.Values.EXTENDED_ID_TOKEN_COOKIE
import sncf.solar.core.tech.jackson.JacksonMapper
import sncf.solar.core.tech.monitoring.log.ILogging
import sncf.solar.server.config.AuthenticationConfiguration
import sncf.solar.server.config.AuthenticationConfiguration.Companion.APP_ROLES
import sncf.solar.server.service.encryption.EncryptionService
import sncf.solar.server.service.feature.FrontFeatureService
import sncf.solar.server.service.user.AUTHENTICATED_DATA
import sncf.solar.server.service.user.AUTHENTICATION_STATUS
import sncf.solar.server.service.user.AuthenticatedData
import sncf.solar.server.service.user.AuthenticationStatus
import sncf.solar.server.service.user.AuthenticationStatus.ANONYMOUS
import sncf.solar.server.service.user.AuthenticationStatus.AGENT_ID_MISSING
import sncf.solar.server.service.user.AuthenticationStatus.EXPIRED
import sncf.solar.server.service.user.AuthenticationStatus.FORBIDDEN
import sncf.solar.server.service.user.AuthenticationStatus.INVALID
import sncf.solar.server.service.user.AuthenticationStatus.VALID
import sncf.solar.server.service.user.user
import java.net.URL
import java.util.Date

class IdTokenVerificationWebFilter(
    authenticationConfiguration: AuthenticationConfiguration,
    private val jacksonMapper: JacksonMapper,
    private val encryptionService: EncryptionService,
    private val frontFeatureService: FrontFeatureService,
) : WebFilter {

    private val tokenValidator =
        TokenValidator(
            clientId = authenticationConfiguration.clientId,
            openId = authenticationConfiguration.openId,
        )

    override fun filter(
        serverWebExchange: ServerWebExchange,
        webFilterChain: WebFilterChain,
    ): Mono<Void> {
        try {
            val request = serverWebExchange.request
            val isAuthenticationEnabled = frontFeatureService.isAuthenticationEnabled(serverWebExchange.user())
            val token = getHeaderTokenValue(request) ?: getCookiesTokenValue(request)
            val parsedToken = if (!token.isNullOrBlank()) JWTParser.parse(token) else null
            val agentId = parsedToken.getAgentId()
            val authenticationStatus = validateTokenForRequest(request, parsedToken, agentId, isAuthenticationEnabled)

            serverWebExchange.attributes[AUTHENTICATION_STATUS] = authenticationStatus
            when (authenticationStatus) {
                VALID -> serverWebExchange.attributes[AUTHENTICATED_DATA] = AuthenticatedData(token!!, agentId!!)
                INVALID -> return unauthorizedError(serverWebExchange)
                EXPIRED -> return expiredTokenError(serverWebExchange)
                FORBIDDEN -> return forbiddenError(serverWebExchange)
                AGENT_ID_MISSING -> return agentIdMissingError(serverWebExchange)
                ANONYMOUS -> Unit
            }
        } catch (e: Exception) {
            return unauthorizedError(serverWebExchange)
                .also { logger.error(e) }
        }
        return webFilterChain.filter(serverWebExchange)
    }

    private fun JWT?.getAgentId(): String? =
        this?.jwtClaimsSet?.getStringClaim("sub")

    private fun validateTokenForRequest(
        request: ServerHttpRequest,
        token: JWT?,
        agentId: String?,
        isAuthenticationEnabled: Boolean,
    ) = when {
        !ApiVersionWebFilter.eligibleToFilter(request) -> ANONYMOUS
        token == null -> if (isAuthenticationEnabled) INVALID else ANONYMOUS
        agentId.isNullOrBlank() -> AGENT_ID_MISSING
        else -> tokenValidator.validate(token)
    }

    private fun unauthorizedError(serverWebExchange: ServerWebExchange): Mono<Void> {
        serverWebExchange.response.statusCode = HttpStatus.UNAUTHORIZED
        return writeErrorBody(
            jacksonMapper.internalObjectMapper,
            serverWebExchange.response,
            OutputErrorMessage(
                title = AUTHENTICATION_INVALID_JWT_TOKEN.title,
                type = AUTHENTICATION_INVALID_JWT_TOKEN.type,
                behaviour = AUTHENTICATION_INVALID_JWT_TOKEN.behaviour,
                status = AUTHENTICATION_INVALID_JWT_TOKEN.httpStatus,
                cause = Cause(type = "Invalid JWT token"),
            ),
        )
    }

    private fun forbiddenError(serverWebExchange: ServerWebExchange): Mono<Void> {
        serverWebExchange.response.statusCode = HttpStatus.FORBIDDEN
        return writeErrorBody(
            jacksonMapper.internalObjectMapper,
            serverWebExchange.response,
            OutputErrorMessage(
                title = AUTHENTICATION_FORBIDDEN.title,
                type = AUTHENTICATION_FORBIDDEN.type,
                behaviour = AUTHENTICATION_FORBIDDEN.behaviour,
                status = AUTHENTICATION_FORBIDDEN.httpStatus,
                cause = Cause(type = "User does not have access to the resource"),
            ),
        )
    }

    private fun expiredTokenError(serverWebExchange: ServerWebExchange): Mono<Void> {
        serverWebExchange.response.statusCode = HttpStatus.UNAUTHORIZED
        return writeErrorBody(
            jacksonMapper.internalObjectMapper,
            serverWebExchange.response,
            OutputErrorMessage(
                title = AUTHENTICATION_EXPIRED_JWT_TOKEN.title,
                type = AUTHENTICATION_EXPIRED_JWT_TOKEN.type,
                behaviour = AUTHENTICATION_EXPIRED_JWT_TOKEN.behaviour,
                status = AUTHENTICATION_EXPIRED_JWT_TOKEN.httpStatus,
                cause = Cause(type = "Expired JWT token"),
            ),
        )
    }

    private fun agentIdMissingError(serverWebExchange: ServerWebExchange): Mono<Void> {
        serverWebExchange.response.statusCode = HttpStatus.INTERNAL_SERVER_ERROR
        return writeErrorBody(
            jacksonMapper.internalObjectMapper,
            serverWebExchange.response,
            OutputErrorMessage(
                title = AUTHENTICATION_AGENT_ID_MISSING.title,
                type = AUTHENTICATION_AGENT_ID_MISSING.type,
                behaviour = AUTHENTICATION_AGENT_ID_MISSING.behaviour,
                status = AUTHENTICATION_AGENT_ID_MISSING.httpStatus,
                cause = Cause(type = "The agent id field is missing"),
            ),
        )
    }

    private fun getCookiesTokenValue(request: ServerHttpRequest): String? =
        request.cookies.getFirst(EXTENDED_ID_TOKEN_COOKIE)?.value
            ?.let { encryptionService.decrypt(it) }

    private fun getHeaderTokenValue(request: ServerHttpRequest): String? =
        request.headers.getFirst(HttpHeaders.Authorization)
            ?.replace("Bearer ", "")

    companion object : ILogging()
}

class TokenValidator(
    clientId: String,
    openId: Map<String, String>,
) {
    val validator: IDTokenValidator

    init {
        val iss = Issuer(openId["issuer"])
        val clientID = ClientID(clientId)
        val jwsAlg = JWSAlgorithm.RS256
        val jwkSetURL = URL(openId["jwks-uri"])
        validator = IDTokenValidator(iss, clientID, jwsAlg, jwkSetURL)
    }

    fun validate(token: JWT): AuthenticationStatus =
        try {
            if (token.hasAnyRoleIn(APP_ROLES)) {
                validator.validate(token, null)
                VALID
            } else {
                FORBIDDEN
            }
        } catch (e: BadJOSEException) {
            logger.warn("Invalid signature or claims (iss, aud, exp...)", e)
            if (Date() > token.jwtClaimsSet.expirationTime) {
                EXPIRED
            } else {
                INVALID
            }
        } catch (e: JOSEException) {
            logger.warn("Internal processing exception", e)
            INVALID
        }

    companion object : ILogging()
}

@Suppress("UNCHECKED_CAST")
fun JWT.hasAnyRoleIn(allowedRoles: List<String>): Boolean =
    this.jwtClaimsSet.getStringListClaim("roles")?.any { role -> role.uppercase() in allowedRoles } ?: false
Editor is loading...