Untitled
unknown
c_cpp
12 days ago
7.5 kB
5
Indexable
#pragma once #include <numeric> #include <unordered_set> #include "onnx/common/assertions.h" #include "onnxoptimizer/pass.h" #include "onnxoptimizer/passes/pass_util.h" namespace ONNX_NAMESPACE { namespace optimization { struct ConvertMeanIntoConv final : public PredicateBasedPass { explicit ConvertMeanIntoConv() : PredicateBasedPass(PassType::Replace, PassEfficiency::Complete, PassOptimizationType::Compute) {} std::string getPassName() const override { return "convert_mean_into_conv"; } bool patternMatchPredicate(Node* node) override { if (!CheckKind(node, kReduceMean)) return false; auto axes_attr = node->hasAttribute(kaxes) ? node->is(kaxes) : std::vector<int64_t>{}; if (axes_attr.size() != 1) return false; int64_t axis = axes_attr[0]; return axis == 1 || axis == 2 || axis == 3 || axis == -1; // <-- modified } bool IsElementWiseOp(Node* node) { static const std::unordered_set<std::string> ew_ops = { "Add", "Sub", "Mul", "Div", "Pow", "Sqrt", "Reciprocal", "Tanh", "Sigmoid" }; return ew_ops.count(node->kind().toString()) > 0; } bool runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current) override { Value* input = node->inputs()[0]; auto input_shape = input->sizes(); auto axes_attr = node->hasAttribute(kaxes) ? node->is(kaxes) : std::vector<int64_t>{}; int64_t axis = axes_attr.empty() ? -1 : axes_attr[0]; if (axis == -1 && input_shape.size() > 0) axis = static_cast<int64_t>(input_shape.size()) - 1; std::cout<<axis<<std::endl; // Handle axis = 3 case directly (no elementwise, no squeeze/unsqueeze) if (axis == 3) { if (input_shape.size() <= 3 || !input_shape[3].is_int) { destroy_current = NodeDestroyType::DestroyZero; return false; } int64_t C = input_shape[3].dim; if (C <= 0) { destroy_current = NodeDestroyType::DestroyZero; return false; } Node* conv = graph.create(kConv, 1); conv->insertAfter(node); std::vector<float> weight_data(C, 1.0f / C); Tensor weight_tensor; weight_tensor.elem_type() = TensorProto_DataType_FLOAT; weight_tensor.sizes() = {1, 1, 1, C}; weight_tensor.floats() = weight_data; weight_tensor.setName(node->name() + "/DepthwiseConv_weight"); Value* weight_val = graph.addInitializerAndCreateValue(weight_tensor); conv->addInput(input); conv->addInput(weight_val); conv->i_(kgroup, 1); conv->is_(kkernel_shape, std::vector<int64_t>{1, 1}); conv->is_(kpads, std::vector<int64_t>{0, 0, 0, 0}); conv->is_(kstrides, std::vector<int64_t>{1, 1}); conv->setName(node->name() + "/" + conv->kind().toString()); conv->output()->copyMetadata(node->output()); node->output()->replaceAllUsesWith(conv->output()); node->destroy(); destroy_current = NodeDestroyType::DestroyZero; return true; } // For axis=1 or 2 (including -1 mapped to 1/2) if (axis != 1 && axis != 2) { destroy_current = NodeDestroyType::DestroyZero; return false; } int64_t C = (axis == 1 && input_shape.size() > 1 && input_shape[1].is_int) ? input_shape[1].dim : (axis == 2 && input_shape.size() > 2 && input_shape[2].is_int) ? input_shape[2].dim : -1; if (C <= 0) { destroy_current = NodeDestroyType::DestroyZero; return false; } Node* current = node; std::vector<Node*> chain; while (current->outputs()[0]->uses().size() == 1) { Node* next = current->outputs()[0]->uses()[0].user; if (!IsElementWiseOp(next)) break; if (next->inputs()[0]->uses().size() != 1) break; chain.push_back(next); if (next->outputs()[0]->uses().size() != 1) break; current = next; } if (chain.empty() || node->outputs()[0]->uses().size() != 1) { destroy_current = NodeDestroyType::DestroyZero; return false; } destroy_current = NodeDestroyType::DestroyZero; Node* unsqueeze = graph.create(kUnsqueeze, 1); unsqueeze->addInput(input); unsqueeze->insertBefore(node); int opset = getOpsetVersion(graph); std::vector<int64_t> unsq_axes = (axis == 1) ? std::vector<int64_t>{0, 1} : std::vector<int64_t>{0}; unsqueeze->setName(node->name() + "/unsqueeze"); if (opset < 13 && opset != 0) { unsqueeze->is_(kaxes, std::move(unsq_axes)); } else { Tensor t; t.sizes().push_back(unsq_axes.size()); t.int64s() = unsq_axes; t.elem_type() = TensorProto_DataType_INT64; t.setName("Unsqueeze_axes_" + ONNX_NAMESPACE::to_string(graph.getNextUnique())); Value* axes_val = graph.addInitializerAndCreateValue(t); unsqueeze->addInput(axes_val); } Node* conv = graph.create(kConv, 1); conv->insertAfter(unsqueeze); std::vector<float> weight_data(C, 1.0f / C); Tensor weight_tensor; weight_tensor.elem_type() = TensorProto_DataType_FLOAT; weight_tensor.sizes() = {1, 1, 1, C}; weight_tensor.floats() = weight_data; weight_tensor.setName(node->name() + "/DepthwiseConv_weight"); Value* weight_val = graph.addInitializerAndCreateValue(weight_tensor); conv->addInput(unsqueeze->output()); conv->addInput(weight_val); conv->i_(kgroup, 1); conv->is_(kkernel_shape, std::vector<int64_t>{1, C}); conv->is_(kpads, std::vector<int64_t>{0, 0, 0, 0}); conv->is_(kstrides, std::vector<int64_t>{1, 1}); conv->setName(node->name() + "/conv"); Value* prev_out = conv->output(); Node* last_node = conv; for (Node* ew : chain) { Node* new_ew = graph.create(ew->kind(), 1); for (Value* in : ew->inputs()) { new_ew->addInput(in == ew->inputs()[0] ? prev_out : in); } new_ew->insertAfter(last_node); new_ew->setName(ew->name()); prev_out = new_ew->output(); last_node = new_ew; } Node* squeeze = graph.create(kSqueeze, 1); squeeze->addInput(prev_out); squeeze->insertAfter(last_node); std::vector<int64_t> squeeze_axes = (axis == 1) ? std::vector<int64_t>{0, 1} : std::vector<int64_t>{0}; squeeze->setName(node->name() + "/squeeze"); if (opset < 13 && opset != 0) { squeeze->is_(kaxes, std::move(squeeze_axes)); } else { Tensor t; t.sizes().push_back(squeeze_axes.size()); t.int64s() = squeeze_axes; t.elem_type() = TensorProto_DataType_INT64; t.setName("Squeeze_axes_" + ONNX_NAMESPACE::to_string(graph.getNextUnique())); Value* axes_val = graph.addInitializerAndCreateValue(t); squeeze->addInput(axes_val); } Value* target_output = chain.empty() ? node->outputs()[0] : chain.back()->outputs()[0]; Value* new_output = squeeze->output(); new_output->copyMetadata(target_output); std::string old_name = target_output->uniqueName(); std::string new_internal_name = old_name + "/deprecated"; target_output->setUniqueName(new_internal_name); target_output->replaceAllUsesWith(new_output); for (auto it = chain.rbegin(); it != chain.rend(); ++it) { (*it)->destroy(); } destroy_current = NodeDestroyType::DestroyOne; return true; } }; } // namespace optimization } // namespace ONNX_NAMESPACE
Editor is loading...
Leave a Comment