Untitled

 avatar
unknown
c_cpp
12 days ago
7.5 kB
5
Indexable
#pragma once

#include <numeric>
#include <unordered_set>

#include "onnx/common/assertions.h"
#include "onnxoptimizer/pass.h"
#include "onnxoptimizer/passes/pass_util.h"

namespace ONNX_NAMESPACE {
namespace optimization {

struct ConvertMeanIntoConv final : public PredicateBasedPass {
  explicit ConvertMeanIntoConv()
      : PredicateBasedPass(PassType::Replace, PassEfficiency::Complete,
                           PassOptimizationType::Compute) {}

  std::string getPassName() const override {
    return "convert_mean_into_conv";
  }

  bool patternMatchPredicate(Node* node) override {
    if (!CheckKind(node, kReduceMean)) return false;
    auto axes_attr = node->hasAttribute(kaxes) ? node->is(kaxes) : std::vector<int64_t>{};
    if (axes_attr.size() != 1) return false;
    int64_t axis = axes_attr[0];
    return axis == 1 || axis == 2 || axis == 3 || axis == -1;  // <-- modified
  }

  bool IsElementWiseOp(Node* node) {
    static const std::unordered_set<std::string> ew_ops = {
      "Add", "Sub", "Mul", "Div", "Pow", "Sqrt", "Reciprocal", "Tanh", "Sigmoid"
    };
    return ew_ops.count(node->kind().toString()) > 0;
  }

