Untitled

 avatar
unknown
plain_text
11 days ago
20 kB
6
Indexable
# api/request_processor.py
import asyncio
import json
import logging
import re
import uuid
from typing import Callable

import websockets
from asgiref.sync import sync_to_async
from django.conf import settings
from django.utils import timezone

from api.models import DataRequest
from api.utils.remote_api import (
    initiate_copy_async,
    initiate_crop_async,
    initiate_zip_async,
    initiate_delete_async,
    initiate_stop_async,
    pull_file_with_progress,
    RemoteAPIError,
    get_remote_ws_url
)
from api.services.redis_events import publish_request_update
from api.utils.enums import ProcessingAction

logger = logging.getLogger(__name__)


def sanitize_path(path: str) -> str:
    """
    Converts an arbitrary string into a safe folder name by replacing invalid
    characters with underscores.
    """
    return re.sub(r'[\\\/:*?"<>| ]+', '_', path)


class RequestProcessor:
    """
    Processes a single DataRequest in stages.

      0: Waiting
      1: Copying to Cache
      2: On Cache
      3: Making Crops
      4: Zipping
      5: Zip Completed
      6: Downloading Zip to Local
      7: On Local (Ready)
    """
    def __init__(self, req: DataRequest):
        self.req = req
        self.attempts_per_stage = {}
        # Default zip source directory will be updated as needed.
        self.zip_source_dir = f"{settings.REMOTE_CACHE_PATH}/{self.req.id}"
        self.current_op_id = None
        self.download_task = None

    async def get_data_from_request(self, *attrs) -> dict:
        await sync_to_async(self.req.refresh_from_db)()
        return {attr: getattr(self.req, attr) for attr in attrs}

    def __lt__(self, other):
        if not isinstance(other, RequestProcessor):
            return NotImplemented
        self_uncached = self.req.stage_int < 2
        other_uncached = other.req.stage_int < 2
        if self_uncached != other_uncached:
            return self_uncached
        if self.req.priority != other.req.priority:
            return self.req.priority < other.req.priority
        return self.req.created_at < other.req.created_at

    async def process_current_stage(self):
        data = await self.get_data_from_request("is_stopped", "stage_int", "id")
        if data["is_stopped"]:
            logger.info(f"Request ID={data['id']} is stopped. Aborting current stage.")
            return

        stage = data["stage_int"]
        req_id = data["id"]
        logger.info(f"Processing stage {stage} for request ID={req_id}")
        try:
            if stage == 0:
                await self.advance_stage(1, "Starting copy to cache (In Progress)")
            elif stage == 1:
                await self.process_copy_to_cache()
            elif stage == 2:
                # Decide whether to do cropping based on requested data types.
                data_types = self.req.data_types or []
                if 1 in data_types:
                    await self.advance_stage(3, "Initiating cropping operation (Making Crops)")
                else:
                    await self.advance_stage(4, "Skipping cropping. Initiating zipping operation.")
            elif stage == 3:
                await self.process_making_crops()
            elif stage == 4:
                await self.process_zipping()
            elif stage == 5:
                await self.advance_stage(6, "Initiating ZIP file download (Downloading Zip to Local)")
            elif stage == 6:
                await self.process_downloading()
            elif stage >= 7:
                # Before logging completion, check if stopped.
                data = await self.get_data_from_request("is_stopped")
                if data["is_stopped"]:
                    logger.info(f"Request ID={req_id} is stopped; not marking complete.")
                    return
                logger.info(f"Request ID={req_id} is complete.")
            else:
                logger.warning(f"Unknown stage {stage} for request ID={req_id}")
        except RemoteAPIError as e:
            await self.handle_stage_error(stage, str(e))
        except Exception as e:
            logger.exception(f"Unexpected error in stage {stage} for request ID={req_id}: {e}")
            await self.handle_stage_error(stage, str(e))

    async def process_copy_to_cache(self):
        data = await self.get_data_from_request("id", "source", "is_stopped", "start_date_time", "end_date_time")
        req_id = data["id"]
        if data["is_stopped"]:
            return
        source = data["source"]
        data_types = self.req.data_types or []
        mod_time_range = await self._build_mod_time_filter()

        # If snapshots (0) or crops (1) are requested, copy snapshots.
        if (0 in data_types) or (1 in data_types):
            copy_dest_snapshots = f"{settings.REMOTE_CACHE_PATH}/{req_id}/snapshots"
            op_id, ws_task = await self._start_operation(stage=1, source=source)
            self.current_op_id = op_id
            await initiate_copy_async(
                source=source,
                source_dir=settings.REMOTE_SNAPSHOTS_PATH,
                dest_dir=copy_dest_snapshots,
                mod_time_range=mod_time_range,
                operation_id=op_id,
                recursive=False
            )
            await ws_task
            # Check if stopped after copying snapshots.
            data = await self.get_data_from_request("is_stopped")
            if data["is_stopped"]:
                return

        # Process each custom path.
        custom_paths = self.req.custom_paths or []
        for custom_path in custom_paths:
            dest_custom = f"{settings.REMOTE_CACHE_PATH}/{req_id}/extra_data/{sanitize_path(custom_path)}"
            op_id, ws_task = await self._start_operation(stage=1, source=source)
            self.current_op_id = op_id
            await initiate_copy_async(
                source=source,
                source_dir=custom_path,
                dest_dir=dest_custom,
                mod_time_range=None,
                operation_id=op_id,
                recursive=True
            )
            await ws_task
            data = await self.get_data_from_request("is_stopped")
            if data["is_stopped"]:
                return

        await publish_request_update(
            self.req,
            action=ProcessingAction.STAGE_CHANGE,
            status="Copy operations complete. Data is now on cache.",
            new_stage=1
        )
        # Set the default zip source directory to the entire request folder.
        self.zip_source_dir = f"{settings.REMOTE_CACHE_PATH}/{req_id}"
        await self.advance_stage(2, "Copy complete. Data is now on cache.")

    async def process_making_crops(self):
        data = await self.get_data_from_request("id", "source", "is_stopped")
        req_id = data["id"]
        if data["is_stopped"]:
            return
        source = data["source"]
        snapshots_dir = f"{settings.REMOTE_CACHE_PATH}/{req_id}/snapshots"
        crop_output_dir = f"{settings.REMOTE_CACHE_PATH}/{req_id}/crops"

        await publish_request_update(
            self.req,
            action=ProcessingAction.STAGE_CHANGE,
            status="Initiating cropping operation (Making Crops)",
            new_stage=3
        )

        op_id, ws_task = await self._start_operation(stage=3, source=source)
        self.current_op_id = op_id
        await initiate_crop_async(
            source=source,
            snapshots_dir=snapshots_dir,
            output_dir=crop_output_dir,
            operation_id=op_id
        )
        await ws_task
        data = await self.get_data_from_request("is_stopped")
        if data["is_stopped"]:
            return

        # Determine the zip source directory.
        # If the only requested data type is crops ([1] only), then zip only the crops folder.
        data_types = self.req.data_types or []
        if set(data_types) == {1}:
            self.zip_source_dir = crop_output_dir
        else:
            self.zip_source_dir = f"{settings.REMOTE_CACHE_PATH}/{req_id}"
        await self.advance_stage(4, "Cropping complete. Ready to zip.")

    async def process_zipping(self):
        data = await self.get_data_from_request("id", "source", "is_stopped")
        req_id = data["id"]
        if data["is_stopped"]:
            return
        source = data["source"]
        zip_path = f"{settings.REMOTE_CACHE_PATH}/{req_id}.zip"

        await publish_request_update(
            self.req,
            action=ProcessingAction.STAGE_CHANGE,
            status="Initiating zipping operation (Zipping)",
            new_stage=4
        )

        op_id, ws_task = await self._start_operation(stage=4, source=source)
        self.current_op_id = op_id
        await initiate_zip_async(
            source=source,
            source_dir=self.zip_source_dir,
            zip_path=zip_path,
            operation_id=op_id
        )
        await ws_task
        data = await self.get_data_from_request("is_stopped")
        if data["is_stopped"]:
            return
        await self.advance_stage(5, "Zipping complete. ZIP file ready.")

    async def process_downloading(self):
        data = await self.get_data_from_request("id", "source", "is_stopped")
        req_id = data["id"]
        if data["is_stopped"]:
            return
        source = data["source"]
        remote_zip = f"{settings.REMOTE_CACHE_PATH}/{req_id}.zip"
        local_zip_path = f"{settings.LOCAL_DOWNLOAD_TARGET_PATH}_{req_id}.zip"

        await publish_request_update(
            self.req,
            action=ProcessingAction.STATUS_UPDATE,
            status="Downloading ZIP file..."
        )

        self.download_task = asyncio.create_task(
            pull_file_with_progress(
                source=source,
                remote_path=remote_zip,
                local_path=local_zip_path,
                callback=self.generate_progress_callback(req_id, 6)
            )
        )
        try:
            await self.download_task
        except asyncio.CancelledError:
            logger.info(f"Download for request {req_id} was cancelled.")
            return
        finally:
            self.download_task = None

        op_id, _ = await self._start_operation(stage=6, source=source)
        self.current_op_id = op_id
        await initiate_delete_async(
            source=source,
            path=f"{settings.REMOTE_CACHE_PATH}/{req_id}",
            operation_id=op_id
        )
        op_id, _ = await self._start_operation(stage=6, source=source)
        self.current_op_id = op_id
        await initiate_delete_async(
            source=source,
            path=f"{settings.REMOTE_CACHE_PATH}/{req_id}.zip",
            operation_id=op_id
        )

        await publish_request_update(
            self.req,
            action=ProcessingAction.STATUS_UPDATE,
            status="Initiated deletion from cache."
        )
        data = await self.get_data_from_request("is_stopped")
        if data["is_stopped"]:
            return
        await self.advance_stage(7, "Download complete. Files available locally.")
        # Before publishing complete, check again.
        data = await self.get_data_from_request("is_stopped")
        if data["is_stopped"]:
            return
        await publish_request_update(
            self.req,
            action=ProcessingAction.PROCESS_COMPLETE,
            status="All done"
        )

    async def _listen_ws(self, ws_url: str, stage: int, operation_id: str):
        await self.get_data_from_request("id")
        try:
            async with websockets.connect(ws_url) as websocket:
                logger.info(f"Connected to WS for operation_id={operation_id} at {ws_url}")
                while True:
                    message_raw = await websocket.recv()
                    message = json.loads(message_raw)
                    status = message.get("status")
                    msg = message.get("message", "")
                    logger.debug(f"Received WS message for operation_id={operation_id}: {message}")

                    if status == "progress":
                        await publish_request_update(
                            self.req,
                            action=ProcessingAction.STATUS_UPDATE,
                            status=msg
                        )
                    elif status == "complete":
                        await publish_request_update(
                            self.req,
                            action=ProcessingAction.STATUS_UPDATE,
                            status=f"Operation (stage {stage}) completed: {msg}"
                        )
                        break
                    elif status == "stopped":
                        await publish_request_update(
                            self.req,
                            action=ProcessingAction.STATUS_UPDATE,
                            status=f"Operation (stage {stage}) stopped: {msg}"
                        )
                        break
                    elif status == "error":
                        await publish_request_update(
                            self.req,
                            action=ProcessingAction.PROCESS_ERROR,
                            status=f"Operation (stage {stage}) error: {msg}",
                            error_message=msg,
                            is_failed=True
                        )
                        raise RemoteAPIError(f"Operation (stage {stage}) failed: {msg}")
        except Exception as e:
            logger.exception(f"Error during WS communication for op_id={operation_id}: {e}")
            raise RemoteAPIError(f"WebSocket communication failed: {e}")

    def generate_progress_callback(self, req_id: int, stage: int) -> Callable[[int, int], None]:
        last_reported = [-1]

        async def callback(downloaded: int, total: int):
            await self.get_data_from_request("id")
            if total > 0:
                percent = (downloaded / total) * 100
                current_milestone = (int(percent) // 5) * 5
                if current_milestone > last_reported[0]:
                    downloaded_mb = downloaded / (1024 * 1024)
                    total_mb = total / (1024 * 1024)
                    status_msg = f"Downloading ZIP: {current_milestone}% ({downloaded_mb:.1f}MB/{total_mb:.1f}MB)"
                    await publish_request_update(
                        self.req,
                        action=ProcessingAction.STATUS_UPDATE,
                        status=status_msg
                    )
                    last_reported[0] = current_milestone
            else:
                downloaded_mb = downloaded / (1024 * 1024)
                status_msg = f"Downloading ZIP: {downloaded_mb:.1f}MB downloaded"
                await publish_request_update(
                    self.req,
                    action=ProcessingAction.STATUS_UPDATE,
                    status=status_msg
                )
        return callback

    async def handle_stage_error(self, stage: int, error_msg: str):
        data = await self.get_data_from_request("id")
        req_id = data["id"]
        attempts = self.attempts_per_stage.get(stage, 0) + 1
        self.attempts_per_stage[stage] = attempts
        logger.warning(f"Stage {stage} for request ID={req_id} failed on attempt {attempts}: {error_msg}")
        await publish_request_update(
            self.req,
            action=ProcessingAction.STATUS_UPDATE,
            status=f"Stage {stage} failed on attempt {attempts}: {error_msg}"
        )
        if attempts < 3:
            logger.info(f"Retrying stage {stage} for request ID={req_id} in {settings.TIME_BETWEEN_RETRIES} seconds (Attempt {attempts + 1})")
            await publish_request_update(
                self.req,
                action=ProcessingAction.STATUS_UPDATE,
                status=f"Retrying stage {stage} for request ID={req_id} in {settings.TIME_BETWEEN_RETRIES} seconds (Attempt {attempts + 1})"
            )
            await asyncio.sleep(settings.TIME_BETWEEN_RETRIES)
            try:
                await self.process_current_stage()
                self.attempts_per_stage[stage] = 0
            except RemoteAPIError as e2:
                logger.error(f"Request ID={req_id} stage {stage} still failing after retry: {e2}")
                await self.mark_failed(stage, str(e2))
        else:
            await self.mark_failed(stage, error_msg)

    async def mark_failed(self, stage: int, error_msg: str):
        data = await self.get_data_from_request("id")
        req_id = data["id"]
        logger.error(f"Request ID={req_id} marked as failed at stage {stage}: {error_msg}")
        self.req.is_failed = True
        await sync_to_async(self.req.save)()
        await publish_request_update(
            self.req,
            action=ProcessingAction.PROCESS_ERROR,
            status=f"Stage {stage} failed after retries: {error_msg}",
            error_message=error_msg,
            is_failed=True
        )

    async def advance_stage(self, new_stage: int, status_msg: str):
        data = await self.get_data_from_request("id", "is_stopped")
        if data["is_stopped"]:
            logger.info(f"Request ID={data['id']} is stopped; not advancing stage.")
            return
        req_id = data["id"]
        logger.info(f"Request ID={req_id} advancing to stage {new_stage}: {status_msg}")
        self.req.stage_int = new_stage
        await sync_to_async(self.req.save)()
        await publish_request_update(
            self.req,
            action=ProcessingAction.STAGE_CHANGE,
            status=status_msg,
            new_stage=new_stage
        )

    async def _build_mod_time_filter(self) -> dict:
        data = await self.get_data_from_request("start_date_time", "end_date_time", "id")
        start_dt = timezone.localtime(data["start_date_time"])
        end_dt = timezone.localtime(data["end_date_time"])
        mod_time_start = int(start_dt.timestamp())
        mod_time_end = int(end_dt.timestamp())
        logger.debug(f"Request ID={data['id']} generated mod_time_range: {mod_time_start} to {mod_time_end}")
        return {"mod_time_start": mod_time_start, "mod_time_end": mod_time_end}

    async def _start_operation(self, stage: int, source: str):
        operation_id = str(uuid.uuid4())
        ws_url = get_remote_ws_url(source, operation_id)
        ws_task = asyncio.create_task(self._listen_ws(ws_url, stage, operation_id))
        return operation_id, ws_task

    async def stop(self):
        data = await self.get_data_from_request("id", "source")
        req_id = data["id"]

        if self.download_task is not None and not self.download_task.done():
            logger.info(f"Cancelling active download for request {req_id}")
            self.download_task.cancel()
            try:
                await self.download_task
            except asyncio.CancelledError:
                logger.info(f"Download for request {req_id} cancelled successfully.")
            self.download_task = None

        if self.current_op_id:
            try:
                await initiate_stop_async(source=data["source"], operation_id=self.current_op_id)
                logger.info(f"Sent stop command for remote operation {self.current_op_id} (Request ID={req_id})")
            except Exception as e:
                logger.exception(f"Failed to stop remote operation {self.current_op_id} for request {req_id}: {e}")

why if snapshots and cropps are ticked and creates nice folders for each like snapshots crops

but if it's a combination like snapshots and a custom_path this doesnt work it's not creating for example a folder "crops" and "extra_data" 

i want to make those possible give me fully fixed working full code
Leave a Comment