Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
3.2 kB
2
Indexable
import logging
import pika
import threading

from flask import Flask, request
from typing import List, Optional

from config import IMAGES_ENDPOINT, DATA_DIR


class Server:
    def serve_worker_answers(self):
        connection = pika.BlockingConnection(pika.ConnectionParameters(host='rabbitmq', port=5672))
        channel = connection.channel()

        channel.queue_declare(queue='ans_queque', durable=True)  # to make sure that the queue will survive a RabbitMQ node restart

        def callback(ch, method, properties, body):

            ch.basic_ack(delivery_tag=method.delivery_tag)  # Send ack


        channel.basic_qos(prefetch_count=1)  # This uses the basic.qos protocol method to tell RabbitMQ not to give more than one message to a worker at a time
        channel.basic_consume(queue='ans_queque', on_message_callback=callback)

        channel.start_consuming()

    def __init__(self, host, port):
        self._host = host
        self._port = port
        self._processed_images = []
        self._image_id = 0
        self._from_id_to_description = {}

        self._connection = pika.BlockingConnection(pika.ConnectionParameters(host=self._host, port=self._port))
        self._channel = self._connection.channel()

        self._channel.queue_declare(queue='task_queue', durable=True)

        # worker = threading.Thread(target = self.serve_worker_answers)
        # worker.start()


    def store_image(self, image: str) -> int:
        self._channel.basic_publish(
            exchange='',
            routing_key='task_queue',
            body=''.join([image, str(self._image_id)]),
            properties=pika.BasicProperties(
                delivery_mode=2,
            )
        )
        self._from_id_to_description[self._image_id] = None
        self._image_id += 1
        return self._image_id - 1

    def get_processed_images(self) -> List[int]:
        return self._processed_images

    def get_image_description(self, image_id: str) -> Optional[str]:
        if  image_id in self._from_id_to_description:
            return self._from_id_to_description[image_id]
        return None

    def __del__(self):
        self._connection.close()

def create_app() -> Flask:
    """
    Create flask application
    """
    app = Flask(__name__)

    server = Server('rabbitmq', 5672)

    # worker = threading.Thread(target = ServerPostBox, args = (postbox, stub))
    # worker.start()

    @app.route(IMAGES_ENDPOINT, methods=['POST'])
    def add_image():
        body = request.get_json(force=True)
        image_id = server.store_image(body['image_url'])
        return {"image_id": image_id}

    @app.route(IMAGES_ENDPOINT, methods=['GET'])
    def get_image_ids():
        image_ids = server.get_processed_images()
        return {"image_ids": image_ids}

    @app.route(f'{IMAGES_ENDPOINT}/<string:image_id>', methods=['GET'])
    def get_processing_result(image_id):
        result = server.get_image_description(image_id)
        if result is None:
            return "Image not found.", 404
        else:
            return {'description': result}

    return app


app = create_app()

if __name__ == '__main__':
    logging.basicConfig()
    app.run(host='0.0.0.0', port=5000)