Untitled

 avatar
unknown
plain_text
6 months ago
10 kB
3
Indexable
import customtkinter as ctk
import tkinter as tk
from PIL import Image, ImageTk
import cv2
import numpy as np
import os
from datetime import datetime
import json
from typing import Optional
import tflite_runtime.interpreter as tflite
from picamera2 import Picamera2
 
# Set appearance mode and default color theme
ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue")
 
class ProgressBar(ctk.CTkFrame):
    def __init__(self, master, **kwargs):
        super().__init__(master, **kwargs)
        self.progress_bar = ctk.CTkProgressBar(self)
        self.progress_bar.pack(fill="x", padx=10, pady=5)
        self.progress_bar.set(1.0)
 
    def update_progress(self, value: float):
        self.progress_bar.set(value)
 
class FeedbackWindow(ctk.CTkFrame):
    def __init__(self, master, prediction: str, on_feedback, **kwargs):
        super().__init__(master, **kwargs)
 
        # Configure grid
        self.grid_columnconfigure(0, weight=1)
        self.grid_rowconfigure(0, weight=1)
 
        # Main container
        container = ctk.CTkFrame(self)
        container.grid(row=0, column=0, sticky="nsew", padx=20, pady=20)
        container.grid_columnconfigure((0, 1, 2), weight=1)
 
        # Title
        title = ctk.CTkLabel(container, text="Was it correct?", font=("Helvetica", 24, "bold"))
        title.grid(row=0, column=0, columnspan=3, pady=20)
 
        # Prediction display with color coding
        color_map = {
            "Biomüll": "#8B4513",
            "Gelber Sack": "#FFD700",
            "Papier": "#4169E1",
            "Restmüll": "#808080"
        }
        pred_color = color_map.get(prediction, "#FFFFFF")
        pred_label = ctk.CTkLabel(container, text=prediction, 
                                font=("Helvetica", 20),
                                text_color=pred_color)
        pred_label.grid(row=1, column=0, columnspan=3, pady=10)
 
        # Feedback buttons
        thumbs_up = ctk.CTkButton(container, text="👍", 
                                 command=lambda: on_feedback(True),
                                 width=80, height=80,
                                 font=("Helvetica", 30))
        thumbs_up.grid(row=2, column=0, padx=10, pady=20)
 
        thumbs_down = ctk.CTkButton(container, text="👎",
                                   command=lambda: on_feedback(False),
                                   width=80, height=80,
                                   font=("Helvetica", 30))
        thumbs_down.grid(row=2, column=2, padx=10, pady=20)
 
        # Progress bar
        self.progress = ProgressBar(container)
        self.progress.grid(row=3, column=0, columnspan=3, sticky="ew", pady=10)
 
    def update_timer(self, value: float):
        self.progress.update_progress(value)
 
