1 简介
elementwise_floordiv 算子在 int32/int64 的情况下直接转换成了 ONNX 中的 div 算子,由于 div 算子是普通除操作,而不是整除操作,因此无法通过 CI 的校验。
2 实现过程
原核心实现代码如下
void ElementWiseFloordivMapper::Opset7() {
auto input_x_info = GetInput("X");
auto input_y_info = GetInput("Y");
auto output_info = GetOutput("Out");
bool is_int = false;
if (input_x_info[0].dtype <= 3 || input_x_info[0].dtype == 20 ||
input_y_info[0].dtype <= 3 || input_y_info[0].dtype == 20) {
is_int = true;
}
if (axis_ == -1 || axis_ == input_x_info[0].Rank() - 1 ||
input_x_info[0].Rank() == input_y_info[0].Rank()) {
if (is_int) {
helper_->MakeNode("Div", {input_x_info[0].name, input_y_info[0].name},
{output_info[0].name});
} else {
auto div_node = helper_->MakeNode(
"Div", {input_x_info[0].name, input_y_info[0].name});
helper_->MakeNode("Floor", {div_node->output(0)}, {output_info[0].name});
}
} else {
std::vector<int64_t> broadcast_shape;
broadcast_shape.resize(axis_ + input_x_info[0].Rank(), 1);
for (auto i = 0; i < input_y_info[0].Rank(); ++i) {
broadcast_shape[axis_ + i] = input_y_info[0].shape[i];
}
std::string broadcast_shape_node =
helper_->Constant(GetOnnxDtype(P2ODataType::INT64), broadcast_shape);
auto y_node = helper_->MakeNode(
"Reshape", {input_y_info[0].name, broadcast_shape_node});
if (is_int) {
helper_->MakeNode("Div", {input_x_info[0].name, y_node->output(0)},
{output_info[0].name});
} else {
auto div_node =
helper_->MakeNode("Div", {input_x_info[0].name, y_node->output(0)});
helper_->MakeNode("Floor", {div_node->output(0)}, {output_info[0].name});
}
}
}
可以看到,针对 int 的情况,原转换函数直接将 elementwise_floordiv 算子转换成了 Div 算子,这显然缺少了一个 floor 操作,因此修改为如下代码:
void ElementWiseFloordivMapper::Opset7() {
auto input_x_info = GetInput("X");
auto input_y_info = GetInput("Y");
auto output_info = GetOutput("Out");
auto div_input_0 = helper_->AutoCast(input_x_info[0].name, input_x_info[0].dtype, P2ODataType::FP32);
auto div_input_1 = helper_->AutoCast(input_y_info[0].name, input_y_info[0].dtype, P2ODataType::FP32);
if (axis_ == -1 || axis_ == input_x_info[0].Rank() - 1 || input_x_info[0].Rank() == input_y_info[0].Rank()) {
auto div_node = helper_->MakeNode("Div", {div_input_0, div_input_1});
auto floor_output = helper_->MakeNode("Floor", {div_node->output(0)});
helper_->AutoCast(floor_output->output(0), output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
} else {
std::vector<int64_t> broadcast_shape;
broadcast_shape.resize(axis_ + input_x_info[0].Rank(), 1);
for (auto i = 0; i < input_y_info[0].Rank(); ++i) {
broadcast_shape[axis_ + i] = input_y_info[0].shape[i];
}
std::string broadcast_shape_node = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), broadcast_shape);
auto y_node = helper_->MakeNode("Reshape", {div_input_1, broadcast_shape_node});
auto div_node = helper_->MakeNode("Div", {div_input_0, y_node->output(0)});
auto floor_output = helper_->MakeNode("Floor", {div_node->output(0)});
helper_->AutoCast(floor_output->output(0), output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
}
}
3 参考资料
- [Bug][CI] Fix the bug where elementwise floordiv only performs div without floor by Zheng-Bicheng · Pull Request #1188 · PaddlePaddle/Paddle2ONNX
- ONNX Div Operator
- ONNX Floor Operator
- Paddle 1.8 与 Paddle 2.0 API 映射表-API文档-PaddlePaddle深度学习平台