[Onnx简化库深度剖析] OnnxSimplifier和OnnxOptimizer解读-(1)

2023-12-18 16:53:12

[Onnx简化库深度剖析] OnnxSimplifier和OnnxOptimizer解读-(1)

简介

OnnxSimplifier是一个用于简化onnx模型的工具,主要工具就是:拥有折叠常量(FoldConstant)的功能、自动调度OnnxOptimizer,最为重要而且核心的是FixedPointFn这个简化调度算法。

OnnxOptimizer是一个onnx官方的一个onnx模型优化库,内部包含很多模型简化/优化的功能。用户也可直接通过python/c++/c api执行调用,但是需要比较了解内部的opt优化手段,才能够得到理想的结果。
  • 依赖情况
40% 24% 24% 12% OnnxSimplifier OnnxOptimizer Onnx OnnxRuntime pybind11

目的

从上述的描述来看,似乎OnnxSimplifier也没有干什么事情,因为OnnxOptimizer才是干简化模型的主要工具。但是OnnxSimplifier主要有以下的几点主要优点和必要性让其比较突出:

  • OnnxSimplifier接口参数较为简单,不需要了过多了解OnnxOptimizer的内部参数和优化手段
  • FixedPointFn简化调度算法让模型能够尽可能优化到最简的模型结果上,这主要因为这个迭代算法在交替使用FoldConstant和OnnxOptimizer进行优化。

OnnxSimplifier基本原理

FoldConstant功能

  • 目的:去除掉模型中那些跟输入数据流无关的叶子节点,也就是constant_node。通过单独运行constant_node,可以得到常量的output tensor,这些output tensor将被加入到模型中作为常量数据而存在,而该constant_node也将会从模型中移除。
  • constant_node条件:
    • node的domain应该属于以下的一种:[?]
      • ai.onnx
      • ai.onnx.ml
    • node的op_type不属于以下的任何一种:[?]
      • RandomUniform
      • RandomNormal
      • RandomUniformLike
      • RandomNormalLike
      • Multinomial
    • node不应该是以下的节点: [?]
      • QuantizeLinear
      • DequantizeLinear
    • node不存在子图 [?]
    • node不会产生超过threshold大小的tensor [?]
    • node的所有输入应该都在model.graph.initializer中 [?]

FixedPointFn迭代优化函数

  • 基本原理:就是通过两个优化函数,反复迭代优化中得到了最终无法继续优化的最终模型。

  • FixedPointFn的原始代码如下:

    template <typename T>
     std::function<T(const T&)> FixedPointFn(const std::function<T(const T&)>& f1,
                                             const std::function<T(const T&)>& f2,
                                             size_t max_iters, bool* converged) {
     return [f1, f2, max_iters, converged](const T& x) {
         size_t _max_iters = max_iters;
         T tmp1 = f1(x);
         T tmp2 = f2(tmp1);
         T& y1 = tmp1;
         T& y2 = tmp2;
         while (_max_iters-- > 0) {
         // 超出迭代次数则跳出
         if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) {
             // f1(x) == f2(f1(x))时,则无法继续优化,直接返回f2(f1(x))
             if (converged) {
             *converged = true;
             }
             return y2;
         }
         y1 = f1(y2);
         if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) {
             if (converged) {
             *converged = true;
             }
             return y1;
         }
         y2 = f2(y1);
         }
    
         if (converged) {
         *converged = false;
         }
         return y2;
     };
     }
    
  • FixedPointFn的流程图如下所示:

y1
y2
yes
no
yes
no
yes
no
x
f1(x)
f2(y1)
max_iters-- > 0 ?
y2
y1 == y2
y1=f1(y2)
y1 == y2
y1
y2=f2(y1)
return
  • FixedPointFn的实际函数如下所示:
    • f1: OptAndShape,FixedPointFn优化
      • OptAndShape.f1:_InferShapes形状推导(可选,不使用的时候为Identity)
      • OptAndShape.f2:OptimizeFixed优化,调用了OnnxOptimizer函数库进行优化
      • OptAndShape.max_iters:默认是50(可通过ONNXSIM_FIXED_POINT_ITERS设置)
    • f2: FoldConstant函数(可选,不使用的时候为Identity)
    • max_iters:默认是50
    • 综上:因此,实际的优化函数为FixedPointFn(FixedPointFn(_InferShapes, OptimizeFixed), FoldConstant)

总结

这次主要是介绍了OnnxSimplifier简化原理,重点介绍了FoldConstant功能和FixedPointFn迭代优化函数,这是该简化包的核心部分了。但是对于其他的OptimizeFixed,也就是OnnxOptimizer函数库的内部简化细节却没有具体的说明。后续将会具体介绍OnnxOptimizer的模型优化细节。

文章来源:https://blog.csdn.net/Pengcode/article/details/135065148
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。