  bool runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current) override {
    Value* input = node->inputs()[0];
    auto input_shape = input->sizes();

    auto axes_attr = node->hasAttribute(kaxes) ? node->is(kaxes) : std::vector<int64_t>{};
    int64_t axis = axes_attr.empty() ? -1 : axes_attr[0];

    if (axis == -1 && input_shape.size() > 0)
      axis = static_cast<int64_t>(input_shape.size()) - 1;
    std::cout<<axis<<std::endl;

    // Handle axis = 3 case directly (no elementwise, no squeeze/unsqueeze)
    if (axis == 3) {
      if (input_shape.size() <= 3 || !input_shape[3].is_int) {
        destroy_current = NodeDestroyType::DestroyZero;
        return false;
      }

      int64_t C = input_shape[3].dim;
      if (C <= 0) {
        destroy_current = NodeDestroyType::DestroyZero;
        return false;
      }

      Node* conv = graph.create(kConv, 1);
      conv->insertAfter(node);

      std::vector<float> weight_data(C, 1.0f / C);
      Tensor weight_tensor;
      weight_tensor.elem_type() = TensorProto_DataType_FLOAT;
      weight_tensor.sizes() = {1, 1, 1, C};
      weight_tensor.floats() = weight_data;
      weight_tensor.setName(node->name() + "/DepthwiseConv_weight");
      Value* weight_val = graph.addInitializerAndCreateValue(weight_tensor);

      conv->addInput(input);
      conv->addInput(weight_val);
      conv->i_(kgroup, 1);
      conv->is_(kkernel_shape, std::vector<int64_t>{1, 1});
      conv->is_(kpads, std::vector<int64_t>{0, 0, 0, 0});
      conv->is_(kstrides, std::vector<int64_t>{1, 1});
      conv->setName(node->name() + "/" + conv->kind().toString());

      conv->output()->copyMetadata(node->output());
      node->output()->replaceAllUsesWith(conv->output());
      node->destroy();
      destroy_current = NodeDestroyType::DestroyZero;
      return true;
    }

    // For axis=1 or 2 (including -1 mapped to 1/2)
    if (axis != 1 && axis != 2) {
      destroy_current = NodeDestroyType::DestroyZero;
      return false;
    }

    int64_t C = (axis == 1 && input_shape.size() > 1 && input_shape[1].is_int) ? input_shape[1].dim :
                (axis == 2 && input_shape.size() > 2 && input_shape[2].is_int) ? input_shape[2].dim : -1;
    if (C <= 0) {
      destroy_current = NodeDestroyType::DestroyZero;
      return false;
    }

    Node* current = node;
    std::vector<Node*> chain;

    while (current->outputs()[0]->uses().size() == 1) {
      Node* next = current->outputs()[0]->uses()[0].user;
      if (!IsElementWiseOp(next)) break;
      if (next->inputs()[0]->uses().size() != 1) break;
      chain.push_back(next);
      if (next->outputs()[0]->uses().size() != 1) break;
      current = next;
    }

    if (chain.empty() || node->outputs()[0]->uses().size() != 1) {
      destroy_current = NodeDestroyType::DestroyZero;
      return false;
    }

    destroy_current = NodeDestroyType::DestroyZero;

    Node* unsqueeze = graph.create(kUnsqueeze, 1);
    unsqueeze->addInput(input);
    unsqueeze->insertBefore(node);

    int opset = getOpsetVersion(graph);
    std::vector<int64_t> unsq_axes = (axis == 1) ? std::vector<int64_t>{0, 1} : std::vector<int64_t>{0};
    unsqueeze->setName(node->name() + "/unsqueeze");
    if (opset < 13 && opset != 0) {
      unsqueeze->is_(kaxes, std::move(unsq_axes));
    } else {
      Tensor t;
      t.sizes().push_back(unsq_axes.size());
      t.int64s() = unsq_axes;
      t.elem_type() = TensorProto_DataType_INT64;
      t.setName("Unsqueeze_axes_" + ONNX_NAMESPACE::to_string(graph.getNextUnique()));
      Value* axes_val = graph.addInitializerAndCreateValue(t);
      unsqueeze->addInput(axes_val);
    }

    Node* conv = graph.create(kConv, 1);
    conv->insertAfter(unsqueeze);
    std::vector<float> weight_data(C, 1.0f / C);
    Tensor weight_tensor;
    weight_tensor.elem_type() = TensorProto_DataType_FLOAT;
    weight_tensor.sizes() = {1, 1, 1, C};
    weight_tensor.floats() = weight_data;
    weight_tensor.setName(node->name() + "/DepthwiseConv_weight");
    Value* weight_val = graph.addInitializerAndCreateValue(weight_tensor);
    conv->addInput(unsqueeze->output());
    conv->addInput(weight_val);
    conv->i_(kgroup, 1);
    conv->is_(kkernel_shape, std::vector<int64_t>{1, C});
    conv->is_(kpads, std::vector<int64_t>{0, 0, 0, 0});
    conv->is_(kstrides, std::vector<int64_t>{1, 1});
    conv->setName(node->name() + "/conv");

    Value* prev_out = conv->output();
    Node* last_node = conv;
    for (Node* ew : chain) {
      Node* new_ew = graph.create(ew->kind(), 1);
      for (Value* in : ew->inputs()) {
        new_ew->addInput(in == ew->inputs()[0] ? prev_out : in);
      }
      new_ew->insertAfter(last_node);
      new_ew->setName(ew->name());
      prev_out = new_ew->output();
      last_node = new_ew;
    }

    Node* squeeze = graph.create(kSqueeze, 1);
    squeeze->addInput(prev_out);
    squeeze->insertAfter(last_node);
    std::vector<int64_t> squeeze_axes = (axis == 1) ? std::vector<int64_t>{0, 1} : std::vector<int64_t>{0};
    squeeze->setName(node->name() + "/squeeze");
    if (opset < 13 && opset != 0) {
      squeeze->is_(kaxes, std::move(squeeze_axes));
    } else {
      Tensor t;
      t.sizes().push_back(squeeze_axes.size());
      t.int64s() = squeeze_axes;
      t.elem_type() = TensorProto_DataType_INT64;
      t.setName("Squeeze_axes_" + ONNX_NAMESPACE::to_string(graph.getNextUnique()));
      Value* axes_val = graph.addInitializerAndCreateValue(t);
      squeeze->addInput(axes_val);
    }

    Value* target_output = chain.empty() ? node->outputs()[0] : chain.back()->outputs()[0];
    Value* new_output = squeeze->output();
    new_output->copyMetadata(target_output);
    std::string old_name = target_output->uniqueName();
    std::string new_internal_name = old_name + "/deprecated";
    target_output->setUniqueName(new_internal_name);
    target_output->replaceAllUsesWith(new_output);

    for (auto it = chain.rbegin(); it != chain.rend(); ++it) {
      (*it)->destroy();
    }

    destroy_current = NodeDestroyType::DestroyOne;
    return true;
  }
};

} // namespace optimization
} // namespace ONNX_NAMESPACE
Editor is loading...
Leave a Comment