Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
51 kB
3
Indexable
Never
package com.trendmicro.serapis.admin.apiservice.api.generativeai.threatcontextenrichment.definition;

import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.IOC_DATA;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.IOC_SUMMARIES;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.cleanupAssetDataForDeviceAndLocationDetails;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.cleanupAssetDataForRiskExposureAndMitigation;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.convertToTimestamp;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.extractIocsFromLogs;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateAlertSummarySingleDetectionData;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateDeviceSectionPrompt;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateThreatSummarySingleIocData;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.getLatestTimestamp;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.retrieveCleanedAssetData;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.retrieveCleanedDetectionsAndAssetId;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import javax.ws.rs.core.Response;

import org.apache.commons.text.StringSubstitutor;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import com.cysiv.elasticsearch.model.triage.Detection;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.trendmicro.serapis.apiservice.api.accounts.model.Account;
import com.trendmicro.serapis.apiservice.api.elasticsearch.backend.ElasticSearchBackend;
import com.trendmicro.serapis.apiservice.api.exceptions.AssetDataRetreivalException;
import com.trendmicro.serapis.apiservice.api.exceptions.ElasticsearchException;
import com.trendmicro.serapis.apiservice.api.exceptions.EntityNotFoundException;
import com.trendmicro.serapis.apiservice.api.exceptions.InvalidParameterException;
import com.trendmicro.serapis.apiservice.api.exceptions.UnauthorizedException;
import com.trendmicro.serapis.apiservice.api.generativeai.GenerativeAiUtils;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.definition.ThreatContextEnrichmentService;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.definition.VLThreatIntel;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.AllNarrationsResponse;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.IdType;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.VLIoC;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.FirestoreNarrationStorageUtils;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.NarrationStorageBackend;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.model.NarrationData;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.model.NarrationType;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.GenerativeAiPredictionService;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.VertexAiGeminiPredictionServiceBackend;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.VertexAiPalmPredictionServiceBackend;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.model.request.ModelParameters;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.model.request.PredictionModelType;
import com.trendmicro.serapis.apiservice.api.rem.definition.RemProxyApiService;
import com.trendmicro.serapis.apiservice.config.SerapisConfig;
import com.trendmicro.serapis.apiservice.database.AccountRepositoryService;
import com.trendmicro.serapis.apiservice.security.SecurityUtils;
import com.trendmicro.serapis.apiservice.util.DateUtils;
import com.trendmicro.serapis.apiservice.util.EmailUtils;

import io.micrometer.core.instrument.util.StringUtils;
import lombok.Getter;

@Service
public class ThreatContextEnrichmentAdminApiService {

    @Autowired
    private NarrationStorageBackend narrationStorageBackend;

    @Autowired
    private AccountRepositoryService accountRepositoryService;

    private static final Logger logger =
            LoggerFactory.getLogger(ThreatContextEnrichmentService.class);


    record DetectionsCacheKey(Account account, String entityType, String entityId) {
    }

    record AssetDataCacheKey(String accountId, String assetId) {
    }

    public record Prompts(
            String alertSummarySingleDetectionPromptTemplate,
            String alertSummaryCombineResponsesPromptTemplate,
            List<String> alertSummarySingleDetectionPromptData,
            String threatIntelSummarySingleIocPromptTemplate,
            String threatIntelSummaryCombineResponsesPromptTemplate,
            List<String> threatIntelSummarySingleIocPromptData,
            String deviceDetailsAndLocationPrompt,
            String deviceRiskExposurePrompt, String deviceMitigationPrompt) {
    }

    public record PromptDataAndResponse(String promptData, String response) {
    }

    public record NarrationTypeAndPromptResponse(NarrationType narrationType,
                                                 ThreatContextEnrichmentService.PromptDataAndResponse promptDataAndResponse) {
    }

    public static final String ALERT_SUMMARY_PROMPT_TEMPLATE_FILE_NAME = "alert-summary";
    public static final String ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE_FILE_NAME =
            "alert-summary-single-detection";
    public static final String ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME =
            "alert-summary-combine-responses";
    public static final String DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE_FILE_NAME =
            "device-and-location-details";
    public static final String DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE_FILE_NAME =
            "device-risk-exposure";
    public static final String DEVICE_MITIGATION_PROMPT_TEMPLATE_FILE_NAME =
            "device-mitigation";
    public static final String THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE_FILE_NAME =
            "threat-intel-summary-single";
    public static final String THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE_FILE_NAME =
            "threat-intel-summary-single";
    public static final String THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME =
            "threat-intel-summary-combine-responses";
    public static final List<String> NARRATION_STORAGE_RESTRICTED_REGION_PREFIXES =
            List.of("europe");

