Untitled

mail@pastecode.io avatar
unknown
python
a year ago
20 kB
4
Indexable
Never
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import argparse
import glob
import logging
import os
import pickle
import sys
from typing import Any, ClassVar, Dict, List
import torch
import cv2

from detectron2.config import CfgNode, get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from detectron2.utils.logger import setup_logger

from densepose import add_densepose_config
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.utils.logger import verbosity_to_level
from densepose.vis.base import CompoundVisualizer
from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import (
    DensePoseOutputsTextureVisualizer,
    DensePoseOutputsVertexVisualizer,
    get_texture_atlases,
)
from densepose.vis.densepose_results import (
    DensePoseResultsContourVisualizer,
    DensePoseResultsFineSegmentationVisualizer,
    DensePoseResultsUVisualizer,
    DensePoseResultsVVisualizer,
)
from densepose.vis.densepose_results_textures import (
    DensePoseResultsVisualizerWithTexture,
    get_texture_atlas,
)
from densepose.vis.extractor import (
    CompoundExtractor,
    DensePoseOutputsExtractor,
    DensePoseResultExtractor,
    create_extractor,
)

from pdb import set_trace as bb
import tqdm

DOC = """Apply Net - a tool to print / visualize DensePose results
"""

LOGGER_NAME = "apply_net"
logger = logging.getLogger(LOGGER_NAME)
_ACTION_REGISTRY: Dict[str, "Action"] = {}

def box_overlaps(box1, box2):
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2

    # Compute the area of both bounding boxes
    area1 = (x1_max - x1_min) * (y1_max - y1_min)
    area2 = (x2_max - x2_min) * (y2_max - y2_min)

    # Compute the intersection area
    x_min = max(x1_min, x2_min)
    y_min = max(y1_min, y2_min)
    x_max = min(x1_max, x2_max)
    y_max = min(y1_max, y2_max)
    intersection_area = max(0, x_max - x_min) * max(0, y_max - y_min)

    # Compute the union area
    union_area = area1 + area2 - intersection_area

    # Compute the IoU
    iou = intersection_area / union_area

    return iou
    

