Stable Diffusion Generator

 avatar
unknown
plain_text
a year ago
5.4 kB
122
No Index
import base64
import io
import diffusers
import flask
from flask import request, jsonify
from flask_cors import CORS
import torch
import PIL
from PIL import Image
import uuid
import time
import os


queue = []
results = {}

queue_busy = False

torch_dtype_value = torch.float16

class QueueRequest:
    def __init__(self, request_id, data, image_data=None):
        self.request_id = request_id
         = data
        self.status = "queued"

app = flask.Flask(__name__)
CORS(app)

pipelines = {
    "realisticvision": {'path': "models/realisticVisionV60B1_v51VAE.safetensors", 'loaded': None},
}

def process_request(queue_item):
    data = queue_item.data
    
    base64_image_array = []
    
    # if the model isnt loaded, load it:
    if pipelines[data['model']]['loaded'] is None:
        pipeline = diffusers.StableDiffusionPipeline.from_single_file(
            pipelines[data['model']]['path'],
            torch_dtype=torch_dtype_value
        )
        pipeline.enable_vae_slicing()
        pipeline.enable_vae_tiling()
        
        pipelines[data['model']]['loaded'] = pipeline

    pipeline = pipelines[data['model']]['loaded']
    
    # set the scheduler:
    pipeline.scheduler = diffusers.EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
    
    pipeline.to('cuda:1')
    
    image = pipeline(
        prompt=data['prompt'],
        negative_prompt=data['negative_prompt'],
        width=data['width'],
        height=data['height'],
        num_inference_steps=20,
        num_images_per_prompt=1,
    ).images[0]
    
    pipeline.to('cpu')
    
    os.makedirs('outputs', exist_ok=True)
    
    image.save(f'outputs/output-{queue_item.request_id}.png')
    
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    base64_image_array.append(img_str)
    
    # append the result to the results list, with the request_id as the key:
    results[queue_item.request_id] = {
        "images": base64_image_array,
    }
        
    queue_item.status = "completed"
    return "processed"

def process_queue():
    global queue_busy
    
    while True:
        time.sleep(0.01)  # Short sleep to prevent CPU overutilization
        if not queue_busy:
                        
            if queue:  # Check if the queue is not empty
                queue_item = queue[0]  # Get the first item
                if queue_item.status == "queued":
                    queue_item.status = "waiting"
                    queue_busy = True
                    
                    result = process_request(queue_item)
                    
                    queue_busy = False
                    
                    if result == "processed":
                        queue.remove(queue_item)  # Remove the item from the queue
                    elif queue_item.status in ["completed", "error", "skipped"]:
                        queue.remove(queue_item)  # Remove the item from the queue
            
        # Sleep if no unprocessed request is found
        if not any(item.status == "queued" for item in queue):
            time.sleep(0.5)

# generate endpoint:
u/app.route('/generate', methods=['POST'])
def generate():
    
    data = request.json
    print(data)
    
    # sanitize the data:
    data['width'] = int(data['width'])
    data['height'] = int(data['height'])
    
    request_id = str(uuid.uuid4())
    queue_item = QueueRequest(request_id, data)
    
    queue.append(queue_item)
    
    position = len(queue)
    
    return jsonify({"status": "queued", "request_id": request_id, "position": position, "queue_length": position}), 202

u/app.route('/queue_position/<request_id>', methods=['GET'])
def check_queue_position(request_id):

    for index, item in enumerate(queue):
        if item.request_id == request_id:
            return jsonify({"status": "waiting", "request_id": request_id, "position": index + 1, "queue_length": len(queue)}), 200
        
    if request_id in results:
        
        if results[request_id].get("status") == "error":
            return jsonify({"status": "error", "message": results[request_id].get("message")}), 200
        
        return jsonify({"status": "completed", "request_id": request_id}), 200
    
    return jsonify({"status": "not found", "message": "Invalid request_id"}), 404


u/app.route('/result/<request_id>', methods=['GET'])
def get_result(request_id):
    result = results.get(request_id)
    # remove the result from the results dictionary:
    results.pop(request_id, None)
    if result:
        return jsonify(result)
    else:
        return jsonify({"status": "processing"}), 202 


if __name__ == '__main__':
    import threading
    threading.Thread(target=process_queue, daemon=True).start()
    app.run(port=5678)self.data
Editor is loading...
Leave a Comment