Untitled

 avatar
unknown
plain_text
8 months ago
10 kB
3
Indexable
import customtkinter as ctk
import tflite_runtime.interpreter as tflite
import numpy as np
import cv2
from picamera2 import Picamera2
import os
from datetime import datetime
from PIL import Image
import threading
from pathlib import Path

# Set appearance mode and default color theme
ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue")

class TimerProgressBar(ctk.CTkFrame):
    def __init__(self, master, duration=5000, **kwargs):
        super().__init__(master, **kwargs)
        self.duration = duration
        self.remaining = duration
        self.is_running = False
        
        self.progress_bar = ctk.CTkProgressBar(self)
        self.progress_bar.pack(fill="x", padx=10, pady=5)
        self.progress_bar.set(1.0)
        
    def start(self, callback=None):
        self.is_running = True
        self.remaining = self.duration
        self.update_progress(callback)
        
    def update_progress(self, callback=None):
        if not self.is_running:
            return
            
        if self.remaining <= 0:
            self.progress_bar.set(0)
            self.is_running = False
            if callback:
                callback()
            return
            
        progress = self.remaining / self.duration
        self.progress_bar.set(progress)
        self.remaining -= 50  # Update every 50ms
        self.after(50, lambda: self.update_progress(callback))
        
    def stop(self):
        self.is_running = False