def get_box(mask):
    if np.sum(mask) == 0:
        return -1,-1
    
    ret,binary = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
    contours,hierarchy = cv2.findContours(binary,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
    x,y,w,h = cv2.boundingRect(contours[0])

    center = (x + w//2, y + h//2)
    if max(w, h) < 80 or max(w,h)/max(mask.shape[0], mask.shape[1]) < 0.05:
        return -1, -1

    d = max(w,h)//2

    if d < 74:
        d = 74
    
    y2 = min(center[1] + 1.5*d, mask.shape[0]-1)
    y1 = y2 - 3.5*d
    if y1 < 0:
        y1 = 0
        y2 = y1 + 3.5*d
        
    x1 = max(0, center[0] - 1.75*d)
    x2 = x1 + d*3.5

    if x2 >= mask.shape[1]:
        x2 = mask.shape[1]-1
        x1 = x2 - d*3.5

    return [int(x1),int(y1),int(x2),int(y2)], [int(x),int(y),int(x+w),int(y+h)]



    
def get_arm_box(img, mask):
    if np.sum(mask) == 0:
        return -1,-1
    bb()
    skin_mask = get_skin_color_mask(img.copy())
    skin_mask[mask==0] = 0
    skin_pixels = np.sum(skin_mask>0)
    arm_pixels = np.sum(mask>0)    
    ratio = skin_pixels/arm_pixels
    
    if ratio < 0.3:
        return -1,-1
    
    ret,binary = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
    contours,hierarchy = cv2.findContours(binary,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
    x,y,w,h = cv2.boundingRect(contours[0])
    center = (x + w//2, y + h//2)

    
    if h < 500:
        return -1, -1

    if h/w < 2:
        return -1, -1
    
    d = h//2

    y2 = min(center[1] + 1.05*d, mask.shape[0]-1)
    y1 = y2 - 2.05*d
    if y1 < 0:
        y1 = 0
        y2 = y1 + 2.05*d
        
    x1 = max(0, center[0] - d*0.25)
    x2 = x1 + d*0.5

    if x2 >= mask.shape[1]:
        x2 = mask.shape[1]-1
        x1 = x2 - d*0.5

    return [int(x1),int(y1),int(x2),int(y2)], [int(x),int(y),int(x+w),int(y+h)]


class Action(object):
    @classmethod
    def add_arguments(cls: type, parser: argparse.ArgumentParser):
        parser.add_argument(
            "-v",
            "--verbosity",
            action="count",
            help="Verbose mode. Multiple -v options increase the verbosity.",
        )


def register_action(cls: type):
    """
    Decorator for action classes to automate action registration
    """
    global _ACTION_REGISTRY
    _ACTION_REGISTRY[cls.COMMAND] = cls
    return cls


class InferenceAction(Action):
    @classmethod
    def add_arguments(cls: type, parser: argparse.ArgumentParser):
        super(InferenceAction, cls).add_arguments(parser)
        parser.add_argument("cfg", metavar="<config>", help="Config file")
        parser.add_argument("model", metavar="<model>", help="Model file")
        parser.add_argument("input", metavar="<input>", help="Input data")
        parser.add_argument(
            "--opts",
            help="Modify config options using the command-line 'KEY VALUE' pairs",
            default=[],
            nargs=argparse.REMAINDER,
        )

    @classmethod
    def execute(cls: type, args: argparse.Namespace):
        logger.info(f"Loading config from {args.cfg}")
        opts = []
        cfg = cls.setup_config(args.cfg, args.model, args, opts)
        logger.info(f"Loading model from {args.model}")
        predictor = DefaultPredictor(cfg)
        logger.info(f"Loading data from {args.input}")
        file_list = cls._get_input_file_list(args.input)
        if len(file_list) == 0:
            logger.warning(f"No input images for {args.input}")
            return
        context = cls.create_context(args, cfg)

        f = open(file_list[0], 'r')
        file_list = f.read().split()
        f.close()
        
        num_images = len(file_list)
        import time        
        np.random.seed(int(time.time()))
        perm = np.random.permutation(len(file_list))
        cfg['INPUT']['MIN_SIZE_TEST'] = 400
        cfg['INPUT']['MAX_SIZE_TEST'] = 1200

        #done_names = os.popen('ls /media/bharat/ssd/hand_data_2/*.jpg').read().split()
        
        for i in tqdm.tqdm(range(len(file_list))):
            file_name = '/media/bharat/ssd/downloads/fashionova/zalando/' + file_list[perm[i]]
            name = file_name.split('.jpg')[0]
            name = name + '.txt'            
            if os.path.exists(name):
                continue
            try:
                img = read_image(file_name, format="BGR")  # predictor expects BGR image.
            except:
                continue
            print(name)
            with torch.no_grad():
                outputs = predictor(img)["instances"]
                cls.execute_on_outputs(context, {"file_name": file_name, "image": img}, outputs)
            os.system('touch ' + name)
        cls.postexecute(context)

    @classmethod
    def setup_config(
        cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
    ):
        cfg = get_cfg()
        add_densepose_config(cfg)
        cfg.merge_from_file(config_fpath)
        cfg.merge_from_list(args.opts)
        if opts:
            cfg.merge_from_list(opts)
        cfg.MODEL.WEIGHTS = model_fpath
        cfg.freeze()
        return cfg

    @classmethod
    def _get_input_file_list(cls: type, input_spec: str):
        if os.path.isdir(input_spec):
            file_list = [
                os.path.join(input_spec, fname)
                for fname in os.listdir(input_spec)
                if os.path.isfile(os.path.join(input_spec, fname))
            ]
        elif os.path.isfile(input_spec):
            file_list = [input_spec]
        else:
            file_list = glob.glob(input_spec)

        ffile_list = []
        for names in file_list:
            if '.png' in names:
                continue
            else:
                ffile_list.append(names)
        return ffile_list


@register_action
class DumpAction(InferenceAction):
    """
    Dump action that outputs results to a pickle file
    """

    COMMAND: ClassVar[str] = "dump"

    @classmethod
    def add_parser(cls: type, subparsers: argparse._SubParsersAction):
        parser = subparsers.add_parser(cls.COMMAND, help="Dump model outputs to a file.")
        cls.add_arguments(parser)
        parser.set_defaults(func=cls.execute)

    @classmethod
    def add_arguments(cls: type, parser: argparse.ArgumentParser):
        super(DumpAction, cls).add_arguments(parser)
        parser.add_argument(
            "--output",
            metavar="<dump_file>",
            default="results.pkl",
            help="File name to save dump to",
        )

    @classmethod
    def execute_on_outputs(
        cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
    ):
        image_fpath = entry["file_name"]
        logger.info(f"Processing {image_fpath}")
        result = {"file_name": image_fpath}
        if outputs.has("scores"):
            result["scores"] = outputs.get("scores").cpu()
        if outputs.has("pred_boxes"):
            result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
            if outputs.has("pred_densepose"):
                if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
                    extractor = DensePoseResultExtractor()
                elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
                    extractor = DensePoseOutputsExtractor()
                result["pred_densepose"] = extractor(outputs)[0]
        context["results"].append(result)

    @classmethod
    def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode):
        context = {"results": [], "out_fname": args.output}
        return context

    @classmethod
    def postexecute(cls: type, context: Dict[str, Any]):
        out_fname = context["out_fname"]
        out_dir = os.path.dirname(out_fname)
        if len(out_dir) > 0 and not os.path.exists(out_dir):
            os.makedirs(out_dir)
        with open(out_fname, "wb") as hFile:
            pickle.dump(context["results"], hFile)
            logger.info(f"Output saved to {out_fname}")


@register_action
class ShowAction(InferenceAction):
    """
    Show action that visualizes selected entries on an image
    """

    COMMAND: ClassVar[str] = "show"
    VISUALIZERS: ClassVar[Dict[str, object]] = {
        "dp_contour": DensePoseResultsContourVisualizer,
        "dp_segm": DensePoseResultsFineSegmentationVisualizer,
        "dp_u": DensePoseResultsUVisualizer,
        "dp_v": DensePoseResultsVVisualizer,
        "dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
        "dp_cse_texture": DensePoseOutputsTextureVisualizer,
        "dp_vertex": DensePoseOutputsVertexVisualizer,
        "bbox": ScoredBoundingBoxVisualizer,
    }

    @classmethod
    def add_parser(cls: type, subparsers: argparse._SubParsersAction):
        parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries")
        cls.add_arguments(parser)
        parser.set_defaults(func=cls.execute)

    @classmethod
    def add_arguments(cls: type, parser: argparse.ArgumentParser):
        super(ShowAction, cls).add_arguments(parser)
        parser.add_argument(
            "visualizations",
            metavar="<visualizations>",
            help="Comma separated list of visualizations, possible values: "
            "[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))),
        )
        parser.add_argument(
            "--min_score",
            metavar="<score>",
            default=0.9,
            type=float,
            help="Minimum detection score to visualize",
        )
        parser.add_argument(
            "--nms_thresh", metavar="<threshold>", default=None, type=float, help="NMS threshold"
        )
        parser.add_argument(
            "--texture_atlas",
            metavar="<texture_atlas>",
            default=None,
            help="Texture atlas file (for IUV texture transfer)",
        )
        parser.add_argument(
            "--texture_atlases_map",
            metavar="<texture_atlases_map>",
            default=None,
            help="JSON string of a dict containing texture atlas files for each mesh",
        )
        parser.add_argument(
            "--output",
            metavar="<image_file>",
            default="outputres.png",
            help="File name to save output to",
        )

    @classmethod
    def setup_config(
        cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
    ):
        opts.append("MODEL.ROI_HEADS.SCORE_THRESH_TEST")
        opts.append(str(args.min_score))
        if args.nms_thresh is not None:
            opts.append("MODEL.ROI_HEADS.NMS_THRESH_TEST")
            opts.append(str(args.nms_thresh))
        cfg = super(ShowAction, cls).setup_config(config_fpath, model_fpath, args, opts)
        return cfg

    @classmethod
    def execute_on_outputs(
        cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
    ):
        import cv2
        import numpy as np

        visualizer = context["visualizer"]
        extractor = context["extractor"]
        image_fpath = entry["file_name"]
        #mask_fpath = entry["file_name"].split('.jpeg')[0]
        logger.info(f"Processing {image_fpath}")

        image = cv2.cvtColor(entry["image"], cv2.COLOR_BGR2GRAY)
        image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
        data = extractor(outputs)

        scaling_x = image.shape[0]
        scaling_y = image.shape[1]
        
        if data[0][0] is not None:
            try:
                all_boxes = []
                all_fin_boxes = []
                all_masks = []
                img = cv2.imread(entry["file_name"])
                for i in range(len(data[0][0])):                    
                    box = data[1][0][i].cpu().numpy()            
                    mask = data[0][0][i].labels.cpu().numpy()
                    #rmask = np.array(255*(mask==3), dtype=np.uint8)
                    rmask = np.array(255*((mask==3) | (mask==16) | (mask==18) | (mask==20) | (mask==22)), dtype=np.uint8)                    
                    
                    right_hand_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
                    right_hand_mask[int(box[1]):int(box[1])+rmask.shape[0], int(box[0]):int(box[0])+rmask.shape[1]] = rmask
                    #discard very small boxes and resize very large boxes to a reasonable size
                    right_box_fin, original_right_box = get_arm_box(img, right_hand_mask)
                    
                    #lmask = np.array(255*(mask==4), dtype=np.uint8)
                    lmask = np.array(255*((mask==4) | (mask==19) | (mask==21) | (mask==15) | (mask==17)), dtype=np.uint8)
                    left_hand_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
                    left_hand_mask[int(box[1]):int(box[1])+lmask.shape[0], int(box[0]):int(box[0])+lmask.shape[1]] = lmask
                    left_box_fin, original_left_box = get_arm_box(img, left_hand_mask)
                    #check if there is any overlap across boxes, if yes, discard overlapping boxes
                    
                    if right_box_fin != -1:
                        all_boxes.append(original_right_box)
                        all_fin_boxes.append(right_box_fin)                        
                        all_masks.append(right_hand_mask)
                    if left_box_fin != -1:
                        all_boxes.append(original_left_box)
                        all_masks.append(left_hand_mask)
                        all_fin_boxes.append(left_box_fin)

                for i in range(len(all_boxes)):
                    #save the image, box and the mask
                    img = cv2.imread(entry["file_name"])
                    crop_box = all_fin_boxes[i]
                    img = img[crop_box[1]:crop_box[3], crop_box[0]:crop_box[2], :]
                    mask = all_masks[i][crop_box[1]:crop_box[3], crop_box[0]:crop_box[2]]
                    bbox = [all_boxes[i][0]-crop_box[0], all_boxes[i][1]-crop_box[1], all_boxes[i][2]-crop_box[0], all_boxes[i][3]-crop_box[1]]
                    name = entry["file_name"].split('.jpg')[0]
                    init_path = entry["file_name"].split('/')
                    init_path = init_path[-3] + '/' + init_path[-2] + '/'
                    base_name = '/media/bharat/ssd/arm_data/' + init_path
                    if not os.path.exists(base_name):
                        os.system('mkdir -p ' + base_name)
                    
                    name = base_name + name.split('/')[-1]
                    #name = 'tmp2/' + name.split('/')[-1]
                    cv2.imwrite(name + '_{}_{}_{}_{}_{}.jpg'.format(i, bbox[0], bbox[1], bbox[2], bbox[3]), img)
                    cv2.imwrite(name + '_{}_mask.png'.format(i), mask)
            except:
                return
        
        """image_vis = visualizer.visualize(image, data)
        entry_idx = context["entry_idx"] + 1
        out_fname = cls._get_out_fname(entry_idx, context["out_fname"])
        out_dir = os.path.dirname(out_fname)
        if len(out_dir) > 0 and not os.path.exists(out_dir):
            os.makedirs(out_dir)
        #cv2.imwrite(out_fname, image_vis)
        logger.info(f"Output saved to {out_fname}")
        context["entry_idx"] += 1"""

    @classmethod
    def postexecute(cls: type, context: Dict[str, Any]):
        pass

    @classmethod
    def _get_out_fname(cls: type, entry_idx: int, fname_base: str):
        base, ext = os.path.splitext(fname_base)
        return base + ".{0:04d}".format(entry_idx) + ext

    @classmethod
    def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode) -> Dict[str, Any]:
        vis_specs = args.visualizations.split(",")
        visualizers = []
        extractors = []
        for vis_spec in vis_specs:
            texture_atlas = get_texture_atlas(args.texture_atlas)
            texture_atlases_dict = get_texture_atlases(args.texture_atlases_map)
            vis = cls.VISUALIZERS[vis_spec](
                cfg=cfg,
                texture_atlas=texture_atlas,
                texture_atlases_dict=texture_atlases_dict,
            )
            visualizers.append(vis)
            extractor = create_extractor(vis)
            extractors.append(extractor)
        visualizer = CompoundVisualizer(visualizers)
        extractor = CompoundExtractor(extractors)
        context = {
            "extractor": extractor,
            "visualizer": visualizer,
            "out_fname": args.output,
            "entry_idx": 0,
        }
        return context


def create_argument_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description=DOC,
        formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120),
    )
    parser.set_defaults(func=lambda _: parser.print_help(sys.stdout))
    subparsers = parser.add_subparsers(title="Actions")
    for _, action in _ACTION_REGISTRY.items():
        action.add_parser(subparsers)
    return parser


def main():
    parser = create_argument_parser()
    args = parser.parse_args()
    verbosity = args.verbosity if hasattr(args, "verbosity") else None
    global logger
    logger = setup_logger(name=LOGGER_NAME)
    logger.setLevel(verbosity_to_level(verbosity))
    args.func(args)


if __name__ == "__main__":
    main()