    final String ALERT_SUMMARY_PROMPT_TEMPLATE;
    final String ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE;
    final String ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
    final String DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE;
    final String DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE;
    final String DEVICE_MITIGATION_PROMPT_TEMPLATE;
    final String THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE;
    final String THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE;
    final String THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;

    @Autowired
    private ElasticSearchBackend elasticSearchBackend;

    @Autowired
    private RemProxyApiService remProxyApiService;

    @Getter
    final GenerativeAiPredictionService generativeAiPredictionService;

    private final LoadingCache<ThreatContextEnrichmentAdminApiService.DetectionsCacheKey, List<Detection>> detectionsCache;

    private final LoadingCache<ThreatContextEnrichmentAdminApiService.AssetDataCacheKey, JsonNode> assetDataCache;

    private final ExecutorService threadPool;

    // specifies how far back to search for asset data
    private final int assetDataLastSeenOnlineTimeFilterInHours;

    private final SerapisConfig.ThreatContextEnrichmentConfig threatContextEnrichmentConfig;

    @Getter
    private final VLThreatIntel vlThreatIntel;

    private final int documentTtlInDays;

    private final int prevNarrationDocumentTtlInDays;

    @Autowired
    public ThreatContextEnrichmentAdminApiService(SerapisConfig appConfig) throws IOException {
        SerapisConfig.GoogleCloud googleCloudConfig = appConfig.getGoogleCloud();
        threatContextEnrichmentConfig =
                appConfig.getThreatContextEnrichment();
        SerapisConfig.VertexAiModelConfig vertexAiPalmModelConfig =
                threatContextEnrichmentConfig.getVertexAiPalmModel();
        SerapisConfig.VertexAiModelConfig.VertexAiModelParameters palmModelParametersConfig =
                vertexAiPalmModelConfig.getModelParameters();

        SerapisConfig.VertexAiModelConfig vertexAiGeminiModelConfig =
                threatContextEnrichmentConfig.getVertexAiGeminiModel();
        SerapisConfig.VertexAiModelConfig.VertexAiModelParameters geminiModelParametersConfig =
                vertexAiGeminiModelConfig.getModelParameters();

        // load max detections and max indicators per detection limits for alert summary
        SerapisConfig.AlertSummaryConfig alertSummaryConfig =
                threatContextEnrichmentConfig.getAlertSummary();

        SerapisConfig.ThreatContextEnrichmentConfig.CacheConfig detectionsCacheConfig =
                alertSummaryConfig.getDetectionsCache();

        SerapisConfig.DeviceSummariesConfig deviceSummariesConfig =
                threatContextEnrichmentConfig.getDeviceSummaries();
        SerapisConfig.ThreatContextEnrichmentConfig.CacheConfig assetDataCacheConfig =
                deviceSummariesConfig.getAssetDataCache();

        VertexAiPalmPredictionServiceBackend vertexAiPalmModelPredictionService =
                new VertexAiPalmPredictionServiceBackend(googleCloudConfig.getProject(),
                        vertexAiPalmModelConfig.getApiHostname(),
                        vertexAiPalmModelConfig.getApiLocation(),
                        vertexAiPalmModelConfig.getModelPublisher(),
                        vertexAiPalmModelConfig.getModelName(),
                        new ModelParameters(palmModelParametersConfig.getTemperature().doubleValue(),
                                palmModelParametersConfig.getTopP().doubleValue(),
                                palmModelParametersConfig.getTopK(),
                                palmModelParametersConfig.getMaxOutputTokens(),
                                Collections.emptyList(),
                                palmModelParametersConfig.getCandidateCount()));

        VertexAiGeminiPredictionServiceBackend vertexAiGeminiModelPredictionService =
                new VertexAiGeminiPredictionServiceBackend(googleCloudConfig.getProject(),
                        vertexAiGeminiModelConfig.getApiLocation(),
                        vertexAiGeminiModelConfig.getModelName(),
                        new ModelParameters(geminiModelParametersConfig.getTemperature().doubleValue(),
                                geminiModelParametersConfig.getTopP().doubleValue(),
                                geminiModelParametersConfig.getTopK(),
                                geminiModelParametersConfig.getMaxOutputTokens(),
                                Collections.emptyList(),
                                geminiModelParametersConfig.getCandidateCount()));

        this.generativeAiPredictionService =
                new GenerativeAiPredictionService(vertexAiPalmModelPredictionService,
                        vertexAiGeminiModelPredictionService);

        this.documentTtlInDays =
                threatContextEnrichmentConfig.getNarrationStorage().getDocumentTtlDays();

        this.prevNarrationDocumentTtlInDays =
                threatContextEnrichmentConfig.getNarrationStorage()
                        .getPreviousNarrationDocumentTtlDays();

        this.detectionsCache = CacheBuilder.newBuilder()
                .maximumSize(detectionsCacheConfig.getMaximumSize())
                .refreshAfterWrite(
                        Duration.ofSeconds(detectionsCacheConfig.getRefreshAfterWriteSeconds()))
                .expireAfterWrite(
                        Duration.ofSeconds(detectionsCacheConfig.getExpireAfterWriteSeconds()))
                .build(CacheLoader.asyncReloading(new CacheLoader<>() {
                    @Override
                    public List<Detection> load(ThreatContextEnrichmentAdminApiService.DetectionsCacheKey key)
                            throws Exception {
                        logger.debug("Loading detections for key {}", key);
                        return retrieveDetections(key.account(), key.entityType(), key.entityId());
                    }
                }, Executors.newCachedThreadPool()));

        this.assetDataCache = CacheBuilder.newBuilder()
                .maximumSize(assetDataCacheConfig.getMaximumSize())
                .refreshAfterWrite(
                        Duration.ofSeconds(assetDataCacheConfig.getRefreshAfterWriteSeconds()))
                .expireAfterWrite(
                        Duration.ofSeconds(assetDataCacheConfig.getExpireAfterWriteSeconds()))
                .build(CacheLoader.asyncReloading(new CacheLoader<>() {
                    @Override
                    public JsonNode load(ThreatContextEnrichmentAdminApiService.AssetDataCacheKey key) throws AssetDataRetreivalException {
                        logger.debug("Loading asset data for key {}", key);
                        return retrieveAssetData(key.accountId(), key.assetId());

                    }
                }, Executors.newCachedThreadPool()));

        this.threadPool = Executors.newCachedThreadPool();

        this.assetDataLastSeenOnlineTimeFilterInHours =
                deviceSummariesConfig.getAssetDataLastSeenOnlineTimeFilterInHours();

        String promptTemplatesDir = threatContextEnrichmentConfig
                .getPromptTemplatesDir();
        logger.debug("promptTemplatesDir is {}", promptTemplatesDir);

        // load alert summary prompt template
        Path templateFile = Paths.get(promptTemplatesDir, ALERT_SUMMARY_PROMPT_TEMPLATE_FILE_NAME);
        this.ALERT_SUMMARY_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("ALERT_SUMMARY_PROMPT_TEMPLATE is {}", ALERT_SUMMARY_PROMPT_TEMPLATE);

        // load alert summary single detection prompt template
        templateFile =
                Paths.get(promptTemplatesDir, ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE_FILE_NAME);
        this.ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE is {}",
                ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE);

