Untitled

 avatar
unknown
plain_text
20 days ago
13 kB
8
Indexable
import numpy as np
import open3d as o3d
import cv2
import mediapipe as mp
from pathlib import Path
import logging
import traceback
import os

# Suppress TensorFlow warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

#Open3d window customization
def custom_draw_geometries(geometries, window_name="Open3D", width=1024, height=768):
    viewer = o3d.visualization.Visualizer()
    viewer.create_window(window_name, width=width, height=height)
    
    for geometry in geometries:
        viewer.add_geometry(geometry)
    
    opt = viewer.get_render_option()
    opt.background_color = np.asarray([0.5, 0.5, 0.5])
    opt.point_size = 2.0
    opt.line_width = 1.0
    
    ctr = viewer.get_view_control()
    ctr.set_zoom(0.8)
    ctr.set_front([0, 0, -1])
    ctr.set_up([0, -1, 0])
    ctr.set_lookat([0, 0, 0])
    
    viewer.run()
    viewer.destroy_window()

class SingleImageFaceReconstructor:
    def __init__(self):
        self.mp_face_mesh = mp.solutions.face_mesh.FaceMesh(
            static_image_mode=True,
            max_num_faces=1,
            min_detection_confidence=0.7,
            refine_landmarks=True
        )
        
        self.mp_drawing = mp.solutions.drawing_utils
        self.mp_drawing_styles = mp.solutions.drawing_styles
        self.init_face_triangles()

    #Initialize triangles from MediaPipe's tesselation
    def init_face_triangles(self):
        connections = mp.solutions.face_mesh.FACEMESH_TESSELATION
        self.FACE_TRIANGLES = []
        
        vertex_map = {}
        for connection in connections:
            v1, v2 = connection
            if v1 not in vertex_map:
                vertex_map[v1] = set()
            if v2 not in vertex_map:
                vertex_map[v2] = set()
            vertex_map[v1].add(v2)
            vertex_map[v2].add(v1)
        
        processed = set()
        for v1 in vertex_map:
            for v2 in vertex_map[v1]:
                if (v1, v2) in processed or (v2, v1) in processed:
                    continue
                common_vertices = vertex_map[v1].intersection(vertex_map[v2])
                for v3 in common_vertices:
                    if v3 > v2:
                        self.FACE_TRIANGLES.append([v1, v2, v3])
                processed.add((v1, v2))
        
        self.FACE_TRIANGLES = np.array(self.FACE_TRIANGLES, dtype=np.int32)

    #Interpolate between landmarks to create denser point cloud
    def interpolate_points(self, points, colors, subdivisions=4):
        dense_points = []
        dense_colors = []
        
        for triangle in self.FACE_TRIANGLES:
            p1, p2, p3 = points[triangle]
            c1, c2, c3 = colors[triangle]
            
            for i in range(subdivisions + 1):
                for j in range(subdivisions + 1 - i):
                    for k in range(subdivisions + 1 - i - j):
                        if i + j + k == subdivisions:
                            a = i / subdivisions
                            b = j / subdivisions
                            c = k / subdivisions
                            
                            point = (a * p1 + b * p2 + c * p3)
                            noise = np.random.normal(0, 0.0001)
                            point[2] += noise
                            dense_points.append(point)
                            
                            color = (a * c1 + b * c2 + c * c3)
                            dense_colors.append(color)
        
        return np.array(dense_points), np.array(dense_colors)

    #Reconstructing the face (creating point clouds)
    def reconstruct_face(self, image_path, output_path, debug_path=None):
        try:
            image = cv2.imread(str(image_path))
            if image is None:
                raise RuntimeError(f"Failed to load image: {image_path}")
            
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            h, w = image.shape[:2]
            
            results = self.mp_face_mesh.process(image_rgb)
            if not results.multi_face_landmarks:
                raise RuntimeError("No face detected in the image")
            
            landmarks = results.multi_face_landmarks[0].landmark
            
            if debug_path:
                debug_image = image.copy()
                self.mp_drawing.draw_landmarks(
                    image=debug_image,
                    landmark_list=results.multi_face_landmarks[0],
                    connections=mp.solutions.face_mesh.FACEMESH_TESSELATION,
                    landmark_drawing_spec=None,
                    connection_drawing_spec=self.mp_drawing_styles.get_default_face_mesh_tesselation_style()
                )
                cv2.imwrite(str(debug_path), debug_image)
            
            points = []
            colors = []
            
            face_points = np.array([[lm.x * w, lm.y * h] for lm in landmarks])
            face_bbox = cv2.boundingRect(face_points.astype(np.float32))
            face_width = face_bbox[2]
            depth_scale = (face_width / w) * 0.5
            
            nose_tip = landmarks[1]
            left_ear = landmarks[234]
            right_ear = landmarks[454]
            
            avg_ear_z = (left_ear.z + right_ear.z) / 2
            depth_range = abs(nose_tip.z - avg_ear_z)
            
            for idx, landmark in enumerate(landmarks):
                x = landmark.x - 0.5
                y = -(landmark.y - 0.5)
                
                relative_depth = (landmark.z - avg_ear_z) / (depth_range + 1e-6)
                z = relative_depth * depth_scale
                
                points.append([x, y, z])
                
                img_x = int(landmark.x * w)
                img_y = int(landmark.y * h)
                
                if 0 <= img_y < h-1 and 0 <= img_x < w-1:
                    patch = image_rgb[max(0, img_y-1):min(h, img_y+2),
                                    max(0, img_x-1):min(w, img_x+2)]
                    if patch.size > 0:
                        filtered_color = cv2.bilateralFilter(patch, 3, 75, 75)[0][0] / 255.0
                        colors.append(filtered_color)
                    else:
                        colors.append([0.7, 0.7, 0.7])
                else:
                    colors.append([0.7, 0.7, 0.7])
            
            points = np.array(points, dtype=np.float64)
            colors = np.array(colors, dtype=np.float64)
            
            return self.create_mesh(points, colors, output_path, image_path)
            
        except Exception as e:
            logging.error(f"Face reconstruction failed: {str(e)}")
            logging.error(traceback.format_exc())
            raise
    
    #Creating mesh from the point clouds
    def create_mesh(self, points, colors, output_path, image_path):
        try:
            image = cv2.imread(str(image_path))
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            h, w = image.shape[:2]

            # Increasing point density
            dense_points, dense_colors = self.interpolate_points(points, colors, subdivisions=18)
            
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(dense_points)
            pcd.colors = o3d.utility.Vector3dVector(dense_colors)
            
            # More aggressive outlier removal and smoothing
            pcd, _ = pcd.remove_statistical_outlier(nb_neighbors=50, std_ratio=2.5)
            
            # Enhanced normal estimation
            pcd.estimate_normals(
                search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.015, max_nn=50)
            )
            pcd.orient_normals_consistent_tangent_plane(50)
            
            # Poisson reconstruction with optimized parameters
            mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
                pcd, depth=9, width=0, scale=1.1, linear_fit=True
            )
            
            # Remove low-density vertices
            vertices_to_remove = densities < np.quantile(densities, 0.1)
            mesh.remove_vertices_by_mask(vertices_to_remove)
            
            # Enhanced cleanup
            mesh.remove_degenerate_triangles()
            mesh.remove_duplicated_triangles()
            mesh.remove_duplicated_vertices()
            mesh.compute_vertex_normals()
            
            # Filter small components
            triangle_clusters, cluster_n_triangles, _ = mesh.cluster_connected_triangles()
            triangle_clusters = np.asarray(triangle_clusters)
            cluster_n_triangles = np.asarray(cluster_n_triangles)
            largest_cluster_idx = cluster_n_triangles.argmax()
            triangles_to_remove = triangle_clusters != largest_cluster_idx
            
            # Ensure mask size matches number of triangles
            if len(triangles_to_remove) == len(np.asarray(mesh.triangles)):
                mesh.remove_triangles_by_mask(triangles_to_remove)
                
                # Smooth the mesh
                mesh = mesh.filter_smooth_simple(number_of_iterations=2)
                mesh.compute_vertex_normals()

                #Texture mapping
                vertices = np.asarray(mesh.vertices)
                vertex_colors = []
                
                for vertex in vertices:
                    x = (vertex[0] + 0.5) * w
                    y = (-vertex[1] + 0.5) * h
                    
                    if 0 <= y < h-1 and 0 <= x < w-1:
                        x0, y0 = int(x), int(y)
                        x1, y1 = min(x0 + 1, w-1), min(y0 + 1, h-1)
                        wx, wy = x - x0, y - y0
                        
                        c00 = image_rgb[y0, x0]
                        c01 = image_rgb[y0, x1]
                        c10 = image_rgb[y1, x0]
                        c11 = image_rgb[y1, x1]
                        
                        color = (1 - wx) * (1 - wy) * c00 + \
                                wx * (1 - wy) * c01 + \
                                (1 - wx) * wy * c10 + \
                                wx * wy * c11
                        vertex_colors.append(color / 255.0)
                    else:
                        vertex_colors.append([0.7, 0.7, 0.7])
                
                mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors)
                
                vertices = np.asarray(mesh.vertices)
                center = np.mean(vertices, axis=0)
                vertices = vertices - center
                
                scale_factors = np.max(np.abs(vertices), axis=0)
                vertices[:, 0] = vertices[:, 0] / scale_factors[0] * 0.8
                vertices[:, 1] = vertices[:, 1] / scale_factors[1]
                vertices[:, 2] = vertices[:, 2] / scale_factors[2]
                
                vertices[:, 2] += 1.5
                
                mesh.vertices = o3d.utility.Vector3dVector(vertices)
                o3d.io.write_triangle_mesh(str(output_path), mesh)
                
                return mesh, pcd
            else:
                logging.warning("Skipping triangle removal due to size mismatch")
                return mesh, pcd
                    
        except Exception as e:
            logging.error(f"Failed to create mesh: {str(e)}")
            logging.error(traceback.format_exc())
            raise
def main():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    input_path = Path(r"Image_Path.png")
    output_path = Path(r"Output_3d_Model\face.ply")
    debug_path = Path(r"OutputFaceLandmarks\face_landmarks.png")
    try:
        logging.info("Starting face reconstruction...")
        reconstructor = SingleImageFaceReconstructor()
        mesh, point_cloud = reconstructor.reconstruct_face(input_path, output_path, debug_path)
        
        if mesh is not None:
            coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
                size=1.0, origin=[0, 0, 0]
            )
            
            logging.info("Displaying point cloud...")
            custom_draw_geometries([point_cloud, coordinate_frame])
            
            logging.info("Displaying mesh...")
            custom_draw_geometries([mesh, coordinate_frame])
            
            logging.info("Face reconstruction completed successfully!")
            
    except Exception as e:
        logging.error(f"Program failed: {str(e)}")
        logging.error(traceback.format_exc())

if __name__ == "__main__":
    main()
Leave a Comment