Untitled
unknown
plain_text
a month ago
51 kB
3
Indexable
Never
package com.trendmicro.serapis.admin.apiservice.api.generativeai.threatcontextenrichment.definition; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.IOC_DATA; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.IOC_SUMMARIES; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.cleanupAssetDataForDeviceAndLocationDetails; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.cleanupAssetDataForRiskExposureAndMitigation; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.convertToTimestamp; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.extractIocsFromLogs; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateAlertSummarySingleDetectionData; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateDeviceSectionPrompt; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateThreatSummarySingleIocData; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.getLatestTimestamp; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.retrieveCleanedAssetData; import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.retrieveCleanedDetectionsAndAssetId; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import javax.ws.rs.core.Response; import org.apache.commons.text.StringSubstitutor; import org.apache.http.HttpStatus; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import com.cysiv.elasticsearch.model.triage.Detection; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.trendmicro.serapis.apiservice.api.accounts.model.Account; import com.trendmicro.serapis.apiservice.api.elasticsearch.backend.ElasticSearchBackend; import com.trendmicro.serapis.apiservice.api.exceptions.AssetDataRetreivalException; import com.trendmicro.serapis.apiservice.api.exceptions.ElasticsearchException; import com.trendmicro.serapis.apiservice.api.exceptions.EntityNotFoundException; import com.trendmicro.serapis.apiservice.api.exceptions.InvalidParameterException; import com.trendmicro.serapis.apiservice.api.exceptions.UnauthorizedException; import com.trendmicro.serapis.apiservice.api.generativeai.GenerativeAiUtils; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.definition.ThreatContextEnrichmentService; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.definition.VLThreatIntel; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.AllNarrationsResponse; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.IdType; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.VLIoC; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.FirestoreNarrationStorageUtils; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.NarrationStorageBackend; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.model.NarrationData; import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.model.NarrationType; import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.GenerativeAiPredictionService; import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.VertexAiGeminiPredictionServiceBackend; import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.VertexAiPalmPredictionServiceBackend; import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.model.request.ModelParameters; import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.model.request.PredictionModelType; import com.trendmicro.serapis.apiservice.api.rem.definition.RemProxyApiService; import com.trendmicro.serapis.apiservice.config.SerapisConfig; import com.trendmicro.serapis.apiservice.database.AccountRepositoryService; import com.trendmicro.serapis.apiservice.security.SecurityUtils; import com.trendmicro.serapis.apiservice.util.DateUtils; import com.trendmicro.serapis.apiservice.util.EmailUtils; import io.micrometer.core.instrument.util.StringUtils; import lombok.Getter; @Service public class ThreatContextEnrichmentAdminApiService { @Autowired private NarrationStorageBackend narrationStorageBackend; @Autowired private AccountRepositoryService accountRepositoryService; private static final Logger logger = LoggerFactory.getLogger(ThreatContextEnrichmentService.class); record DetectionsCacheKey(Account account, String entityType, String entityId) { } record AssetDataCacheKey(String accountId, String assetId) { } public record Prompts( String alertSummarySingleDetectionPromptTemplate, String alertSummaryCombineResponsesPromptTemplate, List<String> alertSummarySingleDetectionPromptData, String threatIntelSummarySingleIocPromptTemplate, String threatIntelSummaryCombineResponsesPromptTemplate, List<String> threatIntelSummarySingleIocPromptData, String deviceDetailsAndLocationPrompt, String deviceRiskExposurePrompt, String deviceMitigationPrompt) { } public record PromptDataAndResponse(String promptData, String response) { } public record NarrationTypeAndPromptResponse(NarrationType narrationType, ThreatContextEnrichmentService.PromptDataAndResponse promptDataAndResponse) { } public static final String ALERT_SUMMARY_PROMPT_TEMPLATE_FILE_NAME = "alert-summary"; public static final String ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE_FILE_NAME = "alert-summary-single-detection"; public static final String ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME = "alert-summary-combine-responses"; public static final String DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE_FILE_NAME = "device-and-location-details"; public static final String DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE_FILE_NAME = "device-risk-exposure"; public static final String DEVICE_MITIGATION_PROMPT_TEMPLATE_FILE_NAME = "device-mitigation"; public static final String THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE_FILE_NAME = "threat-intel-summary-single"; public static final String THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE_FILE_NAME = "threat-intel-summary-single"; public static final String THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME = "threat-intel-summary-combine-responses"; public static final List<String> NARRATION_STORAGE_RESTRICTED_REGION_PREFIXES = List.of("europe"); final String ALERT_SUMMARY_PROMPT_TEMPLATE; final String ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE; final String ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE; final String DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE; final String DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE; final String DEVICE_MITIGATION_PROMPT_TEMPLATE; final String THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE; final String THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE; final String THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE; @Autowired private ElasticSearchBackend elasticSearchBackend; @Autowired private RemProxyApiService remProxyApiService; @Getter final GenerativeAiPredictionService generativeAiPredictionService; private final LoadingCache<ThreatContextEnrichmentAdminApiService.DetectionsCacheKey, List<Detection>> detectionsCache; private final LoadingCache<ThreatContextEnrichmentAdminApiService.AssetDataCacheKey, JsonNode> assetDataCache; private final ExecutorService threadPool; // specifies how far back to search for asset data private final int assetDataLastSeenOnlineTimeFilterInHours; private final SerapisConfig.ThreatContextEnrichmentConfig threatContextEnrichmentConfig; @Getter private final VLThreatIntel vlThreatIntel; private final int documentTtlInDays; private final int prevNarrationDocumentTtlInDays; @Autowired public ThreatContextEnrichmentAdminApiService(SerapisConfig appConfig) throws IOException { SerapisConfig.GoogleCloud googleCloudConfig = appConfig.getGoogleCloud(); threatContextEnrichmentConfig = appConfig.getThreatContextEnrichment(); SerapisConfig.VertexAiModelConfig vertexAiPalmModelConfig = threatContextEnrichmentConfig.getVertexAiPalmModel(); SerapisConfig.VertexAiModelConfig.VertexAiModelParameters palmModelParametersConfig = vertexAiPalmModelConfig.getModelParameters(); SerapisConfig.VertexAiModelConfig vertexAiGeminiModelConfig = threatContextEnrichmentConfig.getVertexAiGeminiModel(); SerapisConfig.VertexAiModelConfig.VertexAiModelParameters geminiModelParametersConfig = vertexAiGeminiModelConfig.getModelParameters(); // load max detections and max indicators per detection limits for alert summary SerapisConfig.AlertSummaryConfig alertSummaryConfig = threatContextEnrichmentConfig.getAlertSummary(); SerapisConfig.ThreatContextEnrichmentConfig.CacheConfig detectionsCacheConfig = alertSummaryConfig.getDetectionsCache(); SerapisConfig.DeviceSummariesConfig deviceSummariesConfig = threatContextEnrichmentConfig.getDeviceSummaries(); SerapisConfig.ThreatContextEnrichmentConfig.CacheConfig assetDataCacheConfig = deviceSummariesConfig.getAssetDataCache(); VertexAiPalmPredictionServiceBackend vertexAiPalmModelPredictionService = new VertexAiPalmPredictionServiceBackend(googleCloudConfig.getProject(), vertexAiPalmModelConfig.getApiHostname(), vertexAiPalmModelConfig.getApiLocation(), vertexAiPalmModelConfig.getModelPublisher(), vertexAiPalmModelConfig.getModelName(), new ModelParameters(palmModelParametersConfig.getTemperature().doubleValue(), palmModelParametersConfig.getTopP().doubleValue(), palmModelParametersConfig.getTopK(), palmModelParametersConfig.getMaxOutputTokens(), Collections.emptyList(), palmModelParametersConfig.getCandidateCount())); VertexAiGeminiPredictionServiceBackend vertexAiGeminiModelPredictionService = new VertexAiGeminiPredictionServiceBackend(googleCloudConfig.getProject(), vertexAiGeminiModelConfig.getApiLocation(), vertexAiGeminiModelConfig.getModelName(), new ModelParameters(geminiModelParametersConfig.getTemperature().doubleValue(), geminiModelParametersConfig.getTopP().doubleValue(), geminiModelParametersConfig.getTopK(), geminiModelParametersConfig.getMaxOutputTokens(), Collections.emptyList(), geminiModelParametersConfig.getCandidateCount())); this.generativeAiPredictionService = new GenerativeAiPredictionService(vertexAiPalmModelPredictionService, vertexAiGeminiModelPredictionService); this.documentTtlInDays = threatContextEnrichmentConfig.getNarrationStorage().getDocumentTtlDays(); this.prevNarrationDocumentTtlInDays = threatContextEnrichmentConfig.getNarrationStorage() .getPreviousNarrationDocumentTtlDays(); this.detectionsCache = CacheBuilder.newBuilder() .maximumSize(detectionsCacheConfig.getMaximumSize()) .refreshAfterWrite( Duration.ofSeconds(detectionsCacheConfig.getRefreshAfterWriteSeconds())) .expireAfterWrite( Duration.ofSeconds(detectionsCacheConfig.getExpireAfterWriteSeconds())) .build(CacheLoader.asyncReloading(new CacheLoader<>() { @Override public List<Detection> load(ThreatContextEnrichmentAdminApiService.DetectionsCacheKey key) throws Exception { logger.debug("Loading detections for key {}", key); return retrieveDetections(key.account(), key.entityType(), key.entityId()); } }, Executors.newCachedThreadPool())); this.assetDataCache = CacheBuilder.newBuilder() .maximumSize(assetDataCacheConfig.getMaximumSize()) .refreshAfterWrite( Duration.ofSeconds(assetDataCacheConfig.getRefreshAfterWriteSeconds())) .expireAfterWrite( Duration.ofSeconds(assetDataCacheConfig.getExpireAfterWriteSeconds())) .build(CacheLoader.asyncReloading(new CacheLoader<>() { @Override public JsonNode load(ThreatContextEnrichmentAdminApiService.AssetDataCacheKey key) throws AssetDataRetreivalException { logger.debug("Loading asset data for key {}", key); return retrieveAssetData(key.accountId(), key.assetId()); } }, Executors.newCachedThreadPool())); this.threadPool = Executors.newCachedThreadPool(); this.assetDataLastSeenOnlineTimeFilterInHours = deviceSummariesConfig.getAssetDataLastSeenOnlineTimeFilterInHours(); String promptTemplatesDir = threatContextEnrichmentConfig .getPromptTemplatesDir(); logger.debug("promptTemplatesDir is {}", promptTemplatesDir); // load alert summary prompt template Path templateFile = Paths.get(promptTemplatesDir, ALERT_SUMMARY_PROMPT_TEMPLATE_FILE_NAME); this.ALERT_SUMMARY_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("ALERT_SUMMARY_PROMPT_TEMPLATE is {}", ALERT_SUMMARY_PROMPT_TEMPLATE); // load alert summary single detection prompt template templateFile = Paths.get(promptTemplatesDir, ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE_FILE_NAME); this.ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE is {}", ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE); // load alert summary combine responses prompt template templateFile = Paths.get(promptTemplatesDir, ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME); this.ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE is {}", ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE); // load device and location details prompt template templateFile = Paths.get(promptTemplatesDir, DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE_FILE_NAME); this.DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE is {}", DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE); // load device risk exposure prompt template templateFile = Paths.get(promptTemplatesDir, DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE_FILE_NAME); this.DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE is {}", DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE); // load device mitigation prompt template templateFile = Paths.get(promptTemplatesDir, DEVICE_MITIGATION_PROMPT_TEMPLATE_FILE_NAME); this.DEVICE_MITIGATION_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("DEVICE_MITIGATION_PROMPT_TEMPLATE is {}", DEVICE_MITIGATION_PROMPT_TEMPLATE); // load threat intel prompt template templateFile = Paths.get(promptTemplatesDir, THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE_FILE_NAME); this.THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE is {}", THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE); templateFile = Paths.get(promptTemplatesDir, THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE_FILE_NAME); this.THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE is {}", THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE); // load threat intel summary combine responses prompt template templateFile = Paths.get(promptTemplatesDir, THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME); this.THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE = Files.readString(templateFile); logger.debug("THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE is {}", THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE); this.vlThreatIntel = new VLThreatIntel(appConfig); } public int addExpiresAtToDocuments(Instant beforeTimestamp, int numDaysAfter) throws ExecutionException, InterruptedException { List<Account> accounts = accountRepositoryService.getAccountsPrivileged(); int numDocumentsUpdated = 0; for (Account account : accounts) { numDocumentsUpdated += narrationStorageBackend .addExpiresAtToDocuments(account.getId(), beforeTimestamp, numDaysAfter); } return numDocumentsUpdated; } public AllNarrationsResponse generateAllNarrationsWithStorage(String accountId, IdType idType, Optional<String> entityType, Optional<String> entityId, Optional<String> caseId, Optional<String> assetId, Optional<ZoneId> timeZone, boolean forceRefresh, Optional<PredictionModelType> predictionModelTypeOptional) throws IOException, InvalidParameterException, ExecutionException, EntityNotFoundException, ElasticsearchException, AssetDataRetreivalException, InterruptedException { Account account = accountRepositoryService.getAccountPrivileged(accountId); GenerativeAiUtils.validateAllowAiEnabled(account); boolean isAccountInNarrationStorageRestrictedRegion = ThreatContextEnrichmentUtils.isAccountInNarrationStorageRestrictedRegion(account, NARRATION_STORAGE_RESTRICTED_REGION_PREFIXES); String documentId; // generate the document id based on the idType if (idType == IdType.ENTITY) { if (entityId.isEmpty() || entityType.isEmpty()) { throw new InvalidParameterException( "entity_id and entity_type cannot be empty when id_type is ENTITY"); } documentId = FirestoreNarrationStorageUtils .generateDocumentIdFromEntity(entityType.get(), entityId.get(), caseId); } else if (idType == IdType.ASSET) { if (assetId.isEmpty()) { throw new InvalidParameterException( "asset_id cannot be empty when id_type is ASSET"); } documentId = FirestoreNarrationStorageUtils.generateDocumentIdFromAsset(assetId.get()); } else { throw new InvalidParameterException("Invalid id_type"); } PredictionModelType predictionModelType = predictionModelTypeOptional.orElseGet(threatContextEnrichmentConfig::getDefaultAiModel); // variables to store the final responses for each section String alertSummary = "", deviceLocationAndDetails = "", deviceRiskExposure = "", deviceMitigation = "", threatIntelSummary = ""; Instant generationTime; // save email of the user that made the API call String invokingUserEmail = SecurityUtils.getInvokingUserEmail(); // generate all the prompts that could be executed ThreatContextEnrichmentService.Prompts currentPrompts = generateAllPrompts(account, entityType, entityId, assetId, timeZone); // retrieve any existing narration data from storage Optional<NarrationData> narrationDataOptional = narrationStorageBackend.getNarrationData(accountId, documentId); if (narrationDataOptional.isPresent()) { // check if user can perform a force refresh if (forceRefresh && !EmailUtils.isInternalUserAddress(invokingUserEmail)) { throw new UnauthorizedException( "You do not have permissions to perform a force refresh."); } // narration data exists, need to compare data/prompts to see if any sections need to be regenerated NarrationData narrationData = narrationDataOptional.get(); // initialize generation time to the latest timestamp found in the narration data generationTime = Instant.parse(getLatestTimestamp(narrationData)); // the generation of narrations to be updated will be done concurrently using a thread pool // this variable stores the lists of tasks to be executed concurrently for generating the narrations List<Callable<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> tasks = new ArrayList<>(); // stores the fields that need to be updated in the storage Map<String, Object> updates = new HashMap<>(); // check if the alert summary section needs to be regenerated if (FirestoreNarrationStorageUtils .shouldUseStoredNarration(narrationData, currentPrompts, NarrationType.ALERT_SUMMARY, forceRefresh)) { // section does not need to be regenerated, use the stored narration alertSummary = narrationData.getAlertSummaryResponse(); } else { // regenerate alert summary section logger.debug("Regenerating alert summary section..."); tasks.add(() -> ThreatContextEnrichmentUtils.generateAlertSummaryTask(this, currentPrompts.alertSummarySingleDetectionPromptData(), predictionModelType)); } // check if the threat intel summary section needs to be regenerated if (FirestoreNarrationStorageUtils .shouldUseStoredNarration(narrationData, currentPrompts, NarrationType.THREAT_INTEL_SUMMARY, forceRefresh)) { // section does not need to be regenerated, use the stored narration threatIntelSummary = narrationData.getThreatIntelSummaryResponse(); } else { // regenerate threat intel summary section logger.debug("Regenerating threat intel summary section..."); tasks.add(() -> ThreatContextEnrichmentUtils.generateThreatIntelSummaryTask(this, currentPrompts.threatIntelSummarySingleIocPromptData(), predictionModelType)); } // check if we need to regenerate device location and details section if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData, currentPrompts, NarrationType.DEVICE_LOCATION_AND_DETAILS, forceRefresh)) { // section does not need to be regenerated, use the stored narration deviceLocationAndDetails = narrationData.getDeviceDetailsResponse(); } else { // regenerate device location and details section logger.debug("Regenerating device and location details section..."); tasks.add( () -> ThreatContextEnrichmentUtils.generateDeviceDetailsAndLocationTask(this, currentPrompts.deviceDetailsAndLocationPrompt(), predictionModelType)); } // check if we need to regenerate device risk exposure section if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData, currentPrompts, NarrationType.DEVICE_RISK_EXPOSURE, forceRefresh)) { // section does not need to be regenerated, use the stored narration deviceRiskExposure = narrationData.getDeviceRiskResponse(); } else { // regenerate device risk exposure section logger.debug("Regenerating device risk exposure section..."); tasks.add( () -> ThreatContextEnrichmentUtils.generateDeviceRiskExposureTask(this, currentPrompts.deviceRiskExposurePrompt(), predictionModelType)); } // check if we need to regenerate device mitigation section if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData, currentPrompts, NarrationType.DEVICE_MITIGATION, forceRefresh)) { // section does not need to be regenerated, use the stored narration deviceMitigation = narrationData.getDeviceMitigationResponse(); } else { // regenerate device mitigation section logger.debug("Regenerating device mitigation section..."); tasks .add(() -> ThreatContextEnrichmentUtils.generateDeviceMitigationPromptTask(this, currentPrompts.deviceMitigationPrompt(), predictionModelType)); } // use thread pool to execute the tasks List<Future<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> futures = threadPool.invokeAll(tasks); // wait for tasks to finish executing and process the results // init the lastRegeneratedTimeStamp, which will be used for all the sections Instant lastRegeneratedTimeStamp = DateUtils.nowTruncatedToMillis(); NarrationData.NarrationDataBuilder previousNarrationDataBuilder = NarrationData.builder(); for (var future : futures) { var narrationTypeAndPromptResponse = future.get(); // process the results based on narration type ThreatContextEnrichmentUtils.processNarrationUpdateResults(currentPrompts, narrationTypeAndPromptResponse, updates, lastRegeneratedTimeStamp.toString(), invokingUserEmail, this, predictionModelType, narrationData, previousNarrationDataBuilder); } // update narration data in storage if there are updates if (!updates.isEmpty()) { // add a previous narration document containing the previous prompts and responses that will be updated previousNarrationDataBuilder.expiresAt(convertToTimestamp( lastRegeneratedTimeStamp.plus(prevNarrationDocumentTtlInDays, // set expiresAt timestamp ChronoUnit.DAYS))); narrationStorageBackend.storePreviousNarrationData(accountId, documentId, lastRegeneratedTimeStamp.toString(), previousNarrationDataBuilder.build()); // update the expiresAt field, which is used for TTL deletion updates.put(NarrationData.Fields.expiresAt, convertToTimestamp( lastRegeneratedTimeStamp.plus(documentTtlInDays, ChronoUnit.DAYS))); narrationStorageBackend .updateNarrationData(accountId, documentId, updates); // update final section responses if there was an update alertSummary = (String) updates.getOrDefault(NarrationData.Fields.alertSummaryResponse, alertSummary); deviceLocationAndDetails = (String) updates.getOrDefault(NarrationData.Fields.deviceDetailsResponse, deviceLocationAndDetails); deviceRiskExposure = (String) updates.getOrDefault(NarrationData.Fields.deviceRiskResponse, deviceRiskExposure); deviceMitigation = (String) updates.getOrDefault(NarrationData.Fields.deviceMitigationResponse, deviceMitigation); threatIntelSummary = (String) updates.getOrDefault(NarrationData.Fields.threatIntelSummaryResponse, threatIntelSummary); generationTime = lastRegeneratedTimeStamp; } } else { // no existing narrations, generate new narrations and then store it NarrationData.NarrationDataBuilder builder = NarrationData.builder(); long start = System.currentTimeMillis(); // the generation of narrations will be executed concurrently using a thread pool // stores the lists of tasks to be executed concurrently for generating the narrations List<Callable<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> tasks = new ArrayList<>(); // if there is single detection alert summary data, generate alert summary if (!currentPrompts.alertSummarySingleDetectionPromptData().isEmpty()) { tasks.add(() -> ThreatContextEnrichmentUtils.generateAlertSummaryTask(this, currentPrompts.alertSummarySingleDetectionPromptData(), predictionModelType)); } // if device and location details prompt is not blank, generate narration if (StringUtils.isNotBlank(currentPrompts.deviceDetailsAndLocationPrompt())) { tasks.add( () -> ThreatContextEnrichmentUtils.generateDeviceDetailsAndLocationTask(this, currentPrompts.deviceDetailsAndLocationPrompt(), predictionModelType)); } // if device risk exposure prompt is not blank, generate narration if (StringUtils.isNotBlank(currentPrompts.deviceRiskExposurePrompt())) { tasks.add(() -> ThreatContextEnrichmentUtils.generateDeviceRiskExposureTask(this, currentPrompts.deviceRiskExposurePrompt(), predictionModelType)); } // if device mitigation prompt is not blank, generate narration if (StringUtils.isNotBlank(currentPrompts.deviceMitigationPrompt())) { tasks .add(() -> ThreatContextEnrichmentUtils.generateDeviceMitigationPromptTask(this, currentPrompts.deviceMitigationPrompt(), predictionModelType)); } // if there is ioc data, generate threat intel summary if (!currentPrompts.threatIntelSummarySingleIocPromptData().isEmpty()) { tasks.add(() -> ThreatContextEnrichmentUtils.generateThreatIntelSummaryTask(this, currentPrompts.threatIntelSummarySingleIocPromptData(), predictionModelType)); } // use thread pool to execute the tasks List<Future<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> futures = threadPool.invokeAll(tasks); // wait for tasks to finish executing and process the results for (var future : futures) { var narrationTypeAndPromptResponse = future.get(); // process the results based on narration type ThreatContextEnrichmentUtils.processNarrationGenerationResults(currentPrompts, narrationTypeAndPromptResponse, builder, // the NarrationDataBuilder, which is used to store data that we want to send this, predictionModelType); } // populate creation fields and build NarrationData generationTime = DateUtils.nowTruncatedToMillis(); NarrationData newNarrationData = builder.created(generationTime.toString()) .createdBy(invokingUserEmail) .expiresAt(convertToTimestamp(generationTime .plus(documentTtlInDays, ChronoUnit.DAYS))) .build(); // only store narration data if it is not in a restricted region if (!isAccountInNarrationStorageRestrictedRegion) { // store narration data narrationStorageBackend.storeNarrationData(accountId, documentId, newNarrationData); } else { logger.debug( "Not storing narration data with document id {} for account {}, which is in a restricted region", documentId, accountId); // setting documentId to null so the UI knows that the narration is not stored documentId = null; } // save final section responses to be returned alertSummary = newNarrationData.getAlertSummaryResponse(); deviceLocationAndDetails = newNarrationData.getDeviceDetailsResponse(); deviceRiskExposure = newNarrationData.getDeviceRiskResponse(); deviceMitigation = newNarrationData.getDeviceMitigationResponse(); threatIntelSummary = newNarrationData.getThreatIntelSummaryResponse(); logger.debug("Generate and store narration took {}ms", System.currentTimeMillis() - start); } // sections that were not generated will be an empty string return new AllNarrationsResponse(documentId, alertSummary, deviceLocationAndDetails, deviceRiskExposure, deviceMitigation, threatIntelSummary, generationTime.toString()); } private List<Detection> retrieveDetections(Account account, String entityType, String entityId) throws EntityNotFoundException, ElasticsearchException { return elasticSearchBackend .getDetectionsForThreatContextEnrichment(account, entityType, entityId); } private List<Detection> retrieveDetectionsFromCache(String entityType, String entityId, Account account) throws EntityNotFoundException, ElasticsearchException, ExecutionException { try { /** * Note: retrieving account here instead of in {@link #retrieveDetections} because * retrieveDetections() can be invoked by the Loading Cache threads which don't have * access to authorization information needed by the @SecureDao annotation */ return detectionsCache.get(new ThreatContextEnrichmentAdminApiService.DetectionsCacheKey(account, entityType, entityId)); } catch (ExecutionException ex) { if (ex.getCause() instanceof EntityNotFoundException entityNotFoundException) { throw entityNotFoundException; } else if (ex.getCause() instanceof ElasticsearchException elasticsearchException) { throw elasticsearchException; } else { throw ex; } } } private ThreatContextEnrichmentService.Prompts generateAllPrompts(Account account, Optional<String> entityType, Optional<String> entityId, Optional<String> assetId, Optional<ZoneId> timeZone) throws ElasticsearchException, EntityNotFoundException, ExecutionException, IOException, AssetDataRetreivalException { List<String> alertSummarySingleDetectionPromptData = new ArrayList<>(); List<String> threatSummarySingleIocPromptData = new ArrayList<>(); String deviceLocationAndDetailsPrompt = "", deviceRiskExposurePrompt = "", deviceMitigationPrompt = ""; if (entityId.isPresent() && entityType.isPresent()) { // retrieve and clean up threat data List<Detection> detections; detections = retrieveDetectionsFromCache(entityType.get(), entityId.get(), account); ThreatContextEnrichmentService.CleanedDetectionsAndAssetId cleanedDetectionsAndAssetId = retrieveCleanedDetectionsAndAssetId(detections, threatContextEnrichmentConfig.getAlertSummary(), timeZone); // save the data used to generate the alert summaries for individual detections alertSummarySingleDetectionPromptData = generateAlertSummarySingleDetectionData( cleanedDetectionsAndAssetId.cleanedDetections()); // save the data used to generate the threat intel summaries for individual IoCs List<VLIoC> iocs = getIocsFromDetections(detections); threatSummarySingleIocPromptData = generateThreatSummarySingleIocData(iocs); // if retrieved an asset id from the detection and no asset id is supplied, set the asset id if (StringUtils.isNotBlank(cleanedDetectionsAndAssetId.assetId()) && assetId.isEmpty()) { assetId = Optional.of(cleanedDetectionsAndAssetId.assetId()); } } if (assetId.isPresent()) { // retrieve the asset data JsonNode assetData = retrieveAssetDataFromCache(account.getId(), assetId.get()); // only generate sections if asset data is not empty if (!assetData.isEmpty()) { // generate device location and details prompt Map<String, Object> cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone); cleanupAssetDataForDeviceAndLocationDetails(cleanedAssetData); deviceLocationAndDetailsPrompt = generateDeviceSectionPrompt(cleanedAssetData, DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE); // generate device risk exposure prompt cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone); cleanupAssetDataForRiskExposureAndMitigation(cleanedAssetData); deviceRiskExposurePrompt = generateDeviceSectionPrompt(cleanedAssetData, DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE); // generate device mitigation prompt cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone); cleanupAssetDataForRiskExposureAndMitigation(cleanedAssetData); deviceMitigationPrompt = generateDeviceSectionPrompt(cleanedAssetData, DEVICE_MITIGATION_PROMPT_TEMPLATE); } } return new ThreatContextEnrichmentService.Prompts( ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE, ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE, alertSummarySingleDetectionPromptData, THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE, THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE, threatSummarySingleIocPromptData, deviceLocationAndDetailsPrompt, deviceRiskExposurePrompt, deviceMitigationPrompt); } public String executePrompt(String prompt, PredictionModelType predictionModelType) throws InvalidParameterException, IOException { return generativeAiPredictionService.predictTextPrompt(prompt, predictionModelType); } private String generateThreatIntelSummaryUsingSingleInvocationMode(String iocData) throws IOException, InvalidParameterException { Map<String, String> promptData = Map.of(IOC_DATA, iocData); String threatIntelSummaryPrompt = StringSubstitutor.replace(THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE, promptData); logger.debug("ThreatIntelSummary prompt is:\n{}", threatIntelSummaryPrompt); return generativeAiPredictionService.predictTextPrompt(threatIntelSummaryPrompt, PredictionModelType.TEXT_BISON_32K); } private List<VLIoC> getIocsFromDetections(List<Detection> detections) { // Extract the IoCs from the logs Set<String> extractedIoc = new HashSet<>(); for (Detection det : detections) extractedIoc.addAll(extractIocsFromLogs(det)); // Retrieve Threat Intel from the bigQuery Table ArrayList<VLIoC> feedIocList = new ArrayList<>(); for (String ioc : extractedIoc) { VLIoC feedIoc = vlThreatIntel.getIoCInfo(ioc); if (feedIoc != null) { feedIocList.add(feedIoc); } } return feedIocList; } private String generateThreatIntelSummaryForSingleIoc(String iocData, PredictionModelType predictionModelType) throws InvalidParameterException, IOException { Map<String, String> promptData = Map.of(IOC_DATA, iocData); String threatIntelSummaryPrompt = StringSubstitutor.replace(THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE, promptData); logger.debug("generateThreatIntelSummaryForSingleIoc prompt is:\n{}", threatIntelSummaryPrompt); return generativeAiPredictionService.predictTextPrompt(threatIntelSummaryPrompt, predictionModelType); } public ThreatContextEnrichmentService.PromptDataAndResponse generateThreatIntelSummaryUsingMultipleInvocationMode(List<String> iocsForPrompt, PredictionModelType predictionModelType) throws IOException, InvalidParameterException, InterruptedException, ExecutionException { StringBuilder modelResponsesStr = new StringBuilder(); // STEP 1: generate a summary for each individual ioc and save it // generate the summary for each ioc in parallel using an ExecutorService List<Callable<String>> tasks = new ArrayList<>(); for (String iocDataForPrompt : iocsForPrompt) { tasks.add(() -> { String modelResponse = generateThreatIntelSummaryForSingleIoc(iocDataForPrompt, predictionModelType); logger.debug("Model response is: " + modelResponse); return modelResponse; }); } List<Future<String>> futures = threadPool.invokeAll(tasks); for (Future<String> future : futures) { String modelResponse = future.get(); modelResponsesStr.append(modelResponse).append("\n\n"); } //STEP 2: generate an overall summary of the individual summaries // Note: we still perform step 2 even if there is only one detection summary because we want to keep the // summary style and tone consistent Map<String, String> promptData = Map.of(IOC_SUMMARIES, modelResponsesStr.toString()); String prompt = StringSubstitutor.replace(THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE, promptData); logger.debug("threat intel summary combine responses prompt is: {}", prompt); String response = generativeAiPredictionService .predictTextPrompt(prompt, predictionModelType); return new ThreatContextEnrichmentService.PromptDataAndResponse(modelResponsesStr.toString(), response); } private JsonNode retrieveAssetDataFromCache(String accountId, String assetId) throws AssetDataRetreivalException, ExecutionException { try { return assetDataCache.get(new ThreatContextEnrichmentAdminApiService.AssetDataCacheKey(accountId, assetId)); } catch (ExecutionException ex) { if (ex.getCause() instanceof AssetDataRetreivalException assetDataRetreivalException) { throw assetDataRetreivalException; } else { throw ex; } } } private JsonNode retrieveAssetData(String accountId, String assetId) throws AssetDataRetreivalException { // retrieve asset data using assetDataLastSeenOnlineTimeFilterInHours Response response = remProxyApiService.getAsset(accountId, assetId, Instant.now() .minus(assetDataLastSeenOnlineTimeFilterInHours, ChronoUnit.HOURS) .toString(), Instant.now().toString()); // check if there was an error retrieving the asset data if (response.getStatusInfo().getFamily() == Response.Status.Family.CLIENT_ERROR || response.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) { // treat 404 Not Found errors as no asset data found because this what is returned by the REM API when asset data is not found if (response.getStatusInfo().getStatusCode() == HttpStatus.SC_NOT_FOUND) { logger.debug("Asset data not found for account {} asset {}", accountId, assetId); // return empty JsonNode return JsonNodeFactory.instance.objectNode(); } logger.error("Failed to retrieve asset data from REM: {}", response); throw new AssetDataRetreivalException(response.getStatusInfo().toEnum(), "Failed to retrieve asset data from REM: " + response.getEntity()); } JsonNode assetData = (JsonNode) response.getEntity(); // remove fields in asset data that are not needed for generating the narrations and we don't want to cache ((ObjectNode) assetData).remove("events"); return assetData; } public String getAlertSummarySingleDetectionPromptTemplate() { return ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE; } public String getAlertSummaryCombineResponsesPromptTemplate() { return ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE; } public String getThreatIntelSummarySingleIocPromptTemplate() { return THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE; } public String getThreatIntelSummaryCombineResponsesPromptTemplate() { return THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE; } }
Leave a Comment