brand

 avatar
unknown
plain_text
9 months ago
2.7 kB
3
Indexable
import boto3
from pydantic import BaseModel
from typing import List
import json

class UserPrompt(BaseModel):
    query: str

class BrandSearchResultItem(BaseModel):
    lookup_value: str
    linked_records: List[str]

class BrandSearchResults(BaseModel):
    hits: List[BrandSearchResultItem]
    found: int

class BrandSearchParameters(BaseModel):
    query: str
    fields: List[str] = ["lookup_value"]
    sort: str = '_score desc'
    size: int = 10

class BrandCloudSearchClient:
    def __init__(self, domain_endpoint: str):
        self.client = boto3.client('cloudsearchdomain', endpoint_url=domain_endpoint)

    def search(self, params: BrandSearchParameters) -> BrandSearchResults:
        response = self.client.search(
            query=params.query,
            queryParser='simple',
            sort=params.sort,
            size=params.size,
            queryOptions=json.dumps({"fields": params.fields})
        )
        search_results = []
        for hit in response['hits']['hit']:
            search_results.append(BrandSearchResultItem(
                lookup_value=hit['fields'].get('lookup_value', [''])[0],
                linked_records=hit['fields'].get('linked_records', [])
            ))

        return BrandSearchResults(
            hits=search_results,
            found=response['hits']['found']
        )
    

def handle_user_prompt_for_Brand(user_prompt: UserPrompt, client: BrandCloudSearchClient) -> BrandSearchResults:
    params = BrandSearchParameters(query=user_prompt.query)
    results = client.search(params)
    return results
    

def format_brand_search_results(brand_search_results: BrandSearchResults, keyword: str) -> List[dict]:
    formatted_results = []

    for item in brand_search_results.hits:

        matches = item.linked_records

        # Format matches list

        formatted_result = {
            "type": "brand",
            "index_name": "fincopilot-dim-brand",
            "total_count": brand_search_results.found,
            "extracted_count": len(matches),
            "matched_on": keyword,
            "lookup_value": item.lookup_value,
            "matches": matches
        }

        formatted_results.append(formatted_result)

    return formatted_results

if __name__ == "__main__":
    
    with open('config.json', 'r') as f:
        config = json.load(f)
    client = BrandCloudSearchClient(config['BRAND_SEARCH_DOMAIN_ENDPOINT'])
    
    query="Australia"
    user_prompt = UserPrompt(query="marketplace")

    results = handle_user_prompt_for_Brand(user_prompt=user_prompt, client=client)
    formatted_results = format_brand_search_results(results, query)
    
    print(formatted_results)
Editor is loading...
Leave a Comment