class ImageViewerWindow(ctk.CTkFrame):
    def __init__(self, master, image_dir: str, **kwargs):
        super().__init__(master, **kwargs)
        self.image_dir = image_dir
 
        # Configure grid
        self.grid_columnconfigure(0, weight=1)
        self.grid_rowconfigure(1, weight=1)
 
        # Title
        title = ctk.CTkLabel(self, text="Captured Images", font=("Helvetica", 24, "bold"))
        title.grid(row=0, column=0, pady=20)
 
        # Image grid
        self.image_grid = ctk.CTkScrollableFrame(self)
        self.image_grid.grid(row=1, column=0, sticky="nsew", padx=20, pady=20)
 
        # Back button
        back_btn = ctk.CTkButton(self, text="Back",
                                command=self.master.show_main_window)
        back_btn.grid(row=2, column=0, pady=20)
 
        self.load_images()
 
    def load_images(self):
        for widget in self.image_grid.winfo_children():
            widget.destroy()
 
        images = sorted([f for f in os.listdir(self.image_dir) 
                        if f.endswith(('.jpg', '.png'))],
                       reverse=True)
 
        for i, img_name in enumerate(images):
            img_path = os.path.join(self.image_dir, img_name)
            img = Image.open(img_path)
            img.thumbnail((200, 200))
            photo = ImageTk.PhotoImage(img)
 
            label = ctk.CTkLabel(self.image_grid, image=photo)
            label.image = photo
            label.grid(row=i//3, column=i%3, padx=5, pady=5)
 
class WasteClassifierApp(ctk.CTk):
    def __init__(self):
        super().__init__()
 
        # Initialize camera and model
        self.setup_camera()
        self.setup_model()
 
        # Configure window
        self.title("Waste Classifier")
        self.geometry("1024x600")
        self.resizable(False, False)
 
        # Create frames for different windows
        self.main_window = self.create_main_window()
        self.image_viewer = None
        self.feedback_window = None
 
        # Show main window initially
        self.show_main_window()
 
    def setup_camera(self):
        self.camera = Picamera2()
        self.camera.preview_configuration.main.size = (1920, 1440)
        self.camera.preview_configuration.main.format = "RGB888"
        self.camera.configure("preview")
        self.camera.start()
 
    def setup_model(self):
        MODEL_PATH = "ei-v2-transfer-learning-tensorflow-lite-float32-model.lite"
        self.interpreter = tflite.Interpreter(model_path=MODEL_PATH)
        self.interpreter.allocate_tensors()
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()
        self.class_names = ["Biomüll", "Gelber Sack", "Papier", "Restmüll"]
 
        # Create directory for captured images
        self.image_dir = "captured_images"
        os.makedirs(self.image_dir, exist_ok=True)
 
        # Create directory for feedback data
        self.feedback_dir = "feedback_data"
        os.makedirs(self.feedback_dir, exist_ok=True)
 
    def create_main_window(self) -> ctk.CTkFrame:
        frame = ctk.CTkFrame(self)
        frame.grid_columnconfigure(0, weight=1)
        frame.grid_rowconfigure(0, weight=1)
 
        # Main content area
        content = ctk.CTkFrame(frame, fg_color="transparent")
        content.grid(row=0, column=0, sticky="nsew")
        content.grid_columnconfigure((0, 1), weight=1)
        content.grid_rowconfigure(0, weight=1)
 
        # Button container (bottom)
        btn_container = ctk.CTkFrame(content, fg_color="transparent")
        btn_container.grid(row=1, column=0, columnspan=2, sticky="ew", pady=20)
        btn_container.grid_columnconfigure((0, 1), weight=1)
 
        # Gallery button
        gallery_btn = ctk.CTkButton(btn_container, text="Gallery",
                                   command=self.show_image_viewer,
                                   width=100, height=100)
        gallery_btn.grid(row=0, column=0, padx=20)
 
        # Capture button
        capture_btn = ctk.CTkButton(btn_container, text="Scan",
                                   command=self.capture_and_classify,
                                   width=100, height=100)
        capture_btn.grid(row=0, column=1, padx=20)
 
        return frame
 
    def show_main_window(self):
        if self.image_viewer:
            self.image_viewer.grid_remove()
        if self.feedback_window:
            self.feedback_window.grid_remove()
        self.main_window.grid(row=0, column=0, sticky="nsew")
 
    def show_image_viewer(self):
        self.main_window.grid_remove()
        if not self.image_viewer:
            self.image_viewer = ImageViewerWindow(self, self.image_dir)
        self.image_viewer.grid(row=0, column=0, sticky="nsew")
        self.image_viewer.load_images()
 
    def show_feedback_window(self, prediction: str):
        self.main_window.grid_remove()
        self.feedback_window = FeedbackWindow(
            self, prediction,
            on_feedback=lambda feedback: self.handle_feedback(prediction, feedback)
        )
        self.feedback_window.grid(row=0, column=0, sticky="nsew")
 
        # Start countdown
        self.countdown_id = None
        self.start_countdown()
 
    def start_countdown(self, duration: float = 5.0):
        if self.countdown_id:
            self.after_cancel(self.countdown_id)
 
        def update(remaining: float):
            if remaining <= 0:
                self.show_main_window()
                return
 
            self.feedback_window.update_timer(remaining / duration)
            self.countdown_id = self.after(100, update, remaining - 0.1)
 
        update(duration)
 
    def capture_and_classify(self):
        # Capture image
        image = self.camera.capture_array()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        image_path = os.path.join(self.image_dir, f"capture_{timestamp}.jpg")
        cv2.imwrite(image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
 
        # Preprocess and classify
        processed = cv2.resize(image, 
                             (self.input_details[0]['shape'][2],
                              self.input_details[0]['shape'][1]))
        processed = processed / 255.0
        processed = np.expand_dims(processed, axis=0).astype(np.float32)
 
        self.interpreter.set_tensor(self.input_details[0]['index'], processed)
        self.interpreter.invoke()
        output = self.interpreter.get_tensor(self.output_details[0]['index'])
 
        # Get prediction
        prediction = self.class_names[np.argmax(output[0])]
 
        # Show feedback window
        self.show_feedback_window(prediction)
 
    def handle_feedback(self, prediction: str, was_correct: bool):
        # Save feedback
        feedback_data = {
            "timestamp": datetime.now().isoformat(),
            "prediction": prediction,
            "was_correct": was_correct
        }
 
        feedback_file = os.path.join(self.feedback_dir, "feedback.json")
 
        try:
            with open(feedback_file, "r") as f:
                data = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            data = []
 
        data.append(feedback_data)
 
        with open(feedback_file, "w") as f:
            json.dump(data, f, indent=2)
 
        # Return to main window
        self.show_main_window()
 
if __name__ == "__main__":
    app = WasteClassifierApp()
    app.mainloop()
Editor is loading...
Leave a Comment