Untitled

 avatar
unknown
plain_text
25 days ago
5.0 kB
3
Indexable
import asyncio
from pymilvus import Collection, connections, utility
from typing import List, Dict, Any, Optional, Tuple
import time

class MilvusProgressTracker:
    def __init__(self, collection_name: str):
        self.collection_name = collection_name
        self.total_records = 0
        self.inserted_records = 0
        self.flush_progress = 0
        self.overall_progress = 0
        self.start_time = time.time()

    def update_insert_progress(self, batch_size: int) -> None:
        self.inserted_records += batch_size
        self._calculate_overall_progress()

    def update_flush_progress(self, progress: float) -> None:
        self.flush_progress = progress
        self._calculate_overall_progress()

    def _calculate_overall_progress(self) -> None:
        # Insert is 80% of the process, flush is 20%
        if self.total_records > 0:
            insert_weight = 0.8
            flush_weight = 0.2

            insert_progress = (self.inserted_records / self.total_records) * 100

            self.overall_progress = (insert_progress * insert_weight) + (self.flush_progress * flush_weight)

    def get_progress_stats(self) -> Dict[str, Any]:
        elapsed_time = time.time() - self.start_time

        if self.inserted_records > 0:
            records_per_second = self.inserted_records / elapsed_time
            estimated_total_time = self.total_records / records_per_second if records_per_second > 0 else 0
            estimated_remaining_time = estimated_total_time - elapsed_time if estimated_total_time > 0 else 0
        else:
            records_per_second = 0
            estimated_remaining_time = 0

        return {
            "overall_progress": self.overall_progress,
            "insert_progress": (self.inserted_records / self.total_records * 100) if self.total_records > 0 else 0,
            "flush_progress": self.flush_progress,
            "elapsed_time": elapsed_time,
            "records_per_second": records_per_second,
            "estimated_remaining_time": estimated_remaining_time
        }

async def insert_with_detailed_progress(
    collection: Collection,
    data: List[Dict[str, Any]],
    batch_size: int = 1000,
    progress_callback: Optional[callable] = None
) -> List[int]:
    tracker = MilvusProgressTracker(collection.name)
    tracker.total_records = len(data)
    all_ids = []

    # Process in batches
    for i in range(0, tracker.total_records, batch_size):
        batch = data[i:i+batch_size]

        # Perform async insert
        mr = await collection.insert(batch)

        # Update progress
        tracker.update_insert_progress(len(batch))

        if progress_callback:
            progress_callback(tracker.get_progress_stats())

        all_ids.extend(mr.primary_keys)

    # Flush data with progress tracking
    await flush_with_detailed_progress(collection, tracker, progress_callback)

    return all_ids

async def flush_with_detailed_progress(
    collection: Collection,
    tracker: MilvusProgressTracker,
    progress_callback: Optional[callable] = None
) -> None:
    # Get initial segment info
    initial_segments = utility.get_query_segment_info(collection.name)
    initial_growing = sum(1 for seg in initial_segments if seg.state == "Growing")

    # Start async flush
    future = collection.flush(_async=True)

    # Poll for progress
    while not future.done():
        await asyncio.sleep(0.5)

        # Check current segment state
        current_segments = utility.get_query_segment_info(collection.name)
        current_growing = sum(1 for seg in current_segments if seg.state == "Growing")

        if initial_growing > 0:
            # Calculate progress based on how many segments are still growing
            flush_progress = ((initial_growing - current_growing) / initial_growing) * 100
            tracker.update_flush_progress(flush_progress)

            if progress_callback:
                progress_callback(tracker.get_progress_stats())

    # Ensure flush is complete
    await future

    tracker.update_flush_progress(100)
    if progress_callback:
        progress_callback(tracker.get_progress_stats())

# Example usage
async def main():
    connections.connect("default", host="localhost", port="19530")
    collection = Collection("my_collection")

    # Example data
    data = [{"id": i, "vector": [i/100]*128, "text": f"text_{i}"} for i in range(10000)]

    # Define progress callback
    def detailed_progress_update(stats):
        print(f"Overall: {stats['overall_progress']:.2f}%, "
              f"Insert: {stats['insert_progress']:.2f}%, "
              f"Flush: {stats['flush_progress']:.2f}%, "
              f"Speed: {stats['records_per_second']:.2f} records/sec, "
              f"ETA: {stats['estimated_remaining_time']:.2f} sec")

    # Insert with detailed progress tracking
    ids = await insert_with_detailed_progress(collection, data, progress_callback=detailed_progress_update)
    print(f"Inserted {len(ids)} records")

if __name__ == "__main__":
    asyncio.run(main())
Editor is loading...
Leave a Comment