class TrashClassifierGUI:
    def __init__(self):
        self.setup_window()
        self.setup_camera()
        self.setup_model()
        self.create_folders()
        self.current_image_path = None
        self.current_prediction = None
        
    def setup_window(self):
        self.root = ctk.CTk()
        self.root.geometry("1024x600")
        self.root.title("Trash Classifier")
        
        # Create main container for all frames
        self.container = ctk.CTkFrame(self.root)
        self.container.pack(fill="both", expand=True)
        
        # Create all frames
        self.frames = {}
        for F in (MainFrame, ClassificationFrame, FeedbackFrame, CategorySelectionFrame, ThankYouFrame):
            frame = F(self.container, self)
            self.frames[F] = frame
            frame.grid(row=0, column=0, sticky="nsew")
            
        self.show_frame(MainFrame)
        
    def setup_camera(self):
        self.camera = Picamera2()
        self.camera.preview_configuration.main.size = (1920, 1440)
        self.camera.preview_configuration.main.format = "RGB888"
        self.camera.configure("preview")
        self.camera.set_controls({"AfMode": 2})
        self.camera.start()
        
    def setup_model(self):
        model_path = "ei-v2-transfer-learning-tensorflow-lite-float32-model.lite"
        self.interpreter = tflite.Interpreter(model_path=model_path)
        self.interpreter.allocate_tensors()
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()
        self.height = self.input_details[0]['shape'][1]
        self.width = self.input_details[0]['shape'][2]
        self.class_names = ["Biomüll", "Gelber Sack", "Papier", "Restmüll"]
        
    def create_folders(self):
        self.base_dir = Path("captured_images")
        for category in self.class_names:
            (self.base_dir / category).mkdir(parents=True, exist_ok=True)
            
    def show_frame(self, frame_class):
        frame = self.frames[frame_class]
        frame.tkraise()
        
    def capture_and_classify(self):
        image = self.camera.capture_array()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.current_image_path = self.base_dir / f"capture_{timestamp}.jpg"
        cv2.imwrite(str(self.current_image_path), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        
        # Preprocess and classify
        processed_image = self.preprocess_image(image)
        self.interpreter.set_tensor(self.input_details[0]['index'], processed_image)
        self.interpreter.invoke()
        output = self.interpreter.get_tensor(self.output_details[0]['index'])
        
        # Get prediction
        self.current_prediction = {
            'class_index': np.argmax(output[0]),
            'class_name': self.class_names[np.argmax(output[0])],
            'confidence': float(output[0][np.argmax(output[0])] * 100)
        }
        
        return self.current_prediction
        
    def preprocess_image(self, image):
        resized = cv2.resize(image, (self.width, self.height))
        normalized = resized / 255.0
        return np.expand_dims(normalized, axis=0).astype(np.float32)
        
    def save_image_with_feedback(self, correct, new_category=None):
        if self.current_image_path is None:
            return
            
        if correct:
            category = self.current_prediction['class_name']
        elif new_category:
            category = new_category
        else:
            return
            
        new_path = self.base_dir / category / self.current_image_path.name
        self.current_image_path.rename(new_path)
        
class MainFrame(ctk.CTkFrame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller
        
        # Layout
        self.grid_columnconfigure(0, weight=1)
        self.grid_columnconfigure(1, weight=1)
        self.grid_rowconfigure(0, weight=1)
        
        # Scan button (bottom left)
        self.scan_button = ctk.CTkButton(
            self,
            text="Scan",
            command=self.start_scan,
            width=120,
            height=120,
            corner_radius=60
        )
        self.scan_button.grid(row=0, column=0, padx=20, pady=20, sticky="sw")
        
    def start_scan(self):
        prediction = self.controller.capture_and_classify()
        self.controller.show_frame(ClassificationFrame)
        self.controller.frames[ClassificationFrame].show_prediction(prediction)

class ClassificationFrame(ctk.CTkFrame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller
        
        # Timer
        self.timer = TimerProgressBar(self, duration=5000)
        self.timer.pack(fill="x", pady=(10, 0))
        
        # Content
        self.content = ctk.CTkFrame(self)
        self.content.pack(expand=True, fill="both", padx=20, pady=20)
        
        # Prediction display
        self.prediction_label = ctk.CTkLabel(
            self.content,
            text="",
            font=("Helvetica", 24, "bold")
        )
        self.prediction_label.pack(pady=20)
        
        # Feedback buttons
        self.feedback_frame = ctk.CTkFrame(self.content)
        self.feedback_frame.pack(pady=20)
        
        self.correct_button = ctk.CTkButton(
            self.feedback_frame,
            text="👍",
            command=lambda: self.give_feedback(True),
            width=80,
            height=80
        )
        self.correct_button.pack(side="left", padx=10)
        
        self.incorrect_button = ctk.CTkButton(
            self.feedback_frame,
            text="👎",
            command=lambda: self.give_feedback(False),
            width=80,
            height=80
        )
        self.incorrect_button.pack(side="left", padx=10)
        
    def show_prediction(self, prediction):
        self.prediction_label.configure(
            text=f"{prediction['class_name']}\n{prediction['confidence']:.1f}%"
        )
        self.timer.start(lambda: self.controller.show_frame(MainFrame))
        
    def give_feedback(self, is_correct):
        self.timer.stop()
        if is_correct:
            self.controller.save_image_with_feedback(True)
            self.controller.show_frame(ThankYouFrame)
        else:
            self.controller.show_frame(CategorySelectionFrame)

class CategorySelectionFrame(ctk.CTkFrame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller
        
        # Timer
        self.timer = TimerProgressBar(self, duration=5000)
        self.timer.pack(fill="x", pady=(10, 0))
        
        # Content
        self.content = ctk.CTkFrame(self)
        self.content.pack(expand=True, fill="both", padx=20, pady=20)
        
        self.label = ctk.CTkLabel(
            self.content,
            text="Select correct category:",
            font=("Helvetica", 20)
        )
        self.label.pack(pady=20)
        
        # Category buttons
        for category in controller.class_names:
            btn = ctk.CTkButton(
                self.content,
                text=category,
                command=lambda c=category: self.select_category(c)
            )
            btn.pack(pady=10)
            
    def show(self):
        self.timer.start(lambda: self.controller.show_frame(MainFrame))
        
    def select_category(self, category):
        self.timer.stop()
        self.controller.save_image_with_feedback(False, category)
        self.controller.show_frame(ThankYouFrame)

class ThankYouFrame(ctk.CTkFrame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller
        
        # Timer
        self.timer = TimerProgressBar(self, duration=3000)
        self.timer.pack(fill="x", pady=(10, 0))
        
        # Content
        self.content = ctk.CTkFrame(self)
        self.content.pack(expand=True, fill="both", padx=20, pady=20)
        
        self.label = ctk.CTkLabel(
            self.content,
            text="Thank you for your feedback!",
            font=("Helvetica", 24, "bold")
        )
        self.label.pack(expand=True)
        
    def show(self):
        self.timer.start(lambda: self.controller.show_frame(MainFrame))

if __name__ == "__main__":
    app = TrashClassifierGUI()
    app.root.mainloop()
Editor is loading...
Leave a Comment