Untitled
unknown
c_cpp
8 months ago
7.5 kB
8
Indexable
#pragma once
#include <numeric>
#include <unordered_set>
#include "onnx/common/assertions.h"
#include "onnxoptimizer/pass.h"
#include "onnxoptimizer/passes/pass_util.h"
namespace ONNX_NAMESPACE {
namespace optimization {
struct ConvertMeanIntoConv final : public PredicateBasedPass {
explicit ConvertMeanIntoConv()
: PredicateBasedPass(PassType::Replace, PassEfficiency::Complete,
PassOptimizationType::Compute) {}
std::string getPassName() const override {
return "convert_mean_into_conv";
}
bool patternMatchPredicate(Node* node) override {
if (!CheckKind(node, kReduceMean)) return false;
auto axes_attr = node->hasAttribute(kaxes) ? node->is(kaxes) : std::vector<int64_t>{};
if (axes_attr.size() != 1) return false;
int64_t axis = axes_attr[0];
return axis == 1 || axis == 2 || axis == 3 || axis == -1; // <-- modified
}
bool IsElementWiseOp(Node* node) {
static const std::unordered_set<std::string> ew_ops = {
"Add", "Sub", "Mul", "Div", "Pow", "Sqrt", "Reciprocal", "Tanh", "Sigmoid"
};
return ew_ops.count(node->kind().toString()) > 0;
}
bool runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current) override {
Value* input = node->inputs()[0];
auto input_shape = input->sizes();
auto axes_attr = node->hasAttribute(kaxes) ? node->is(kaxes) : std::vector<int64_t>{};
int64_t axis = axes_attr.empty() ? -1 : axes_attr[0];
if (axis == -1 && input_shape.size() > 0)
axis = static_cast<int64_t>(input_shape.size()) - 1;
std::cout<<axis<<std::endl;
// Handle axis = 3 case directly (no elementwise, no squeeze/unsqueeze)
if (axis == 3) {
if (input_shape.size() <= 3 || !input_shape[3].is_int) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
int64_t C = input_shape[3].dim;
if (C <= 0) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
Node* conv = graph.create(kConv, 1);
conv->insertAfter(node);
std::vector<float> weight_data(C, 1.0f / C);
Tensor weight_tensor;
weight_tensor.elem_type() = TensorProto_DataType_FLOAT;
weight_tensor.sizes() = {1, 1, 1, C};
weight_tensor.floats() = weight_data;
weight_tensor.setName(node->name() + "/DepthwiseConv_weight");
Value* weight_val = graph.addInitializerAndCreateValue(weight_tensor);
conv->addInput(input);
conv->addInput(weight_val);
conv->i_(kgroup, 1);
conv->is_(kkernel_shape, std::vector<int64_t>{1, 1});
conv->is_(kpads, std::vector<int64_t>{0, 0, 0, 0});
conv->is_(kstrides, std::vector<int64_t>{1, 1});
conv->setName(node->name() + "/" + conv->kind().toString());
conv->output()->copyMetadata(node->output());
node->output()->replaceAllUsesWith(conv->output());
node->destroy();
destroy_current = NodeDestroyType::DestroyZero;
return true;
}
// For axis=1 or 2 (including -1 mapped to 1/2)
if (axis != 1 && axis != 2) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
int64_t C = (axis == 1 && input_shape.size() > 1 && input_shape[1].is_int) ? input_shape[1].dim :
(axis == 2 && input_shape.size() > 2 && input_shape[2].is_int) ? input_shape[2].dim : -1;
if (C <= 0) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
Node* current = node;
std::vector<Node*> chain;
while (current->outputs()[0]->uses().size() == 1) {
Node* next = current->outputs()[0]->uses()[0].user;
if (!IsElementWiseOp(next)) break;
if (next->inputs()[0]->uses().size() != 1) break;
chain.push_back(next);
if (next->outputs()[0]->uses().size() != 1) break;
current = next;
}
if (chain.empty() || node->outputs()[0]->uses().size() != 1) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
destroy_current = NodeDestroyType::DestroyZero;
Node* unsqueeze = graph.create(kUnsqueeze, 1);
unsqueeze->addInput(input);
unsqueeze->insertBefore(node);
int opset = getOpsetVersion(graph);
std::vector<int64_t> unsq_axes = (axis == 1) ? std::vector<int64_t>{0, 1} : std::vector<int64_t>{0};
unsqueeze->setName(node->name() + "/unsqueeze");
if (opset < 13 && opset != 0) {
unsqueeze->is_(kaxes, std::move(unsq_axes));
} else {
Tensor t;
t.sizes().push_back(unsq_axes.size());
t.int64s() = unsq_axes;
t.elem_type() = TensorProto_DataType_INT64;
t.setName("Unsqueeze_axes_" + ONNX_NAMESPACE::to_string(graph.getNextUnique()));
Value* axes_val = graph.addInitializerAndCreateValue(t);
unsqueeze->addInput(axes_val);
}
Node* conv = graph.create(kConv, 1);
conv->insertAfter(unsqueeze);
std::vector<float> weight_data(C, 1.0f / C);
Tensor weight_tensor;
weight_tensor.elem_type() = TensorProto_DataType_FLOAT;
weight_tensor.sizes() = {1, 1, 1, C};
weight_tensor.floats() = weight_data;
weight_tensor.setName(node->name() + "/DepthwiseConv_weight");
Value* weight_val = graph.addInitializerAndCreateValue(weight_tensor);
conv->addInput(unsqueeze->output());
conv->addInput(weight_val);
conv->i_(kgroup, 1);
conv->is_(kkernel_shape, std::vector<int64_t>{1, C});
conv->is_(kpads, std::vector<int64_t>{0, 0, 0, 0});
conv->is_(kstrides, std::vector<int64_t>{1, 1});
conv->setName(node->name() + "/conv");
Value* prev_out = conv->output();
Node* last_node = conv;
for (Node* ew : chain) {
Node* new_ew = graph.create(ew->kind(), 1);
for (Value* in : ew->inputs()) {
new_ew->addInput(in == ew->inputs()[0] ? prev_out : in);
}
new_ew->insertAfter(last_node);
new_ew->setName(ew->name());
prev_out = new_ew->output();
last_node = new_ew;
}
Node* squeeze = graph.create(kSqueeze, 1);
squeeze->addInput(prev_out);
squeeze->insertAfter(last_node);
std::vector<int64_t> squeeze_axes = (axis == 1) ? std::vector<int64_t>{0, 1} : std::vector<int64_t>{0};
squeeze->setName(node->name() + "/squeeze");
if (opset < 13 && opset != 0) {
squeeze->is_(kaxes, std::move(squeeze_axes));
} else {
Tensor t;
t.sizes().push_back(squeeze_axes.size());
t.int64s() = squeeze_axes;
t.elem_type() = TensorProto_DataType_INT64;
t.setName("Squeeze_axes_" + ONNX_NAMESPACE::to_string(graph.getNextUnique()));
Value* axes_val = graph.addInitializerAndCreateValue(t);
squeeze->addInput(axes_val);
}
Value* target_output = chain.empty() ? node->outputs()[0] : chain.back()->outputs()[0];
Value* new_output = squeeze->output();
new_output->copyMetadata(target_output);
std::string old_name = target_output->uniqueName();
std::string new_internal_name = old_name + "/deprecated";
target_output->setUniqueName(new_internal_name);
target_output->replaceAllUsesWith(new_output);
for (auto it = chain.rbegin(); it != chain.rend(); ++it) {
(*it)->destroy();
}
destroy_current = NodeDestroyType::DestroyOne;
return true;
}
};
} // namespace optimization
} // namespace ONNX_NAMESPACE
Editor is loading...
Leave a Comment