Untitled
unknown
plain_text
a year ago
51 kB
11
Indexable
package com.trendmicro.serapis.admin.apiservice.api.generativeai.threatcontextenrichment.definition;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.IOC_DATA;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.IOC_SUMMARIES;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.cleanupAssetDataForDeviceAndLocationDetails;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.cleanupAssetDataForRiskExposureAndMitigation;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.convertToTimestamp;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.extractIocsFromLogs;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateAlertSummarySingleDetectionData;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateDeviceSectionPrompt;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.generateThreatSummarySingleIocData;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.getLatestTimestamp;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.retrieveCleanedAssetData;
import static com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils.retrieveCleanedDetectionsAndAssetId;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import javax.ws.rs.core.Response;
import org.apache.commons.text.StringSubstitutor;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.cysiv.elasticsearch.model.triage.Detection;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.trendmicro.serapis.apiservice.api.accounts.model.Account;
import com.trendmicro.serapis.apiservice.api.elasticsearch.backend.ElasticSearchBackend;
import com.trendmicro.serapis.apiservice.api.exceptions.AssetDataRetreivalException;
import com.trendmicro.serapis.apiservice.api.exceptions.ElasticsearchException;
import com.trendmicro.serapis.apiservice.api.exceptions.EntityNotFoundException;
import com.trendmicro.serapis.apiservice.api.exceptions.InvalidParameterException;
import com.trendmicro.serapis.apiservice.api.exceptions.UnauthorizedException;
import com.trendmicro.serapis.apiservice.api.generativeai.GenerativeAiUtils;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.ThreatContextEnrichmentUtils;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.definition.ThreatContextEnrichmentService;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.definition.VLThreatIntel;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.AllNarrationsResponse;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.IdType;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.model.VLIoC;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.FirestoreNarrationStorageUtils;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.NarrationStorageBackend;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.model.NarrationData;
import com.trendmicro.serapis.apiservice.api.generativeai.threatcontextenrichment.storage.model.NarrationType;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.GenerativeAiPredictionService;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.VertexAiGeminiPredictionServiceBackend;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.backend.VertexAiPalmPredictionServiceBackend;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.model.request.ModelParameters;
import com.trendmicro.serapis.apiservice.api.generativeai.vertexai.prediction.model.request.PredictionModelType;
import com.trendmicro.serapis.apiservice.api.rem.definition.RemProxyApiService;
import com.trendmicro.serapis.apiservice.config.SerapisConfig;
import com.trendmicro.serapis.apiservice.database.AccountRepositoryService;
import com.trendmicro.serapis.apiservice.security.SecurityUtils;
import com.trendmicro.serapis.apiservice.util.DateUtils;
import com.trendmicro.serapis.apiservice.util.EmailUtils;
import io.micrometer.core.instrument.util.StringUtils;
import lombok.Getter;
@Service
public class ThreatContextEnrichmentAdminApiService {
@Autowired
private NarrationStorageBackend narrationStorageBackend;
@Autowired
private AccountRepositoryService accountRepositoryService;
private static final Logger logger =
LoggerFactory.getLogger(ThreatContextEnrichmentService.class);
record DetectionsCacheKey(Account account, String entityType, String entityId) {
}
record AssetDataCacheKey(String accountId, String assetId) {
}
public record Prompts(
String alertSummarySingleDetectionPromptTemplate,
String alertSummaryCombineResponsesPromptTemplate,
List<String> alertSummarySingleDetectionPromptData,
String threatIntelSummarySingleIocPromptTemplate,
String threatIntelSummaryCombineResponsesPromptTemplate,
List<String> threatIntelSummarySingleIocPromptData,
String deviceDetailsAndLocationPrompt,
String deviceRiskExposurePrompt, String deviceMitigationPrompt) {
}
public record PromptDataAndResponse(String promptData, String response) {
}
public record NarrationTypeAndPromptResponse(NarrationType narrationType,
ThreatContextEnrichmentService.PromptDataAndResponse promptDataAndResponse) {
}
public static final String ALERT_SUMMARY_PROMPT_TEMPLATE_FILE_NAME = "alert-summary";
public static final String ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE_FILE_NAME =
"alert-summary-single-detection";
public static final String ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME =
"alert-summary-combine-responses";
public static final String DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE_FILE_NAME =
"device-and-location-details";
public static final String DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE_FILE_NAME =
"device-risk-exposure";
public static final String DEVICE_MITIGATION_PROMPT_TEMPLATE_FILE_NAME =
"device-mitigation";
public static final String THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE_FILE_NAME =
"threat-intel-summary-single";
public static final String THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE_FILE_NAME =
"threat-intel-summary-single";
public static final String THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME =
"threat-intel-summary-combine-responses";
public static final List<String> NARRATION_STORAGE_RESTRICTED_REGION_PREFIXES =
List.of("europe");
final String ALERT_SUMMARY_PROMPT_TEMPLATE;
final String ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE;
final String ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
final String DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE;
final String DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE;
final String DEVICE_MITIGATION_PROMPT_TEMPLATE;
final String THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE;
final String THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE;
final String THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
@Autowired
private ElasticSearchBackend elasticSearchBackend;
@Autowired
private RemProxyApiService remProxyApiService;
@Getter
final GenerativeAiPredictionService generativeAiPredictionService;
private final LoadingCache<ThreatContextEnrichmentAdminApiService.DetectionsCacheKey, List<Detection>> detectionsCache;
private final LoadingCache<ThreatContextEnrichmentAdminApiService.AssetDataCacheKey, JsonNode> assetDataCache;
private final ExecutorService threadPool;
// specifies how far back to search for asset data
private final int assetDataLastSeenOnlineTimeFilterInHours;
private final SerapisConfig.ThreatContextEnrichmentConfig threatContextEnrichmentConfig;
@Getter
private final VLThreatIntel vlThreatIntel;
private final int documentTtlInDays;
private final int prevNarrationDocumentTtlInDays;
@Autowired
public ThreatContextEnrichmentAdminApiService(SerapisConfig appConfig) throws IOException {
SerapisConfig.GoogleCloud googleCloudConfig = appConfig.getGoogleCloud();
threatContextEnrichmentConfig =
appConfig.getThreatContextEnrichment();
SerapisConfig.VertexAiModelConfig vertexAiPalmModelConfig =
threatContextEnrichmentConfig.getVertexAiPalmModel();
SerapisConfig.VertexAiModelConfig.VertexAiModelParameters palmModelParametersConfig =
vertexAiPalmModelConfig.getModelParameters();
SerapisConfig.VertexAiModelConfig vertexAiGeminiModelConfig =
threatContextEnrichmentConfig.getVertexAiGeminiModel();
SerapisConfig.VertexAiModelConfig.VertexAiModelParameters geminiModelParametersConfig =
vertexAiGeminiModelConfig.getModelParameters();
// load max detections and max indicators per detection limits for alert summary
SerapisConfig.AlertSummaryConfig alertSummaryConfig =
threatContextEnrichmentConfig.getAlertSummary();
SerapisConfig.ThreatContextEnrichmentConfig.CacheConfig detectionsCacheConfig =
alertSummaryConfig.getDetectionsCache();
SerapisConfig.DeviceSummariesConfig deviceSummariesConfig =
threatContextEnrichmentConfig.getDeviceSummaries();
SerapisConfig.ThreatContextEnrichmentConfig.CacheConfig assetDataCacheConfig =
deviceSummariesConfig.getAssetDataCache();
VertexAiPalmPredictionServiceBackend vertexAiPalmModelPredictionService =
new VertexAiPalmPredictionServiceBackend(googleCloudConfig.getProject(),
vertexAiPalmModelConfig.getApiHostname(),
vertexAiPalmModelConfig.getApiLocation(),
vertexAiPalmModelConfig.getModelPublisher(),
vertexAiPalmModelConfig.getModelName(),
new ModelParameters(palmModelParametersConfig.getTemperature().doubleValue(),
palmModelParametersConfig.getTopP().doubleValue(),
palmModelParametersConfig.getTopK(),
palmModelParametersConfig.getMaxOutputTokens(),
Collections.emptyList(),
palmModelParametersConfig.getCandidateCount()));
VertexAiGeminiPredictionServiceBackend vertexAiGeminiModelPredictionService =
new VertexAiGeminiPredictionServiceBackend(googleCloudConfig.getProject(),
vertexAiGeminiModelConfig.getApiLocation(),
vertexAiGeminiModelConfig.getModelName(),
new ModelParameters(geminiModelParametersConfig.getTemperature().doubleValue(),
geminiModelParametersConfig.getTopP().doubleValue(),
geminiModelParametersConfig.getTopK(),
geminiModelParametersConfig.getMaxOutputTokens(),
Collections.emptyList(),
geminiModelParametersConfig.getCandidateCount()));
this.generativeAiPredictionService =
new GenerativeAiPredictionService(vertexAiPalmModelPredictionService,
vertexAiGeminiModelPredictionService);
this.documentTtlInDays =
threatContextEnrichmentConfig.getNarrationStorage().getDocumentTtlDays();
this.prevNarrationDocumentTtlInDays =
threatContextEnrichmentConfig.getNarrationStorage()
.getPreviousNarrationDocumentTtlDays();
this.detectionsCache = CacheBuilder.newBuilder()
.maximumSize(detectionsCacheConfig.getMaximumSize())
.refreshAfterWrite(
Duration.ofSeconds(detectionsCacheConfig.getRefreshAfterWriteSeconds()))
.expireAfterWrite(
Duration.ofSeconds(detectionsCacheConfig.getExpireAfterWriteSeconds()))
.build(CacheLoader.asyncReloading(new CacheLoader<>() {
@Override
public List<Detection> load(ThreatContextEnrichmentAdminApiService.DetectionsCacheKey key)
throws Exception {
logger.debug("Loading detections for key {}", key);
return retrieveDetections(key.account(), key.entityType(), key.entityId());
}
}, Executors.newCachedThreadPool()));
this.assetDataCache = CacheBuilder.newBuilder()
.maximumSize(assetDataCacheConfig.getMaximumSize())
.refreshAfterWrite(
Duration.ofSeconds(assetDataCacheConfig.getRefreshAfterWriteSeconds()))
.expireAfterWrite(
Duration.ofSeconds(assetDataCacheConfig.getExpireAfterWriteSeconds()))
.build(CacheLoader.asyncReloading(new CacheLoader<>() {
@Override
public JsonNode load(ThreatContextEnrichmentAdminApiService.AssetDataCacheKey key) throws AssetDataRetreivalException {
logger.debug("Loading asset data for key {}", key);
return retrieveAssetData(key.accountId(), key.assetId());
}
}, Executors.newCachedThreadPool()));
this.threadPool = Executors.newCachedThreadPool();
this.assetDataLastSeenOnlineTimeFilterInHours =
deviceSummariesConfig.getAssetDataLastSeenOnlineTimeFilterInHours();
String promptTemplatesDir = threatContextEnrichmentConfig
.getPromptTemplatesDir();
logger.debug("promptTemplatesDir is {}", promptTemplatesDir);
// load alert summary prompt template
Path templateFile = Paths.get(promptTemplatesDir, ALERT_SUMMARY_PROMPT_TEMPLATE_FILE_NAME);
this.ALERT_SUMMARY_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("ALERT_SUMMARY_PROMPT_TEMPLATE is {}", ALERT_SUMMARY_PROMPT_TEMPLATE);
// load alert summary single detection prompt template
templateFile =
Paths.get(promptTemplatesDir, ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE_FILE_NAME);
this.ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE is {}",
ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE);
// load alert summary combine responses prompt template
templateFile = Paths.get(promptTemplatesDir,
ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME);
this.ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE is {}",
ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE);
// load device and location details prompt template
templateFile =
Paths.get(promptTemplatesDir, DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE_FILE_NAME);
this.DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE is {}",
DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE);
// load device risk exposure prompt template
templateFile =
Paths.get(promptTemplatesDir, DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE_FILE_NAME);
this.DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE is {}",
DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE);
// load device mitigation prompt template
templateFile = Paths.get(promptTemplatesDir, DEVICE_MITIGATION_PROMPT_TEMPLATE_FILE_NAME);
this.DEVICE_MITIGATION_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("DEVICE_MITIGATION_PROMPT_TEMPLATE is {}", DEVICE_MITIGATION_PROMPT_TEMPLATE);
// load threat intel prompt template
templateFile =
Paths.get(promptTemplatesDir,
THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE_FILE_NAME);
this.THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE is {}",
THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE);
templateFile =
Paths.get(promptTemplatesDir, THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE_FILE_NAME);
this.THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE = Files.readString(templateFile);
logger.debug("THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE is {}",
THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE);
// load threat intel summary combine responses prompt template
templateFile =
Paths.get(promptTemplatesDir,
THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE_FILE_NAME);
this.THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE =
Files.readString(templateFile);
logger.debug("THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE is {}",
THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE);
this.vlThreatIntel = new VLThreatIntel(appConfig);
}
public int addExpiresAtToDocuments(Instant beforeTimestamp, int numDaysAfter)
throws ExecutionException, InterruptedException {
List<Account> accounts = accountRepositoryService.getAccountsPrivileged();
int numDocumentsUpdated = 0;
for (Account account : accounts) {
numDocumentsUpdated += narrationStorageBackend
.addExpiresAtToDocuments(account.getId(), beforeTimestamp, numDaysAfter);
}
return numDocumentsUpdated;
}
public AllNarrationsResponse generateAllNarrationsWithStorage(String accountId,
IdType idType,
Optional<String> entityType,
Optional<String> entityId,
Optional<String> caseId,
Optional<String> assetId,
Optional<ZoneId> timeZone,
boolean forceRefresh,
Optional<PredictionModelType> predictionModelTypeOptional)
throws IOException, InvalidParameterException, ExecutionException,
EntityNotFoundException, ElasticsearchException, AssetDataRetreivalException,
InterruptedException {
Account account = accountRepositoryService.getAccountPrivileged(accountId);
GenerativeAiUtils.validateAllowAiEnabled(account);
boolean isAccountInNarrationStorageRestrictedRegion =
ThreatContextEnrichmentUtils.isAccountInNarrationStorageRestrictedRegion(account,
NARRATION_STORAGE_RESTRICTED_REGION_PREFIXES);
String documentId;
// generate the document id based on the idType
if (idType == IdType.ENTITY) {
if (entityId.isEmpty() || entityType.isEmpty()) {
throw new InvalidParameterException(
"entity_id and entity_type cannot be empty when id_type is ENTITY");
}
documentId = FirestoreNarrationStorageUtils
.generateDocumentIdFromEntity(entityType.get(), entityId.get(), caseId);
} else if (idType == IdType.ASSET) {
if (assetId.isEmpty()) {
throw new InvalidParameterException(
"asset_id cannot be empty when id_type is ASSET");
}
documentId = FirestoreNarrationStorageUtils.generateDocumentIdFromAsset(assetId.get());
} else {
throw new InvalidParameterException("Invalid id_type");
}
PredictionModelType predictionModelType =
predictionModelTypeOptional.orElseGet(threatContextEnrichmentConfig::getDefaultAiModel);
// variables to store the final responses for each section
String alertSummary = "", deviceLocationAndDetails = "",
deviceRiskExposure = "", deviceMitigation = "", threatIntelSummary = "";
Instant generationTime;
// save email of the user that made the API call
String invokingUserEmail = SecurityUtils.getInvokingUserEmail();
// generate all the prompts that could be executed
ThreatContextEnrichmentService.Prompts currentPrompts =
generateAllPrompts(account, entityType, entityId, assetId, timeZone);
// retrieve any existing narration data from storage
Optional<NarrationData> narrationDataOptional =
narrationStorageBackend.getNarrationData(accountId, documentId);
if (narrationDataOptional.isPresent()) {
// check if user can perform a force refresh
if (forceRefresh && !EmailUtils.isInternalUserAddress(invokingUserEmail)) {
throw new UnauthorizedException(
"You do not have permissions to perform a force refresh.");
}
// narration data exists, need to compare data/prompts to see if any sections need to be regenerated
NarrationData narrationData = narrationDataOptional.get();
// initialize generation time to the latest timestamp found in the narration data
generationTime = Instant.parse(getLatestTimestamp(narrationData));
// the generation of narrations to be updated will be done concurrently using a thread pool
// this variable stores the lists of tasks to be executed concurrently for generating the narrations
List<Callable<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> tasks = new ArrayList<>();
// stores the fields that need to be updated in the storage
Map<String, Object> updates = new HashMap<>();
// check if the alert summary section needs to be regenerated
if (FirestoreNarrationStorageUtils
.shouldUseStoredNarration(narrationData,
currentPrompts,
NarrationType.ALERT_SUMMARY,
forceRefresh)) {
// section does not need to be regenerated, use the stored narration
alertSummary = narrationData.getAlertSummaryResponse();
} else {
// regenerate alert summary section
logger.debug("Regenerating alert summary section...");
tasks.add(() -> ThreatContextEnrichmentUtils.generateAlertSummaryTask(this,
currentPrompts.alertSummarySingleDetectionPromptData(),
predictionModelType));
}
// check if the threat intel summary section needs to be regenerated
if (FirestoreNarrationStorageUtils
.shouldUseStoredNarration(narrationData,
currentPrompts,
NarrationType.THREAT_INTEL_SUMMARY,
forceRefresh)) {
// section does not need to be regenerated, use the stored narration
threatIntelSummary = narrationData.getThreatIntelSummaryResponse();
} else {
// regenerate threat intel summary section
logger.debug("Regenerating threat intel summary section...");
tasks.add(() -> ThreatContextEnrichmentUtils.generateThreatIntelSummaryTask(this,
currentPrompts.threatIntelSummarySingleIocPromptData(),
predictionModelType));
}
// check if we need to regenerate device location and details section
if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData,
currentPrompts,
NarrationType.DEVICE_LOCATION_AND_DETAILS,
forceRefresh)) {
// section does not need to be regenerated, use the stored narration
deviceLocationAndDetails = narrationData.getDeviceDetailsResponse();
} else {
// regenerate device location and details section
logger.debug("Regenerating device and location details section...");
tasks.add(
() -> ThreatContextEnrichmentUtils.generateDeviceDetailsAndLocationTask(this,
currentPrompts.deviceDetailsAndLocationPrompt(),
predictionModelType));
}
// check if we need to regenerate device risk exposure section
if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData,
currentPrompts,
NarrationType.DEVICE_RISK_EXPOSURE,
forceRefresh)) {
// section does not need to be regenerated, use the stored narration
deviceRiskExposure = narrationData.getDeviceRiskResponse();
} else {
// regenerate device risk exposure section
logger.debug("Regenerating device risk exposure section...");
tasks.add(
() -> ThreatContextEnrichmentUtils.generateDeviceRiskExposureTask(this,
currentPrompts.deviceRiskExposurePrompt(),
predictionModelType));
}
// check if we need to regenerate device mitigation section
if (FirestoreNarrationStorageUtils.shouldUseStoredNarration(narrationData,
currentPrompts,
NarrationType.DEVICE_MITIGATION,
forceRefresh)) {
// section does not need to be regenerated, use the stored narration
deviceMitigation = narrationData.getDeviceMitigationResponse();
} else {
// regenerate device mitigation section
logger.debug("Regenerating device mitigation section...");
tasks
.add(() -> ThreatContextEnrichmentUtils.generateDeviceMitigationPromptTask(this,
currentPrompts.deviceMitigationPrompt(),
predictionModelType));
}
// use thread pool to execute the tasks
List<Future<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> futures =
threadPool.invokeAll(tasks);
// wait for tasks to finish executing and process the results
// init the lastRegeneratedTimeStamp, which will be used for all the sections
Instant lastRegeneratedTimeStamp = DateUtils.nowTruncatedToMillis();
NarrationData.NarrationDataBuilder previousNarrationDataBuilder = NarrationData.builder();
for (var future : futures) {
var narrationTypeAndPromptResponse = future.get();
// process the results based on narration type
ThreatContextEnrichmentUtils.processNarrationUpdateResults(currentPrompts,
narrationTypeAndPromptResponse,
updates,
lastRegeneratedTimeStamp.toString(),
invokingUserEmail,
this,
predictionModelType,
narrationData,
previousNarrationDataBuilder);
}
// update narration data in storage if there are updates
if (!updates.isEmpty()) {
// add a previous narration document containing the previous prompts and responses that will be updated
previousNarrationDataBuilder.expiresAt(convertToTimestamp(
lastRegeneratedTimeStamp.plus(prevNarrationDocumentTtlInDays, // set expiresAt timestamp
ChronoUnit.DAYS)));
narrationStorageBackend.storePreviousNarrationData(accountId,
documentId,
lastRegeneratedTimeStamp.toString(),
previousNarrationDataBuilder.build());
// update the expiresAt field, which is used for TTL deletion
updates.put(NarrationData.Fields.expiresAt,
convertToTimestamp(
lastRegeneratedTimeStamp.plus(documentTtlInDays, ChronoUnit.DAYS)));
narrationStorageBackend
.updateNarrationData(accountId, documentId, updates);
// update final section responses if there was an update
alertSummary =
(String) updates.getOrDefault(NarrationData.Fields.alertSummaryResponse,
alertSummary);
deviceLocationAndDetails =
(String) updates.getOrDefault(NarrationData.Fields.deviceDetailsResponse,
deviceLocationAndDetails);
deviceRiskExposure =
(String) updates.getOrDefault(NarrationData.Fields.deviceRiskResponse,
deviceRiskExposure);
deviceMitigation =
(String) updates.getOrDefault(NarrationData.Fields.deviceMitigationResponse,
deviceMitigation);
threatIntelSummary =
(String) updates.getOrDefault(NarrationData.Fields.threatIntelSummaryResponse,
threatIntelSummary);
generationTime = lastRegeneratedTimeStamp;
}
} else {
// no existing narrations, generate new narrations and then store it
NarrationData.NarrationDataBuilder builder = NarrationData.builder();
long start = System.currentTimeMillis();
// the generation of narrations will be executed concurrently using a thread pool
// stores the lists of tasks to be executed concurrently for generating the narrations
List<Callable<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> tasks = new ArrayList<>();
// if there is single detection alert summary data, generate alert summary
if (!currentPrompts.alertSummarySingleDetectionPromptData().isEmpty()) {
tasks.add(() -> ThreatContextEnrichmentUtils.generateAlertSummaryTask(this,
currentPrompts.alertSummarySingleDetectionPromptData(),
predictionModelType));
}
// if device and location details prompt is not blank, generate narration
if (StringUtils.isNotBlank(currentPrompts.deviceDetailsAndLocationPrompt())) {
tasks.add(
() -> ThreatContextEnrichmentUtils.generateDeviceDetailsAndLocationTask(this,
currentPrompts.deviceDetailsAndLocationPrompt(),
predictionModelType));
}
// if device risk exposure prompt is not blank, generate narration
if (StringUtils.isNotBlank(currentPrompts.deviceRiskExposurePrompt())) {
tasks.add(() -> ThreatContextEnrichmentUtils.generateDeviceRiskExposureTask(this,
currentPrompts.deviceRiskExposurePrompt(),
predictionModelType));
}
// if device mitigation prompt is not blank, generate narration
if (StringUtils.isNotBlank(currentPrompts.deviceMitigationPrompt())) {
tasks
.add(() -> ThreatContextEnrichmentUtils.generateDeviceMitigationPromptTask(this,
currentPrompts.deviceMitigationPrompt(),
predictionModelType));
}
// if there is ioc data, generate threat intel summary
if (!currentPrompts.threatIntelSummarySingleIocPromptData().isEmpty()) {
tasks.add(() -> ThreatContextEnrichmentUtils.generateThreatIntelSummaryTask(this,
currentPrompts.threatIntelSummarySingleIocPromptData(),
predictionModelType));
}
// use thread pool to execute the tasks
List<Future<ThreatContextEnrichmentService.NarrationTypeAndPromptResponse>> futures =
threadPool.invokeAll(tasks);
// wait for tasks to finish executing and process the results
for (var future : futures) {
var narrationTypeAndPromptResponse = future.get();
// process the results based on narration type
ThreatContextEnrichmentUtils.processNarrationGenerationResults(currentPrompts,
narrationTypeAndPromptResponse,
builder, // the NarrationDataBuilder, which is used to store data that we want to send
this,
predictionModelType);
}
// populate creation fields and build NarrationData
generationTime = DateUtils.nowTruncatedToMillis();
NarrationData newNarrationData =
builder.created(generationTime.toString())
.createdBy(invokingUserEmail)
.expiresAt(convertToTimestamp(generationTime
.plus(documentTtlInDays, ChronoUnit.DAYS)))
.build();
// only store narration data if it is not in a restricted region
if (!isAccountInNarrationStorageRestrictedRegion) {
// store narration data
narrationStorageBackend.storeNarrationData(accountId, documentId, newNarrationData);
} else {
logger.debug(
"Not storing narration data with document id {} for account {}, which is in a restricted region",
documentId,
accountId);
// setting documentId to null so the UI knows that the narration is not stored
documentId = null;
}
// save final section responses to be returned
alertSummary = newNarrationData.getAlertSummaryResponse();
deviceLocationAndDetails = newNarrationData.getDeviceDetailsResponse();
deviceRiskExposure = newNarrationData.getDeviceRiskResponse();
deviceMitigation = newNarrationData.getDeviceMitigationResponse();
threatIntelSummary = newNarrationData.getThreatIntelSummaryResponse();
logger.debug("Generate and store narration took {}ms",
System.currentTimeMillis() - start);
}
// sections that were not generated will be an empty string
return new AllNarrationsResponse(documentId,
alertSummary,
deviceLocationAndDetails,
deviceRiskExposure,
deviceMitigation,
threatIntelSummary,
generationTime.toString());
}
private List<Detection> retrieveDetections(Account account,
String entityType,
String entityId)
throws EntityNotFoundException, ElasticsearchException {
return elasticSearchBackend
.getDetectionsForThreatContextEnrichment(account, entityType, entityId);
}
private List<Detection> retrieveDetectionsFromCache(String entityType,
String entityId,
Account account)
throws EntityNotFoundException, ElasticsearchException, ExecutionException {
try {
/**
* Note: retrieving account here instead of in {@link #retrieveDetections} because
* retrieveDetections() can be invoked by the Loading Cache threads which don't have
* access to authorization information needed by the @SecureDao annotation
*/
return detectionsCache.get(new ThreatContextEnrichmentAdminApiService.DetectionsCacheKey(account, entityType, entityId));
} catch (ExecutionException ex) {
if (ex.getCause() instanceof EntityNotFoundException entityNotFoundException) {
throw entityNotFoundException;
} else if (ex.getCause() instanceof ElasticsearchException elasticsearchException) {
throw elasticsearchException;
} else {
throw ex;
}
}
}
private ThreatContextEnrichmentService.Prompts generateAllPrompts(Account account,
Optional<String> entityType,
Optional<String> entityId,
Optional<String> assetId,
Optional<ZoneId> timeZone)
throws ElasticsearchException, EntityNotFoundException, ExecutionException,
IOException, AssetDataRetreivalException {
List<String> alertSummarySingleDetectionPromptData = new ArrayList<>();
List<String> threatSummarySingleIocPromptData = new ArrayList<>();
String deviceLocationAndDetailsPrompt = "", deviceRiskExposurePrompt = "",
deviceMitigationPrompt = "";
if (entityId.isPresent() && entityType.isPresent()) {
// retrieve and clean up threat data
List<Detection> detections;
detections =
retrieveDetectionsFromCache(entityType.get(), entityId.get(), account);
ThreatContextEnrichmentService.CleanedDetectionsAndAssetId cleanedDetectionsAndAssetId =
retrieveCleanedDetectionsAndAssetId(detections,
threatContextEnrichmentConfig.getAlertSummary(),
timeZone);
// save the data used to generate the alert summaries for individual detections
alertSummarySingleDetectionPromptData = generateAlertSummarySingleDetectionData(
cleanedDetectionsAndAssetId.cleanedDetections());
// save the data used to generate the threat intel summaries for individual IoCs
List<VLIoC> iocs = getIocsFromDetections(detections);
threatSummarySingleIocPromptData = generateThreatSummarySingleIocData(iocs);
// if retrieved an asset id from the detection and no asset id is supplied, set the asset id
if (StringUtils.isNotBlank(cleanedDetectionsAndAssetId.assetId())
&& assetId.isEmpty()) {
assetId = Optional.of(cleanedDetectionsAndAssetId.assetId());
}
}
if (assetId.isPresent()) {
// retrieve the asset data
JsonNode assetData = retrieveAssetDataFromCache(account.getId(), assetId.get());
// only generate sections if asset data is not empty
if (!assetData.isEmpty()) {
// generate device location and details prompt
Map<String, Object> cleanedAssetData =
retrieveCleanedAssetData(assetData, timeZone);
cleanupAssetDataForDeviceAndLocationDetails(cleanedAssetData);
deviceLocationAndDetailsPrompt = generateDeviceSectionPrompt(cleanedAssetData,
DEVICE_AND_LOCATION_DETAILS_PROMPT_TEMPLATE);
// generate device risk exposure prompt
cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone);
cleanupAssetDataForRiskExposureAndMitigation(cleanedAssetData);
deviceRiskExposurePrompt = generateDeviceSectionPrompt(cleanedAssetData,
DEVICE_RISK_EXPOSURE_PROMPT_TEMPLATE);
// generate device mitigation prompt
cleanedAssetData = retrieveCleanedAssetData(assetData, timeZone);
cleanupAssetDataForRiskExposureAndMitigation(cleanedAssetData);
deviceMitigationPrompt = generateDeviceSectionPrompt(cleanedAssetData,
DEVICE_MITIGATION_PROMPT_TEMPLATE);
}
}
return new ThreatContextEnrichmentService.Prompts(
ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE,
ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE,
alertSummarySingleDetectionPromptData,
THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE,
THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE,
threatSummarySingleIocPromptData,
deviceLocationAndDetailsPrompt,
deviceRiskExposurePrompt,
deviceMitigationPrompt);
}
public String executePrompt(String prompt, PredictionModelType predictionModelType)
throws InvalidParameterException, IOException {
return generativeAiPredictionService.predictTextPrompt(prompt, predictionModelType);
}
private String generateThreatIntelSummaryUsingSingleInvocationMode(String iocData)
throws IOException,
InvalidParameterException {
Map<String, String> promptData =
Map.of(IOC_DATA, iocData);
String threatIntelSummaryPrompt =
StringSubstitutor.replace(THREAT_INTEL_SUMMARY_PROMPT_TEMPLATE, promptData);
logger.debug("ThreatIntelSummary prompt is:\n{}", threatIntelSummaryPrompt);
return generativeAiPredictionService.predictTextPrompt(threatIntelSummaryPrompt,
PredictionModelType.TEXT_BISON_32K);
}
private List<VLIoC> getIocsFromDetections(List<Detection> detections) {
// Extract the IoCs from the logs
Set<String> extractedIoc = new HashSet<>();
for (Detection det : detections)
extractedIoc.addAll(extractIocsFromLogs(det));
// Retrieve Threat Intel from the bigQuery Table
ArrayList<VLIoC> feedIocList = new ArrayList<>();
for (String ioc : extractedIoc) {
VLIoC feedIoc = vlThreatIntel.getIoCInfo(ioc);
if (feedIoc != null) {
feedIocList.add(feedIoc);
}
}
return feedIocList;
}
private String generateThreatIntelSummaryForSingleIoc(String iocData,
PredictionModelType predictionModelType)
throws InvalidParameterException, IOException {
Map<String, String> promptData = Map.of(IOC_DATA, iocData);
String threatIntelSummaryPrompt =
StringSubstitutor.replace(THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE,
promptData);
logger.debug("generateThreatIntelSummaryForSingleIoc prompt is:\n{}",
threatIntelSummaryPrompt);
return generativeAiPredictionService.predictTextPrompt(threatIntelSummaryPrompt,
predictionModelType);
}
public ThreatContextEnrichmentService.PromptDataAndResponse generateThreatIntelSummaryUsingMultipleInvocationMode(List<String> iocsForPrompt,
PredictionModelType predictionModelType)
throws IOException, InvalidParameterException, InterruptedException,
ExecutionException {
StringBuilder modelResponsesStr = new StringBuilder();
// STEP 1: generate a summary for each individual ioc and save it
// generate the summary for each ioc in parallel using an ExecutorService
List<Callable<String>> tasks = new ArrayList<>();
for (String iocDataForPrompt : iocsForPrompt) {
tasks.add(() -> {
String modelResponse =
generateThreatIntelSummaryForSingleIoc(iocDataForPrompt,
predictionModelType);
logger.debug("Model response is: " + modelResponse);
return modelResponse;
});
}
List<Future<String>> futures = threadPool.invokeAll(tasks);
for (Future<String> future : futures) {
String modelResponse = future.get();
modelResponsesStr.append(modelResponse).append("\n\n");
}
//STEP 2: generate an overall summary of the individual summaries
// Note: we still perform step 2 even if there is only one detection summary because we want to keep the
// summary style and tone consistent
Map<String, String> promptData = Map.of(IOC_SUMMARIES, modelResponsesStr.toString());
String prompt =
StringSubstitutor.replace(THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE,
promptData);
logger.debug("threat intel summary combine responses prompt is: {}", prompt);
String response = generativeAiPredictionService
.predictTextPrompt(prompt, predictionModelType);
return new ThreatContextEnrichmentService.PromptDataAndResponse(modelResponsesStr.toString(), response);
}
private JsonNode retrieveAssetDataFromCache(String accountId, String assetId)
throws AssetDataRetreivalException, ExecutionException {
try {
return assetDataCache.get(new ThreatContextEnrichmentAdminApiService.AssetDataCacheKey(accountId, assetId));
} catch (ExecutionException ex) {
if (ex.getCause() instanceof AssetDataRetreivalException assetDataRetreivalException) {
throw assetDataRetreivalException;
} else {
throw ex;
}
}
}
private JsonNode retrieveAssetData(String accountId, String assetId)
throws AssetDataRetreivalException {
// retrieve asset data using assetDataLastSeenOnlineTimeFilterInHours
Response response = remProxyApiService.getAsset(accountId,
assetId,
Instant.now()
.minus(assetDataLastSeenOnlineTimeFilterInHours, ChronoUnit.HOURS)
.toString(),
Instant.now().toString());
// check if there was an error retrieving the asset data
if (response.getStatusInfo().getFamily() == Response.Status.Family.CLIENT_ERROR
|| response.getStatusInfo().getFamily() == Response.Status.Family.SERVER_ERROR) {
// treat 404 Not Found errors as no asset data found because this what is returned by the REM API when asset data is not found
if (response.getStatusInfo().getStatusCode() == HttpStatus.SC_NOT_FOUND) {
logger.debug("Asset data not found for account {} asset {}", accountId, assetId);
// return empty JsonNode
return JsonNodeFactory.instance.objectNode();
}
logger.error("Failed to retrieve asset data from REM: {}", response);
throw new AssetDataRetreivalException(response.getStatusInfo().toEnum(),
"Failed to retrieve asset data from REM: " + response.getEntity());
}
JsonNode assetData = (JsonNode) response.getEntity();
// remove fields in asset data that are not needed for generating the narrations and we don't want to cache
((ObjectNode) assetData).remove("events");
return assetData;
}
public String getAlertSummarySingleDetectionPromptTemplate() {
return ALERT_SUMMARY_SINGLE_DETECTION_PROMPT_TEMPLATE;
}
public String getAlertSummaryCombineResponsesPromptTemplate() {
return ALERT_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
}
public String getThreatIntelSummarySingleIocPromptTemplate() {
return THREAT_INTEL_SUMMARY_SINGLE_IOC_PROMPT_TEMPLATE;
}
public String getThreatIntelSummaryCombineResponsesPromptTemplate() {
return THREAT_INTEL_SUMMARY_COMBINE_RESPONSES_PROMPT_TEMPLATE;
}
}
Editor is loading...
Leave a Comment