        // load alert summary combine responses prompt template
        templateFile = Paths.get(promptTemplatesDir,
                ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME);
        this.ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE is {}",
                ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE);

        // load device and location details prompt template
        templateFile =
                Paths.get(promptTemplatesDir, DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE_FILE_NAME);
        this.DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE is {}",
                DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE);

        // load device risk exposure prompt template
        templateFile =
                Paths.get(promptTemplatesDir, DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE_FILE_NAME);
        this.DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE is {}",
                DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE);

        // load device mitigation prompt template
        templateFile = Paths.get(promptTemplatesDir, DEVICE_MITIGATION_PROMPT_TEMPLATE_FILE_NAME);
        this.DEVICE_MITIGATION_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("DEVICE_MITIGATION_PROMPT_TEMPLATE is {}", DEVICE_MITIGATION_PROMPT_TEMPLATE);

        // load threat intel prompt template
        templateFile =
                Paths.get(promptTemplatesDir,
                        THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE_FILE_NAME);
        this.THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE is {}",
                THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE);

        templateFile =
                Paths.get(promptTemplatesDir, THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE_FILE_NAME);
        this.THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE = Files.readString(templateFile);
        logger.debug("THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE is {}",
                THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE);

        // load threat intel summary combine responses prompt template
        templateFile =
                Paths.get(promptTemplatesDir,
                        THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME);
        this.THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE =
                Files.readString(templateFile);
        logger.debug("THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE is {}",
                THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE);

        this.vlThreatIntel = new VLThreatIntel(appConfig);
    }

    public int addExpiresAtToDocuments(Instant beforeTimestamp, int numDaysAfter)
            throws ExecutionException, InterruptedException {

        List<Account> accounts = accountRepositoryService.getAccountsPrivileged();

        int numDocumentsUpdated = 0;
        for (Account account : accounts) {
            numDocumentsUpdated += narrationStorageBackend
                    .addExpiresAtToDocuments(account.getId(), beforeTimestamp, numDaysAfter);
        }
        return numDocumentsUpdated;
    }

    public AllNarrationsResponse generateAllNarrationsWithStorage(String accountId,
                                                                  IdType idType,
                                                                  Optional<String> entityType,
                                                                  Optional<String> entityId,
                                                                  Optional<String> caseId,
                                                                  Optional<String> assetId,
                                                                  Optional<ZoneId> timeZone,
                                                                  boolean forceRefresh,
                                                                  Optional<PredictionModelType> predictionModelTypeOptional)
            throws IOException, InvalidParameterException, ExecutionException,
            EntityNotFoundException, ElasticsearchException, AssetDataRetreivalException,
            InterruptedException {


        Account account = accountRepositoryService.getAccountPrivileged(accountId);
        GenerativeAiUtils.validateAllowAiEnabled(account);
        boolean isAccountInNarrationStorageRestrictedRegion =
                ThreatContextEnrichmentUtils.isAccountInNarrationStorageRestrictedRegion(account,
                        NARRATION_STORAGE_RESTRICTED_REGION_PREFIXES);

        String documentId;

        // generate the document id based on the idType
        if (idType == IdType.ENTITY) {
            if (entityId.isEmpty() || entityType.isEmpty()) {
                throw new InvalidParameterException(
                        "entity_id and entity_type cannot be empty when id_type is ENTITY");
            }
            documentId = FirestoreNarrationStorageUtils
                    .generateDocumentIdFromEntity(entityType.get(), entityId.get(), caseId);

        } else if (idType == IdType.ASSET) {
            if (assetId.isEmpty()) {
                throw new InvalidParameterException(
                        "asset_id cannot be empty when id_type is ASSET");
            }

            documentId = FirestoreNarrationStorageUtils.generateDocumentIdFromAsset(assetId.get());
        } else {
            throw new InvalidParameterException("Invalid id_type");
        }

        PredictionModelType predictionModelType =
                predictionModelTypeOptional.orElseGet(threatContextEnrichmentConfig::getDefaultAiModel);

        // variables to store the final responses for each section
        String alertSummary = "", deviceLocationAndDetails = "",
                deviceRiskExposure = "", deviceMitigation = "", threatIntelSummary = "";

        Instant generationTime;

        // save email of the user that made the API call
        String invokingUserEmail = SecurityUtils.getInvokingUserEmail();

        // generate all the prompts that could be executed
        ThreatContextEnrichmentService.Prompts currentPrompts =
                generateAllPrompts(account, entityType, entityId, assetId, timeZone);

        // retrieve any existing narration data from storage
        Optional<NarrationData> narrationDataOptional =
                narrationStorageBackend.getNarrationData(accountId, documentId);

        if (narrationDataOptional.isPresent()) {

            // check if user can perform a force refresh
            if (forceRefresh && !EmailUtils.isInternalUserAddress(invokingUserEmail)) {
                throw new UnauthorizedException(
                        "You do not have permissions to perform a force refresh.");
            }

            // narration data exists, need to compare data/prompts to see if any sections need to be regenerated
            NarrationData narrationData = narrationDataOptional.get();

            // initialize generation time to the latest timestamp found in the narration data
            generationTime = Instant.parse(getLatestTimestamp(narrationData));

            // the generation of narrations to be updated will be done concurrently using a thread pool
            // this variable stores the lists of tasks to be executed concurrently for generating the narrations
            List<Callable<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> tasks = new ArrayList<>();

            // stores the fields that need to be updated in the storage
            Map<String, Object> updates = new HashMap<>();

            // check if the alert summary section needs to be regenerated
            if (FirestoreNarrationStorageUtils
                    .shouldUseStoredNarration(narrationData,
                            currentPrompts,
                            NarrationType.ALERT_SUMMARY,
                            forceRefresh)) {

                // section does not need to be regenerated, use the stored narration
                alertSummary = narrationData.getAlertSummaryResponse();
            } else {
                // regenerate alert summary section
                logger.debug("Regenerating alert summary section...");

                tasks.add(() -> ThreatContextEnrichmentUtils.generateAlertSummaryTask(this,
                        currentPrompts.alertSummarySingleDetectionPromptData(),
                        predictionModelType));
            }

            // check if the threat intel summary section needs to be regenerated
            if (FirestoreNarrationStorageUtils
                    .shouldUseStoredNarration(narrationData,
                            currentPrompts,
                            NarrationType.THREAT_INTEL_SUMMARY,
                            forceRefresh)) {

                // section does not need to be regenerated, use the stored narration
                threatIntelSummary = narrationData.getThreatIntelSummaryResponse();
            } else {
                // regenerate threat intel summary section
                logger.debug("Regenerating threat intel summary section...");

                tasks.add(() -> ThreatContextEnrichmentUtils.generateThreatIntelSummaryTask(this,
                        currentPrompts.threatIntelSummarySingleIocPromptData(),
                        predictionModelType));
            }

            // check if we need to regenerate device location and details section
            if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData,
                    currentPrompts,
                    NarrationType.DEVICE_LOCATION_AND_DETAILS,
                    forceRefresh)) {

                // section does not need to be regenerated, use the stored narration
                deviceLocationAndDetails = narrationData.getDeviceDetailsResponse();

            } else {
                // regenerate device location and details section
                logger.debug("Regenerating device and location details section...");

                tasks.add(
                        () -> ThreatContextEnrichmentUtils.generateDeviceDetailsAndLocationTask(this,
                                currentPrompts.deviceDetailsAndLocationPrompt(),
                                predictionModelType));
            }

            // check if we need to regenerate device risk exposure section
            if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData,
                    currentPrompts,
                    NarrationType.DEVICE_RISK_EXPOSURE,
                    forceRefresh)) {

                // section does not need to be regenerated, use the stored narration
                deviceRiskExposure = narrationData.getDeviceRiskResponse();

            } else {
                // regenerate device risk exposure section
                logger.debug("Regenerating device risk exposure section...");

                tasks.add(
                        () -> ThreatContextEnrichmentUtils.generateDeviceRiskExposureTask(this,
                                currentPrompts.deviceRiskExposurePrompt(),
                                predictionModelType));
            }

            // check if we need to regenerate device mitigation section
            if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData,
                    currentPrompts,
                    NarrationType.DEVICE_MITIGATION,
                    forceRefresh)) {

                // section does not need to be regenerated, use the stored narration
                deviceMitigation = narrationData.getDeviceMitigationResponse();

            } else {
                // regenerate device mitigation section
                logger.debug("Regenerating device mitigation section...");

                tasks
                        .add(() -> ThreatContextEnrichmentUtils.generateDeviceMitigationPromptTask(this,
                                currentPrompts.deviceMitigationPrompt(),
                                predictionModelType));
            }

            // use thread pool to execute the tasks
            List<Future<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> futures =
                    threadPool.invokeAll(tasks);

            // wait for tasks to finish executing and process the results
            // init the lastRegeneratedTimeStamp, which will be used for all the sections
            Instant lastRegeneratedTimeStamp = DateUtils.nowTruncatedToMillis();
            NarrationData.NarrationDataBuilder previousNarrationDataBuilder = NarrationData.builder();

            for (var future : futures) {
                var narrationTypeAndPromptResponse = future.get();

                // process the results based on narration type
                ThreatContextEnrichmentUtils.processNarrationUpdateResults(currentPrompts,
                        narrationTypeAndPromptResponse,
                        updates,
                        lastRegeneratedTimeStamp.toString(),
                        invokingUserEmail,
                        this,
                        predictionModelType,
                        narrationData,
                        previousNarrationDataBuilder);
            }

            // update narration data in storage if there are updates
            if (!updates.isEmpty()) {
                // add a previous narration document containing the previous prompts and responses that will be updated
                previousNarrationDataBuilder.expiresAt(convertToTimestamp(
                        lastRegeneratedTimeStamp.plus(prevNarrationDocumentTtlInDays, // set expiresAt timestamp
                                ChronoUnit.DAYS)));
                narrationStorageBackend.storePreviousNarrationData(accountId,
                        documentId,
                        lastRegeneratedTimeStamp.toString(),
                        previousNarrationDataBuilder.build());

                // update the expiresAt field, which is used for TTL deletion
                updates.put(NarrationData.Fields.expiresAt,
                        convertToTimestamp(
                                lastRegeneratedTimeStamp.plus(documentTtlInDays, ChronoUnit.DAYS)));

                narrationStorageBackend
                        .updateNarrationData(accountId, documentId, updates);

                // update final section responses if there was an update
                alertSummary =
                        (String) updates.getOrDefault(NarrationData.Fields.alertSummaryResponse,
                                alertSummary);

                deviceLocationAndDetails =
                        (String) updates.getOrDefault(NarrationData.Fields.deviceDetailsResponse,
                                deviceLocationAndDetails);

                deviceRiskExposure =
                        (String) updates.getOrDefault(NarrationData.Fields.deviceRiskResponse,
                                deviceRiskExposure);

                deviceMitigation =
                        (String) updates.getOrDefault(NarrationData.Fields.deviceMitigationResponse,
                                deviceMitigation);

                threatIntelSummary =
                        (String) updates.getOrDefault(NarrationData.Fields.threatIntelSummaryResponse,
                                threatIntelSummary);

                generationTime = lastRegeneratedTimeStamp;
            }

        } else {
            // no existing narrations, generate new narrations and then store it

            NarrationData.NarrationDataBuilder builder = NarrationData.builder();

            long start = System.currentTimeMillis();

            // the generation of narrations will be executed concurrently using a thread pool
            // stores the lists of tasks to be executed concurrently for generating the narrations
            List<Callable<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> tasks = new ArrayList<>();

            // if there is single detection alert summary data, generate alert summary
            if (!currentPrompts.alertSummarySingleDetectionPromptData().isEmpty()) {
                tasks.add(() -> ThreatContextEnrichmentUtils.generateAlertSummaryTask(this,
                        currentPrompts.alertSummarySingleDetectionPromptData(),
                        predictionModelType));
            }

            // if device and location details prompt is not blank, generate narration
            if (StringUtils.isNotBlank(currentPrompts.deviceDetailsAndLocationPrompt())) {
                tasks.add(
                        () -> ThreatContextEnrichmentUtils.generateDeviceDetailsAndLocationTask(this,
                                currentPrompts.deviceDetailsAndLocationPrompt(),
                                predictionModelType));
            }

            // if device risk exposure prompt is not blank, generate narration
            if (StringUtils.isNotBlank(currentPrompts.deviceRiskExposurePrompt())) {
                tasks.add(() -> ThreatContextEnrichmentUtils.generateDeviceRiskExposureTask(this,
                        currentPrompts.deviceRiskExposurePrompt(),
                        predictionModelType));
            }

            // if device mitigation prompt is not blank, generate narration
            if (StringUtils.isNotBlank(currentPrompts.deviceMitigationPrompt())) {
                tasks
                        .add(() -> ThreatContextEnrichmentUtils.generateDeviceMitigationPromptTask(this,
                                currentPrompts.deviceMitigationPrompt(),
                                predictionModelType));
            }

            // if there is ioc data, generate threat intel summary
            if (!currentPrompts.threatIntelSummarySingleIocPromptData().isEmpty()) {
                tasks.add(() -> ThreatContextEnrichmentUtils.generateThreatIntelSummaryTask(this,
                        currentPrompts.threatIntelSummarySingleIocPromptData(),
                        predictionModelType));
            }

            // use thread pool to execute the tasks
            List<Future<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> futures =
                    threadPool.invokeAll(tasks);

            // wait for tasks to finish executing and process the results
            for (var future : futures) {
                var narrationTypeAndPromptResponse = future.get();

                // process the results based on narration type
                ThreatContextEnrichmentUtils.processNarrationGenerationResults(currentPrompts,
                        narrationTypeAndPromptResponse,
                        builder, // the NarrationDataBuilder, which is used to store data that we want to send
                        this,
                        predictionModelType);
            }

            // populate creation fields and build NarrationData
            generationTime = DateUtils.nowTruncatedToMillis();
            NarrationData newNarrationData =
                    builder.created(generationTime.toString())
                            .createdBy(invokingUserEmail)
                            .expiresAt(convertToTimestamp(generationTime
                                    .plus(documentTtlInDays, ChronoUnit.DAYS)))
                            .build();

            // only store narration data if it is not in a restricted region
            if (!isAccountInNarrationStorageRestrictedRegion) {
                // store narration data
                narrationStorageBackend.storeNarrationData(accountId, documentId, newNarrationData);
            } else {
                logger.debug(
                        "Not storing narration data with document id {} for account {}, which is in a restricted region",
                        documentId,
                        accountId);
                // setting documentId to null so the UI knows that the narration is not stored
                documentId = null;
            }

            // save final section responses to be returned
            alertSummary = newNarrationData.getAlertSummaryResponse();
            deviceLocationAndDetails = newNarrationData.getDeviceDetailsResponse();
            deviceRiskExposure = newNarrationData.getDeviceRiskResponse();
            deviceMitigation = newNarrationData.getDeviceMitigationResponse();
            threatIntelSummary = newNarrationData.getThreatIntelSummaryResponse();

            logger.debug("Generate and store narration took {}ms",
                    System.currentTimeMillis() - start);
        }

        // sections that were not generated will be an empty string
        return new AllNarrationsResponse(documentId,
                alertSummary,
                deviceLocationAndDetails,
                deviceRiskExposure,
                deviceMitigation,
                threatIntelSummary,
                generationTime.toString());

    }

    private List<Detection> retrieveDetections(Account account,
                                               String entityType,
                                               String entityId)
            throws EntityNotFoundException, ElasticsearchException {
        return elasticSearchBackend
                .getDetectionsForThreatContextEnrichment(account, entityType, entityId);
    }

    private List<Detection> retrieveDetectionsFromCache(String entityType,
                                                        String entityId,
                                                        Account account)
            throws EntityNotFoundException, ElasticsearchException, ExecutionException {
        try {
            /**
             * Note: retrieving account here instead of in {@link #retrieveDetections} because
             * retrieveDetections() can be invoked by the Loading Cache threads which don't have
             * access to authorization information needed by the @SecureDao annotation
             */

            return detectionsCache.get(new ThreatContextEnrichmentAdminApiService.DetectionsCacheKey(account, entityType, entityId));
        } catch (ExecutionException ex) {
            if (ex.getCause() instanceof EntityNotFoundException entityNotFoundException) {
                throw entityNotFoundException;
            } else if (ex.getCause() instanceof ElasticsearchException elasticsearchException) {
                throw elasticsearchException;
            } else {
                throw ex;
            }
        }
    }

    private ThreatContextEnrichmentService.Prompts generateAllPrompts(Account account,
                                                                      Optional<String> entityType,
                                                                      Optional<String> entityId,
                                                                      Optional<String> assetId,
                                                                      Optional<ZoneId> timeZone)
            throws ElasticsearchException, EntityNotFoundException, ExecutionException,
            IOException, AssetDataRetreivalException {

        List<String> alertSummarySingleDetectionPromptData = new ArrayList<>();
        List<String> threatSummarySingleIocPromptData = new ArrayList<>();

        String deviceLocationAndDetailsPrompt = "", deviceRiskExposurePrompt = "",
                deviceMitigationPrompt = "";

        if (entityId.isPresent() && entityType.isPresent()) {
            // retrieve and clean up threat data
            List<Detection> detections;

            detections =
                    retrieveDetectionsFromCache(entityType.get(), entityId.get(), account);


            ThreatContextEnrichmentService.CleanedDetectionsAndAssetId cleanedDetectionsAndAssetId =
                    retrieveCleanedDetectionsAndAssetId(detections,
                            threatContextEnrichmentConfig.getAlertSummary(),
                            timeZone);

            // save the data used to generate the alert summaries for individual detections
            alertSummarySingleDetectionPromptData = generateAlertSummarySingleDetectionData(
                    cleanedDetectionsAndAssetId.cleanedDetections());

            // save the data used to generate the threat intel summaries for individual IoCs
            List<VLIoC> iocs = getIocsFromDetections(detections);
            threatSummarySingleIocPromptData = generateThreatSummarySingleIocData(iocs);

            // if retrieved an asset id from the detection and no asset id is supplied, set the asset id
            if (StringUtils.isNotBlank(cleanedDetectionsAndAssetId.assetId())
                    && assetId.isEmpty()) {

                assetId = Optional.of(cleanedDetectionsAndAssetId.assetId());
            }
        }

        if (assetId.isPresent()) {
            // retrieve the asset data
            JsonNode assetData = retrieveAssetDataFromCache(account.getId(), assetId.get());

            // only generate sections if asset data is not empty
            if (!assetData.isEmpty()) {

                // generate device location and details prompt
                Map<String, Object> cleanedAssetData =
                        retrieveCleanedAssetData(assetData, timeZone);
                cleanupAssetDataForDeviceAndLocationDetails(cleanedAssetData);
                deviceLocationAndDetailsPrompt = generateDeviceSectionPrompt(cleanedAssetData,
                        DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE);

                // generate device risk exposure prompt
                cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone);
                cleanupAssetDataForRiskExposureAndMitigation(cleanedAssetData);
                deviceRiskExposurePrompt = generateDeviceSectionPrompt(cleanedAssetData,
                        DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE);

                // generate device mitigation prompt
                cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone);
                cleanupAssetDataForRiskExposureAndMitigation(cleanedAssetData);
                deviceMitigationPrompt = generateDeviceSectionPrompt(cleanedAssetData,
                        DEVICE_MITIGATION_PROMPT_TEMPLATE);
            }
        }

        return new ThreatContextEnrichmentService.Prompts(
                ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE,
                ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE,
                alertSummarySingleDetectionPromptData,
                THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE,
                THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE,
                threatSummarySingleIocPromptData,
                deviceLocationAndDetailsPrompt,
                deviceRiskExposurePrompt,
                deviceMitigationPrompt);
    }

    public String executePrompt(String prompt, PredictionModelType predictionModelType)
            throws InvalidParameterException, IOException {
        return generativeAiPredictionService.predictTextPrompt(prompt, predictionModelType);
    }

    private String generateThreatIntelSummaryUsingSingleInvocationMode(String iocData)
            throws IOException,
            InvalidParameterException {
        Map<String, String> promptData =
                Map.of(IOC_DATA, iocData);

        String threatIntelSummaryPrompt =
                StringSubstitutor.replace(THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE, promptData);

        logger.debug("ThreatIntelSummary prompt is:\n{}", threatIntelSummaryPrompt);


        return generativeAiPredictionService.predictTextPrompt(threatIntelSummaryPrompt,
                PredictionModelType.TEXT_BISON_32K);
    }

    private List<VLIoC> getIocsFromDetections(List<Detection> detections) {
        // Extract the IoCs from the logs
        Set<String> extractedIoc = new HashSet<>();
        for (Detection det : detections)
            extractedIoc.addAll(extractIocsFromLogs(det));

        // Retrieve Threat Intel from the bigQuery Table
        ArrayList<VLIoC> feedIocList = new ArrayList<>();
        for (String ioc : extractedIoc) {
            VLIoC feedIoc = vlThreatIntel.getIoCInfo(ioc);
            if (feedIoc != null) {
                feedIocList.add(feedIoc);
            }
        }
        return feedIocList;
    }

    private String generateThreatIntelSummaryForSingleIoc(String iocData,
                                                          PredictionModelType predictionModelType)
            throws InvalidParameterException, IOException {
        Map<String, String> promptData = Map.of(IOC_DATA, iocData);

        String threatIntelSummaryPrompt =
                StringSubstitutor.replace(THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE,
                        promptData);

        logger.debug("generateThreatIntelSummaryForSingleIoc prompt is:\n{}",
                threatIntelSummaryPrompt);

        return generativeAiPredictionService.predictTextPrompt(threatIntelSummaryPrompt,
                predictionModelType);
    }

    public ThreatContextEnrichmentService.PromptDataAndResponse generateThreatIntelSummaryUsingMultipleInvocationMode(List<String> iocsForPrompt,
                                                                                                                      PredictionModelType predictionModelType)
            throws IOException, InvalidParameterException, InterruptedException,
            ExecutionException {

        StringBuilder modelResponsesStr = new StringBuilder();

        // STEP 1: generate a summary for each individual ioc and save it

        // generate the summary for each ioc in parallel using an ExecutorService
        List<Callable<String>> tasks = new ArrayList<>();
        for (String iocDataForPrompt : iocsForPrompt) {
            tasks.add(() -> {
                String modelResponse =
                        generateThreatIntelSummaryForSingleIoc(iocDataForPrompt,
                                predictionModelType);

                logger.debug("Model response is: " + modelResponse);
                return modelResponse;
            });
        }

        List<Future<String>> futures = threadPool.invokeAll(tasks);

        for (Future<String> future : futures) {
            String modelResponse = future.get();
            modelResponsesStr.append(modelResponse).append("\n\n");
        }

        //STEP 2: generate an overall summary of the individual summaries
        // Note: we still perform step 2 even if there is only one detection summary because we want to keep the
        // summary style and tone consistent

        Map<String, String> promptData = Map.of(IOC_SUMMARIES, modelResponsesStr.toString());

        String prompt =
                StringSubstitutor.replace(THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE,
                        promptData);

        logger.debug("threat intel summary combine responses prompt is: {}", prompt);

        String response = generativeAiPredictionService
                .predictTextPrompt(prompt, predictionModelType);

        return new ThreatContextEnrichmentService.PromptDataAndResponse(modelResponsesStr.toString(), response);
    }

    private JsonNode retrieveAssetDataFromCache(String accountId, String assetId)
            throws AssetDataRetreivalException, ExecutionException {
        try {
            return assetDataCache.get(new ThreatContextEnrichmentAdminApiService.AssetDataCacheKey(accountId, assetId));
        } catch (ExecutionException ex) {
            if (ex.getCause() instanceof AssetDataRetreivalException assetDataRetreivalException) {
                throw assetDataRetreivalException;
            } else {
                throw ex;
            }
        }
    }

    private JsonNode retrieveAssetData(String accountId, String assetId)
            throws AssetDataRetreivalException {
        // retrieve asset data using assetDataLastSeenOnlineTimeFilterInHours
        Response response = remProxyApiService.getAsset(accountId,
                assetId,
                Instant.now()
                        .minus(assetDataLastSeenOnlineTimeFilterInHours, ChronoUnit.HOURS)
                        .toString(),
                Instant.now().toString());

        // check if there was an error retrieving the asset data
        if (response.getStatusInfo().getFamily() == Response.Status.Family.CLIENT_ERROR
                || response.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) {
            // treat 404 Not Found errors as no asset data found because this what is returned by the REM API when asset data is not found
            if (response.getStatusInfo().getStatusCode() == HttpStatus.SC_NOT_FOUND) {
                logger.debug("Asset data not found for account {} asset {}", accountId, assetId);
                // return empty JsonNode
                return JsonNodeFactory.instance.objectNode();
            }
            logger.error("Failed to retrieve asset data from REM: {}", response);
            throw new AssetDataRetreivalException(response.getStatusInfo().toEnum(),
                    "Failed to retrieve asset data from REM: " + response.getEntity());
        }

        JsonNode assetData = (JsonNode) response.getEntity();

        // remove fields in asset data that are not needed for generating the narrations and we don't want to cache
        ((ObjectNode) assetData).remove("events");

        return assetData;
    }

    public String getAlertSummarySingleDetectionPromptTemplate() {
        return ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE;
    }

    public String getAlertSummaryCombineResponsesPromptTemplate() {
        return ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
    }

    public String getThreatIntelSummarySingleIocPromptTemplate() {
        return THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE;
    }

    public String getThreatIntelSummaryCombineResponsesPromptTemplate() {
        return THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
    }
}
Leave a Comment