Welcome to Paddle-Inference’s documentation!¶
概述¶
Paddle Inference为飞桨核心框架推理引擎。Paddle Inference功能特性丰富,性能优异,针对不同平台不同的应用场景进行了深度的适配优化,做到高吞吐、低时延,保证了飞桨模型在服务器端即训即用,快速部署。
特性¶
通用性。支持对Paddle训练出的所有模型进行预测。
内存/显存复用。在推理初始化阶段,对模型中的OP输出Tensor 进行依赖分析,将两两互不依赖的Tensor在内存/显存空间上进行复用,进而增大计算并行量,提升服务吞吐量。
细粒度OP融合。在推理初始化阶段,按照已有的融合模式将模型中的多个OP融合成一个OP,减少了模型的计算量的同时,也减少了 Kernel Launch的次数,从而能提升推理性能。目前Paddle Inference支持的融合模式多达几十个。
高性能CPU/GPU Kernel。内置同Intel、Nvidia共同打造的高性能kernel,保证了模型推理高性能的执行。
子图集成 TensorRT。Paddle Inference采用子图的形式集成TensorRT,针对GPU推理场景,TensorRT可对一些子图进行优化,包括OP的横向和纵向融合,过滤冗余的OP,并为OP自动选择最优的kernel,加快推理速度。
集成MKLDNN
支持加载PaddleSlim量化压缩后的模型。 PaddleSlim 是飞桨深度学习模型压缩工具,Paddle Inference可联动PaddleSlim,支持加载量化、裁剪和蒸馏后的模型并部署,由此减小模型存储空间、减少计算占用内存、加快模型推理速度。其中在模型量化方面,Paddle Inference在X86 CPU上做了深度优化 ,常见分类模型的单线程性能可提升近3倍,ERNIE模型的单线程性能可提升2.68倍。
Quick Start¶
前提准备 接下来我们会通过几段Python代码的方式对Paddle Inference使用进行介绍, 为了能够成功运行代码,请您在环境中(Mac, Windows,Linux)安装不低于1.7版本的Paddle, 安装Paddle 请参考 飞桨官网主页。
导出预测模型文件¶
在模型训练期间,我们通常使用Python来构建模型结构,比如:
import paddle.fluid as fluid
res = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu", param_attr=param_attr)
在模型部署时,我们需要提前将这种Python表示的结构以及参数序列化到磁盘中。那是如何做到的呢?
在模型训练过程中或者模型训练结束后,我们可以通过save_inference_model接口来导出标准化的模型文件。
我们用一个简单的代码例子来展示下导出模型文件的这一过程。
import paddle
import paddle.fluid as fluid
# 建立一个简单的网络,网络的输入的shape为[batch, 3, 28, 28]
image_shape = [3, 28, 28]
img = fluid.layers.data(name='image', shape=image_shape, dtype='float32', append_batch_size=True)
# 模型包含两个Conv层
conv1 = fluid.layers.conv2d(
input=img,
num_filters=8,
filter_size=3,
stride=2,
padding=1,
groups=1,
act=None,
bias_attr=True)
out = fluid.layers.conv2d(
input=conv1,
num_filters=8,
filter_size=3,
stride=2,
padding=1,
groups=1,
act=None,
bias_attr=True)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# 创建网络中的参数变量,并初始化参数变量
exe.run(fluid.default_startup_program())
# 如果存在预训练模型
# def if_exist(var):
# return os.path.exists(os.path.join("./ShuffleNet", var.name))
# fluid.io.load_vars(exe, "./pretrained_model", predicate=if_exist)
# 保存模型到model目录中,只保存与输入image和输出与推理相关的部分网络
fluid.io.save_inference_model(dirname='./sample_model', feeded_var_names=['image'], target_vars = [out], executor=exe, model_filename='model', params_filename='params')
该程序运行结束后,会在本目录中生成一个sample_model目录,目录中包含model, params 两个文件,model文件表示模型的结构文件,params表示所有参数的融合文件。
飞桨提供了 两种标准 的模型文件,一种为Combined方式, 一种为No-Combined的方式。
Combined的方式
fluid.io.save_inference_model(dirname='./sample_model', feeded_var_names=['image'], target_vars = [out], executor=exe, model_filename='model', params_filename='params')
model_filename,params_filename表示要生成的模型结构文件、融合参数文件的名字。
No-Combined的方式
fluid.io.save_inference_model(dirname='./sample_model', feeded_var_names=['image'], target_vars = [out], executor=exe)
如果不指定model_filename,params_filename,会在sample_model目录下生成__model__ 模型结构文件,以及一系列的参数文件。
在模型部署期间,我们更推荐使用Combined的方式,因为涉及模型上线加密的场景时,这种方式会更友好一些。
加载模型预测¶
1)使用load_inference方式
我们可以使用load_inference_model接口加载训练好的模型(以sample_model模型举例),并复用训练框架的前向计算,直接完成推理。 示例程序如下所示:
import paddle.fluid as fluid
import numpy as np
data = np.ones((1, 3, 28, 28)).astype(np.float32)
exe = fluid.Executor(fluid.CPUPlace())
# 加载Combined的模型需要指定model_filename, params_filename
# 加载No-Combined的模型不需要指定model_filename, params_filename
[inference_program, feed_target_names, fetch_targets] = \
fluid.io.load_inference_model(dirname='sample_model', executor=exe, model_filename='model', params_filename='params')
with fluid.program_guard(inference_program):
results = exe.run(inference_program,
feed={feed_target_names[0]: data},
fetch_list=fetch_targets, return_numpy=False)
print (np.array(results[0]).shape)
# (1, 8, 7, 7)
在上述方式中,在模型加载后会按照执行顺序将所有的OP进行拓扑排序,在运行期间Op会按照排序一一运行,整个过程中运行的为训练中前向的OP,期间不会有任何的优化(OP融合,显存优化,预测Kernel针对优化)。 因此,load_inference_model的方式预测期间很可能不会有很好的性能表现,此方式比较适合用来做实验(测试模型的效果、正确性等)使用,并不适用于真正的部署上线。接下来我们会重点介绍Paddle Inference的使用。
2)使用Paddle Inference API方式
不同于 load_inference_model方式,Paddle Inference 在模型加载后会进行一系列的优化,包括: Kernel优化,OP横向,纵向融合,显存/内存优化,以及MKLDNN,TensorRT的集成等,性能和吞吐会得到大幅度的提升。这些优化会在之后的文档中进行详细的介绍。
那我们先用一个简单的代码例子来介绍Paddle Inference 的使用。
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
import numpy as np
# 配置运行信息
# config = AnalysisConfig("./sample_model") # 加载non-combined 模型格式
config = AnalysisConfig("./sample_model/model", "./sample_model/params") # 加载combine的模型格式
config.switch_use_feed_fetch_ops(False)
config.enable_memory_optim()
config.enable_use_gpu(1000, 0)
# 根据config创建predictor
predictor = create_paddle_predictor(config)
img = np.ones((1, 3, 28, 28)).astype(np.float32)
# 准备输入
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
input_tensor.reshape(img.shape)
input_tensor.copy_from_cpu(img.copy())
# 运行
predictor.zero_copy_run()
# 获取输出
output_names = predictor.get_output_names()
output_tensor = predictor.get_output_tensor(output_names[0])
output_data = output_tensor.copy_to_cpu()
print (output_data)
上述的代码例子,我们通过加载一个简答模型以及随机输入的方式,展示了如何使用Paddle Inference进行模型预测。可能对于刚接触Paddle Inferenece同学来说,代码中会有一些陌生名词出现,比如AnalysisConfig, Predictor 等。先不要着急,接下来的文章中会对这些概念进行详细的介绍。
相关链接
使用流程¶
一: 模型准备¶
Paddle Inference目前支持的模型结构为PaddlePaddle深度学习框架产出的模型格式。因此,在您开始使用 Paddle Inference框架前您需要准备一个由PaddlePaddle框架保存的模型。 如果您手中的模型是由诸如Caffe2、Tensorflow等框架产出的,那么我们推荐您使用 X2Paddle 工具进行模型格式转换。
二: 环境准备¶
1) Python 环境
安装Python环境有以下三种方式:
# 拉取镜像,该镜像预装Paddle 1.8 Python环境
docker pull hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6
export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
export NVIDIA_SMI="-v /usr/bin/nvidia-smi:/usr/bin/nvidia-smi"
docker run $CUDA_SO $DEVICES $NVIDIA_SMI --name trt_open --privileged --security-opt seccomp=unconfined --net=host -v $PWD:/paddle -it hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6 /bin/bash
2) C++ 环境
获取c++预测库有以下三种方式:
官网 下载预编译库
使用docker镜像
# 拉取镜像,在容器内主目录~/下存放c++预编译库。
docker pull hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6
export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
export NVIDIA_SMI="-v /usr/bin/nvidia-smi:/usr/bin/nvidia-smi"
docker run $CUDA_SO $DEVICES $NVIDIA_SMI --name trt_open --privileged --security-opt seccomp=unconfined --net=host -v $PWD:/paddle -it hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6 /bin/bash
参照接下来的 `预测库编译 <./source_compile.html>`_页面进行自行编译。
三:使用Paddle Inference执行预测¶
使用Paddle Inference进行推理部署的流程如下所示。
配置推理选项。 AnalysisConfig 是飞桨提供的配置管理器API。在使用Paddle Inference进行推理部署过程中,需要使用 AnalysisConfig 详细地配置推理引擎参数,包括但不限于在何种设备(CPU/GPU)上部署( config.EnableUseGPU )、加载模型路径、开启/关闭计算图分析优化、使用MKLDNN/TensorRT进行部署的加速等。参数的具体设置需要根据实际需求来定。
创建 AnalysisPredictor 。 AnalysisPredictor 是Paddle Inference提供的推理引擎。你只需要简单的执行一行代码即可完成预测引擎的初始化 std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config) ,config为1步骤中创建的 AnalysisConfig。
准备输入数据。执行 auto input_names = predictor->GetInputNames() ,您会获取到模型所有输入tensor的名字,同时通过执行 auto tensor = predictor->GetInputTensor(input_names[i]) ; 您可以获取第i个输入的tensor,通过 tensor->copy_from_cpu(data) 方式,将data中的数据拷贝到tensor中。
调用predictor->ZeroCopyRun()执行推理。
获取推理输出。执行 auto out_names = predictor->GetOutputNames() ,您会获取到模型所有输出tensor的名字,同时通过执行 auto tensor = predictor->GetOutputTensor(out_names[i]) ; 您可以获取第i个输出的tensor。通过 tensor->copy_to_cpu(data) 将tensor中的数据copy到data指针上
源码编译¶
什么时候需要源码编译?¶
深度学习的发展十分迅速,对科研或工程人员来说,可能会遇到一些需要自己开发op的场景,可以在python层面编写op,但如果对性能有严格要求的话则必须在C++层面开发op,对于这种情况,需要用户源码编译飞桨,使之生效。 此外对于绝大多数使用C++将模型部署上线的工程人员来说,您可以直接通过飞桨官网下载已编译好的预测库,快捷开启飞桨使用之旅。飞桨官网 提供了多个不同环境下编译好的预测库。如果用户环境与官网提供环境不一致(如cuda 、cudnn、tensorrt版本不一致等),或对飞桨源代码有修改需求,或希望进行定制化构建,可查阅本文档自行源码编译得到预测库。
编译原理¶
一:目标产物
飞桨框架的源码编译包括源代码的编译和链接,最终生成的目标产物包括:
含有 C++ 接口的头文件及其二进制库:用于C++环境,将文件放到指定路径即可开启飞桨使用之旅。
Python Wheel 形式的安装包:用于Python环境,此安装包需要参考 飞桨安装教程 进行安装操作。也就是说,前面讲的pip安装属于在线安装,这里属于本地安装。
二:基础概念
飞桨主要由C++语言编写,通过pybind工具提供了Python端的接口,飞桨的源码编译主要包括编译和链接两步。 * 编译过程由编译器完成,编译器以编译单元(后缀名为 .cc 或 .cpp 的文本文件)为单位,将 C++ 语言 ASCII 源代码翻译为二进制形式的目标文件。一个工程通常由若干源码文件组织得到,所以编译完成后,将生成一组目标文件。 * 链接过程使分离编译成为可能,由链接器完成。链接器按一定规则将分离的目标文件组合成一个能映射到内存的二进制程序文件,并解析引用。由于这个二进制文件通常包含源码中指定可被外部用户复用的函数接口,所以也被称作函数库。根据链接规则不同,链接可分为静态和动态链接。静态链接对目标文件进行归档;动态链接使用地址无关技术,将链接放到程序加载时进行。 配合包含声明体的头文件(后缀名为 .h 或 .hpp),用户可以复用程序库中的代码开发应用。静态链接构建的应用程序可独立运行,而动态链接程序在加载运行时需到指定路径下搜寻其依赖的二进制库。
三:编译方式
飞桨框架的设计原则之一是满足不同平台的可用性。然而,不同操作系统惯用的编译和链接器是不一样的,使用它们的命令也不一致。比如,Linux 一般使用 GNU 编译器套件(GCC),Windows 则使用 Microsoft Visual C++(MSVC)。为了统一编译脚本,飞桨使用了支持跨平台构建的 CMake,它可以输出上述编译器所需的各种 Makefile 或者 Project 文件。 为方便编译,框架对常用的CMake命令进行了封装,如仿照 Bazel工具封装了 cc_binary 和 cc_library ,分别用于可执行文件和库文件的产出等,对CMake感兴趣的同学可在 cmake/generic.cmake 中查看具体的实现逻辑。Paddle的CMake中集成了生成python wheel包的逻辑,对如何生成wheel包感兴趣的同学可参考 相关文档 。
编译步骤¶
飞桨分为 CPU 版本和 GPU 版本。如果您的计算机没有 Nvidia GPU,请选择 CPU 版本构建安装。如果您的计算机含有 Nvidia GPU( 1.0 且预装有 CUDA / CuDNN,也可选择 GPU 版本构建安装。本节简述飞桨在常用环境下的源码编译方式,欢迎访问飞桨官网获取更详细内容。请阅读本节内容。
推荐配置及依赖项
1、稳定的互联网连接,主频 1 GHz 以上的多核处理器,9 GB 以上磁盘空间。 2、Python 版本 2.7 或 3.5 以上,pip 版本 9.0 及以上;CMake v3.5 及以上;Git 版本 2.17 及以上。请将可执行文件放入系统环境变量中以方便运行。 3、GPU 版本额外需要 Nvidia CUDA 9 / 10,CuDNN v7 及以上版本。根据需要还可能依赖 NCCL 和 TensorRT。
基于Ubuntu 18.04¶
一:环境准备
除了本节开头提到的依赖,在 Ubuntu 上进行飞桨的源码编译,您还需要准备 GCC8 编译器等工具,可使用下列命令安装:
sudo apt-get install gcc g++ make cmake git vim unrar python3 python3-dev python3-pip swig wget patchelf libopencv-dev
pip3 install numpy protobuf wheel setuptools
若需启用 cuda 加速,需准备 cuda、cudnn、nccl。上述工具的安装请参考 nvidia 官网,以 cuda10.1,cudnn7.6 为例配置 cuda 环境。
# cuda
sh cuda_10.1.168_418.67_linux.run
export PATH=/usr/local/cuda-10.1/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-10.1/${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
# cudnn
tar -xzvf cudnn-10.1-linux-x64-v7.6.4.38.tgz
sudo cp -a cuda/include/cudnn.h /usr/local/cuda/include/
sudo cp -a cuda/lib64/libcudnn* /usr/local/cuda/lib64/
# nccl
# install nccl local deb 参考https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html
sudo dpkg -i nccl-repo-ubuntu1804-2.5.6-ga-cuda10.1_1-1_amd64.deb
# 根据安装提示,还需要执行sudo apt-key add /var/nccl-repo-2.5.6-ga-cuda10.1/7fa2af80.pub
sudo apt update
sudo apt install libnccl2 libnccl-dev
sudo ldconfig
编译飞桨过程中可能会打开很多文件,Ubuntu 18.04 默认设置最多同时打开的文件数是1024(参见 ulimit -a),需要更改这个设定值。
在 /etc/security/limits.conf 文件中添加两行。
* hard noopen 102400
* soft noopen 102400
重启计算机,重启后执行以下指令,请将${user}切换成当前用户名。
su ${user}
ulimit -n 102400
二:编译命令
使用 Git 将飞桨代码克隆到本地,并进入目录,切换到稳定版本(git tag显示的标签名,如v1.7.1)。 飞桨使用 develop 分支进行最新特性的开发,使用 release 分支发布稳定版本。在 GitHub 的 Releases 选项卡中,可以看到飞桨版本的发布记录。
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
git checkout v1.7.1
下面以 GPU 版本为例说明编译命令。其他环境可以参考“CMake编译选项表”修改对应的cmake选项。比如,若编译 CPU 版本,请将 WITH_GPU 设置为 OFF。
# 创建并进入 build 目录
mkdir build_cuda && cd build_cuda
# 执行cmake指令
cmake -DPY_VERSION=3 \
-DWITH_TESTING=OFF \
-DWITH_MKL=ON \
-DWITH_GPU=ON \
-DON_INFER=ON \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
..
使用make编译
make -j4
编译成功后可在dist目录找到生成的.whl包
pip3 install python/dist/paddlepaddle-1.7.1-cp36-cp36m-linux_x86_64.whl
预测库编译
make inference_lib_dist -j4
cmake编译环境表
以下介绍的编译方法都是通用步骤,根据环境对应修改cmake选项即可。
选项 |
说明 |
默认值 |
---|---|---|
WITH_GPU |
是否支持GPU |
ON |
WITH_AVX |
是否编译含有AVX指令集的飞桨二进制文件 |
ON |
WITH_PYTHON |
是否内嵌PYTHON解释器并编译Wheel安装包 |
ON |
WITH_TESTING |
是否开启单元测试 |
OFF |
WITH_MKL |
是否使用MKL数学库,如果为否,将使用OpenBLAS |
ON |
WITH_SYSTEM_BLAS |
是否使用系统自带的BLAS |
OFF |
WITH_DISTRIBUTE |
是否编译带有分布式的版本 |
OFF |
WITH_BRPC_RDMA |
是否使用BRPC,RDMA作为RPC协议 |
OFF |
ON_INFER |
是否打开预测优化 |
OFF |
CUDA_ARCH_NAME |
是否只针对当前CUDA架构编译 |
All:编译所有可支持的CUDA架构;Auto:自动识别当前环境的架构编译 |
TENSORRT_ROOT |
TensorRT_lib的路径,该路径指定后会编译TRT子图功能eg:/paddle/nvidia/TensorRT/ |
/usr |
基于Windows 10¶
一:环境准备
除了本节开头提到的依赖,在 Windows 10 上编译飞桨,您还需要准备 Visual Studio 2015 Update3 以上版本。本节以 Visual Studio 企业版 2019(C++ 桌面开发,含 MSVC 14.24)、Python 3.8 为例介绍编译过程。
在命令提示符输入下列命令,安装必需的 Python 组件。
pip3 install numpy protobuf wheel`
二:编译命令
使用 Git 将飞桨代码克隆到本地,并进入目录,切换到稳定版本(git tag显示的标签名,如v1.7.1)。 飞桨使用 develop 分支进行最新特性的开发,使用 release 分支发布稳定版本。在 GitHub 的 Releases 选项卡中,可以看到 Paddle 版本的发布记录。
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
git checkout v1.7.1
创建一个构建目录,并在其中执行 CMake,生成解决方案文件 Solution File,以编译 CPU 版本为例说明编译命令,其他环境可以参考“CMake编译选项表”修改对应的cmake选项。
mkdir build
cd build
cmake .. -G "Visual Studio 16 2019" -A x64 -DWITH_GPU=OFF -DWITH_TESTING=OFF
-DCMAKE_BUILD_TYPE=Release -DPY_VERSION=3
使用 Visual Studio 打开解决方案文件,在窗口顶端的构建配置菜单中选择 Release x64,单击生成解决方案,等待构建完毕即可。
cmake编译环境表
选项 |
说明 |
默认值 |
---|---|---|
WITH_GPU |
是否支持GPU |
ON |
WITH_AVX |
是否编译含有AVX指令集的飞桨二进制文件 |
ON |
WITH_PYTHON |
是否内嵌PYTHON解释器并编译Wheel安装包 |
ON |
WITH_TESTING |
是否开启单元测试 |
OFF |
WITH_MKL |
是否使用MKL数学库,如果为否,将使用OpenBLAS |
ON |
WITH_SYSTEM_BLAS |
是否使用系统自带的BLAS |
OFF |
WITH_DISTRIBUTE |
是否编译带有分布式的版本 |
OFF |
WITH_BRPC_RDMA |
是否使用BRPC,RDMA作为RPC协议 |
OFF |
ON_INFER |
是否打开预测优化 |
OFF |
CUDA_ARCH_NAME |
是否只针对当前CUDA架构编译 |
All:编译所有可支持的CUDA架构;Auto:自动识别当前环境的架构编译 |
TENSORRT_ROOT |
TensorRT_lib的路径,该路径指定后会编译TRT子图功能eg:/paddle/nvidia/TensorRT/ |
/usr |
结果验证
一:python whl包
编译完毕后,会在 python/dist 目录下生成一个文件名类似 paddlepaddle-1.7.1-cp36-cp36m-linux_x86_64.whl 的 Python Wheel 安装包,安装测试的命令为:
pip3 install python/dist/paddlepaddle-1.7.1-cp36-cp36m-linux_x86_64.whl
安装完成后,可以使用 python3 进入python解释器,输入以下指令,出现 `Your Paddle Fluid is installed succesfully! ` ,说明安装成功。
import paddle.fluid as fluid
fluid.install_check.run_check()
二:c++ lib
预测库编译后,所有产出均位于build目录下的fluid_inference_install_dir目录内,目录结构如下。version.txt 中记录了该预测库的版本信息,包括Git Commit ID、使用OpenBlas或MKL数学库、CUDA/CUDNN版本号。
build/fluid_inference_install_dir
├── CMakeCache.txt
├── paddle
│ ├── include
│ │ ├── paddle_anakin_config.h
│ │ ├── paddle_analysis_config.h
│ │ ├── paddle_api.h
│ │ ├── paddle_inference_api.h
│ │ ├── paddle_mkldnn_quantizer_config.h
│ │ └── paddle_pass_builder.h
│ └── lib
│ ├── libpaddle_fluid.a (Linux)
│ ├── libpaddle_fluid.so (Linux)
│ └── libpaddle_fluid.lib (Windows)
├── third_party
│ ├── boost
│ │ └── boost
│ ├── eigen3
│ │ ├── Eigen
│ │ └── unsupported
│ └── install
│ ├── gflags
│ ├── glog
│ ├── mkldnn
│ ├── mklml
│ ├── protobuf
│ ├── xxhash
│ └── zlib
└── version.txt
Include目录下包括了使用飞桨预测库需要的头文件,lib目录下包括了生成的静态库和动态库,third_party目录下包括了预测库依赖的其它库文件。
您可以编写应用代码,与预测库联合编译并测试结果。请参 C++ 预测库 API 使用 一节。
使用Python预测¶
Paddle Inference提供了高度优化的Python 和C++ API预测接口,本篇文档主要介绍Python API,使用C++ API进行预测的文档可以参考可以参考 这里 。
下面是详细的使用说明。
使用Python预测API预测包含以下几个主要步骤:
配置推理选项
创建Predictor
准备模型输入
模型推理
获取模型输出
我们先从一个简单程序入手,介绍这一流程:
def create_predictor():
# 通过AnalysisConfig配置推理选项
config = AnalysisConfig("./resnet50/model", "./resnet50/params")
config.switch_use_feed_fetch_ops(False)
config.enable_use_gpu(100, 0)
config.enable_mkldnn()
config.enable_memory_optim()
predictor = create_paddle_predictor(config)
return predictor
def run(predictor, data):
# 准备模型输入
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_tensor(name)
input_tensor.reshape(data[i].shape)
input_tensor.copy_from_cpu(data[i].copy())
# 执行模型推理
predictor.zero_copy_run()
results = []
# 获取模型输出
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_tensor(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)
return results
以上的程序中 create_predictor 函数对推理过程进行了配置以及创建了Predictor。 run 函数进行了输入数据的准备、模型推理以及输出数据的获取过程。
在接下来的部分中,我们会依次对程序中出现的AnalysisConfig,Predictor,模型输入,模型输出进行详细的介绍。
一、推理配置管理器AnalysisConfig¶
AnalysisConfig管理AnalysisPredictor的推理配置,提供了模型路径设置、推理引擎运行设备选择以及多种优化推理流程的选项。配置中包括了必选配置以及可选配置。
1. 必选配置¶
a.设置模型和参数路径
Non-combined形式:模型文件夹 model_dir 下存在一个模型文件和多个参数文件时,传入模型文件夹路径,模型文件名默认为__model__。 使用方式为: config.set_model(“./model_dir”)
Combined形式:模型文件夹 model_dir 下只有一个模型文件 model 和一个参数文件params时,传入模型文件和参数文件路径。使用方式为: config.set_model(“./model_dir/model”, “./model_dir/params”)
内存加载模式:如果模型是从内存加载,可以使用:
import os model_buffer = open('./resnet50/model','rb') params_buffer = open('./resnet50/params','rb') model_size = os.fstat(model_buffer.fileno()).st_size params_size = os.fstat(params_buffer.fileno()).st_size config.set_model_buffer(model_buffer.read(), model_size, params_buffer.read(), params_size)
关于 non-combined 以及 combined 模型介绍,请参照 这里。
b. 关闭feed与fetch OP
config.switch_use_feed_fetch_ops(False) # 关闭feed和fetch OP
2. 可选配置¶
a. 加速CPU推理
# 开启MKLDNN,可加速CPU推理,要求预测库带MKLDNN功能。
config.enable_mkldnn()
# 可以设置CPU数学库线程数math_threads,可加速推理。
# 注意:math_threads * 外部线程数 需要小于总的CPU的核心数目,否则会影响预测性能。
config.set_cpu_math_library_num_threads(10)
b. 使用GPU推理
# enable_use_gpu后,模型将运行在GPU上。
# 第一个参数表示预先分配显存数目,第二个参数表示设备的ID。
config.enable_use_gpu(100, 0)
如果使用的预测lib带Paddle-TRT子图功能,可以打开TRT选项进行加速:
# 开启TensorRT推理,可提升GPU推理性能,需要使用带TensorRT的推理库
config.enable_tensorrt_engine(1 << 30, # workspace_size
batch_size, # max_batch_size
3, # min_subgraph_size
AnalysisConfig.Precision.Float32, # precision
False, # use_static
False, # use_calib_mode
)
通过计算图分析,Paddle可以自动将计算图中部分子图融合,并调用NVIDIA的 TensorRT 来进行加速。 使用Paddle-TensorRT 预测的完整方法可以参考 这里。
c. 内存/显存优化
config.enable_memory_optim() # 开启内存/显存复用
该配置设置后,在模型图分析阶段会对图中的变量进行依赖分类,两两互不依赖的变量会使用同一块内存/显存空间,缩减了运行时的内存/显存占用(模型较大或batch较大时效果显著)。
d. debug开关
# 该配置设置后,会关闭模型图分析阶段的任何图优化,预测期间运行同训练前向代码一致。
config.switch_ir_optim(False)
# 该配置设置后,会在模型图分析的每个阶段后保存图的拓扑信息到.dot文件中,该文件可用graphviz可视化。
config.switch_ir_debug(True)
二、预测器PaddlePredictor¶
PaddlePredictor 是在模型上执行推理的预测器,根据AnalysisConfig中的配置进行创建。
predictor = create_paddle_predictor(config)
create_paddle_predictor 期间首先对模型进行加载,并且将模型转换为由变量和运算节点组成的计算图。接下来将进行一系列的图优化,包括OP的横向纵向融合,删除无用节点,内存/显存优化,以及子图(Paddle-TRT)的分析,加速推理性能,提高吞吐。
三:输入/输出¶
1.准备输入¶
a. 获取模型所有输入的Tensor名字
input_names = predictor.get_input_names()
b. 获取对应名字下的Tensor
# 获取第0个输入
input_tensor = predictor.get_input_tensor(input_names[0])
c. 将输入数据copy到Tensor中
# 在copy前需要设置Tensor的shape
input_tensor.reshape((batch_size, channels, height, width))
# Tensor会根据上述设置的shape从input_data中拷贝对应数目的数据。input_data为numpy数组。
input_tensor.copy_from_cpu(input_data)
使用C++预测¶
为了简单方便地进行推理部署,飞桨提供了一套高度优化的C++ API推理接口。下面对各主要API使用方法进行详细介绍。
在 使用流程 一节中,我们了解到Paddle Inference预测包含了以下几个方面:
配置推理选项
创建predictor
准备模型输入
模型推理
获取模型输出
那我们先用一个简单的程序介绍这一过程:
std::unique_ptr<paddle::PaddlePredictor> CreatePredictor() {
// 通过AnalysisConfig配置推理选项
AnalysisConfig config;
config.SetModel(“./resnet50/model”,
"./resnet50/params");
config.EnableUseGpu(100, 0);
config.SwitchUseFeedFetchOps(false);
config.EnableMKLDNN();
config.EnableMemoryOptim();
// 创建predictor
return CreatePaddlePredictor(config);
}
void Run(paddle::PaddlePredictor *predictor,
const std::vector<float>& input,
const std::vector<int>& input_shape,
std::vector<float> *out_data) {
// 准备模型的输入
int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape(input_shape);
input_t->copy_from_cpu(input.data());
// 模型推理
CHECK(predictor->ZeroCopyRun());
// 获取模型的输出
auto output_names = predictor->GetOutputNames();
// there is only one output of Resnet50
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
out_data->resize(out_num);
output_t->copy_to_cpu(out_data->data());
}
以上的程序中 CreatePredictor 函数对推理过程进行了配置以及创建了Predictor。 Run 函数进行了输入数据的准备、模型推理以及输出数据的获取过程。
接下来我们依次对程序中出现的AnalysisConfig,Predictor,模型输入,模型输出做一个详细的介绍。
一:关于AnalysisConfig¶
AnalysisConfig管理AnalysisPredictor的推理配置,提供了模型路径设置、推理引擎运行设备选择以及多种优化推理流程的选项。配置中包括了必选配置以及可选配置。
1. 必选配置¶
a. 设置模型和参数路径
从磁盘加载模型时,根据模型和参数文件存储方式不同,设置AnalysisConfig加载模型和参数的路径有两种形式:
non-combined形式 :模型文件夹model_dir下存在一个模型文件和多个参数文件时,传入模型文件夹路径,模型文件名默认为__model__。 使用方式为: config->SetModel(“./model_dir”);。
combined形式 :模型文件夹model_dir下只有一个模型文件`model`和一个参数文件params时,传入模型文件和参数文件路径。 使用方式为: config->SetModel(“./model_dir/model”, “./model_dir/params”);。
内存加载模式:如果模型是从内存加载(模型必须为combined形式),可以使用
std::ifstream in_m(FLAGS_dirname + "/model");
std::ifstream in_p(FLAGS_dirname + "/params");
std::ostringstream os_model, os_param;
os_model << in_m.rdbuf();
os_param << in_p.rdbuf();
config.SetModelBuffer(os_model.str().data(), os_model.str().size(), os_param.str().data(), os_param.str().size());
Paddle Inference有两种格式的模型,分别为 non-combined 以及 combined 。这两种类型我们在 Quick Start 一节中提到过,忘记的同学可以回顾下。
b. 关闭Feed,Fetch op
config->SwitchUseFeedFetchOps(false); // 关闭feed和fetch OP使用,使用ZeroCopy接口必须设置此项`
我们用一个小的例子来说明我们为什么要关掉它们。 假设我们有一个模型,模型运行的序列为: input -> FEED_OP -> feed_out -> CONV_OP -> conv_out -> FETCH_OP -> output
序列中大些字母的FEED_OP, CONV_OP, FETCH_OP 为模型中的OP, 小写字母的input,feed_out,output 为模型中的变量。
在ZeroCopy模式下,我们通过 predictor->GetInputTensor(input_names[0]) 获取的模型输入为FEED_OP的输出, 即feed_out,我们通过 predictor->GetOutputTensor(output_names[0]) 接口获取的模型输出为FETCH_OP的输入,即conv_out,这种情况下,我们在运行期间就没有必要运行feed和fetch OP了,因此需要设置 config->SwitchUseFeedFetchOps(false) 来关闭feed和fetch op。
2. 可选配置¶
a. 加速CPU推理
// 开启MKLDNN,可加速CPU推理,要求预测库带MKLDNN功能。
config->EnableMKLDNN();
// 可以设置CPU数学库线程数math_threads,可加速推理。
// 注意:math_threads * 外部线程数 需要小于总的CPU的核心数目,否则会影响预测性能。
config->SetCpuMathLibraryNumThreads(10);
b. 使用GPU推理
// EnableUseGpu后,模型将运行在GPU上。
// 第一个参数表示预先分配显存数目,第二个参数表示设备的ID。
config->EnableUseGpu(100, 0);
如果使用的预测lib带Paddle-TRT子图功能,可以打开TRT选项进行加速, 详细的请访问 Paddle-TensorRT文档:
// 开启TensorRT推理,可提升GPU推理性能,需要使用带TensorRT的推理库
config->EnableTensorRtEngine(1 << 30 /*workspace_size*/,
batch_size /*max_batch_size*/,
3 /*min_subgraph_size*/,
AnalysisConfig::Precision::kFloat32 /*precision*/,
false /*use_static*/,
false /*use_calib_mode*/);
通过计算图分析,Paddle可以自动将计算图中部分子图融合,并调用NVIDIA的 TensorRT 来进行加速。
c. 内存/显存优化
config->EnableMemoryOptim(); // 开启内存/显存复用
该配置设置后,在模型图分析阶段会对图中的变量进行依赖分类,两两互不依赖的变量会使用同一块内存/显存空间,缩减了运行时的内存/显存占用(模型较大或batch较大时效果显著)。
d. debug开关
// 该配置设置后,会关闭模型图分析阶段的任何图优化,预测期间运行同训练前向代码一致。
config->SwitchIrOptim(false);
// 该配置设置后,会在模型图分析的每个阶段后保存图的拓扑信息到.dot文件中,该文件可用graphviz可视化。
config->SwitchIrDebug();
二:关于PaddlePredictor¶
PaddlePredictor 是在模型上执行推理的预测器,根据AnalysisConfig中的配置进行创建。
std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config);
CreatePaddlePredictor 期间首先对模型进行加载,并且将模型转换为由变量和运算节点组成的计算图。接下来将进行一系列的图优化,包括OP的横向纵向融合,删除无用节点,内存/显存优化,以及子图(Paddle-TRT)的分析,加速推理性能,提高吞吐。
三:输入输出¶
1. 准备输入¶
a. 获取模型所有输入的tensor名字
std::vector<std::string> input_names = predictor->GetInputNames();
b. 获取对应名字下的tensor
// 获取第0个输入
auto input_t = predictor->GetInputTensor(input_names[0]);
c. 将数据copy到tensor中
// 在copy前需要设置tensor的shape
input_t->Reshape({batch_size, channels, height, width});
// tensor会根据上述设置的shape从input_data中拷贝对应数目的数据到tensor中。
input_t->copy_from_cpu<float>(input_data /*数据指针*/);
当然我们也可以用mutable_data获取tensor的数据指针:
// 参数可为PaddlePlace::kGPU, PaddlePlace::kCPU
float *input_d = input_t->mutable_data<float>(PaddlePlace::kGPU);
2. 获取输出¶
a. 获取模型所有输出的tensor名字
std::vector<std::string> out_names = predictor->GetOutputNames();
b. 获取对应名字下的tensor
// 获取第0个输出
auto output_t = predictor->GetOutputTensor(out_names[0]);
c. 将数据copy到tensor中
std::vector<float> out_data;
// 获取输出的shpae
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
out_data->resize(out_num);
output_t->copy_to_cpu(out_data->data());
我们可以用data接口获取tensor的数据指针:
// 参数可为PaddlePlace::kGPU, PaddlePlace::kCPU
int output_size;
float *output_d = output_t->data<float>(PaddlePlace::kGPU, &output_size);
下一步
看到这里您是否已经对Paddle Inference的C++使用有所了解了呢?请访问 这里 进行样例测试。
使用Paddle-TensorRT库预测¶
NVIDIA TensorRT 是一个高性能的深度学习预测库,可为深度学习推理应用程序提供低延迟和高吞吐量。PaddlePaddle 采用子图的形式对TensorRT进行了集成,即我们可以使用该模块来提升Paddle模型的预测性能。在这篇文章中,我们会介绍如何使用Paddle-TRT子图加速预测。
概述¶
当模型加载后,神经网络可以表示为由变量和运算节点组成的计算图。如果我们打开TRT子图模式,在图分析阶段,Paddle会对模型图进行分析同时发现图中可以使用TensorRT优化的子图,并使用TensorRT节点替换它们。在模型的推断期间,如果遇到TensorRT节点,Paddle会调用TensorRT库对该节点进行优化,其他的节点调用Paddle的原生实现。TensorRT除了有常见的OP融合以及显存/内存优化外,还针对性的对OP进行了优化加速实现,降低预测延迟,提升推理吞吐。
目前Paddle-TRT支持静态shape模式以及/动态shape模式。在静态shape模式下支持图像分类,分割,检测模型,同时也支持Fp16, Int8的预测加速。在动态shape模式下,除了对动态shape的图像模型(FCN, Faster rcnn)支持外,同时也对NLP的Bert/Ernie模型也进行了支持。
Paddle-TRT的现有能力:
1)静态shape:
支持模型:
分类模型 |
检测模型 |
分割模型 |
---|---|---|
Mobilenetv1 |
yolov3 |
ICNET |
Resnet50 |
SSD |
UNet |
Vgg16 |
Mask-rcnn |
FCN |
Resnext |
Faster-rcnn |
|
AlexNet |
Cascade-rcnn |
|
Se-ResNext |
Retinanet |
|
GoogLeNet |
Mobilenet-SSD |
|
DPN |
Fp16:
Calib Int8:
优化信息序列化:
加载PaddleSlim Int8模型:
2)动态shape:
支持模型:
图像 |
NLP |
---|---|
FCN |
Bert |
Faster_RCNN |
Ernie |
Fp16:
Calib Int8:
优化信息序列化:
加载PaddleSlim Int8模型:
Note:
从源码编译时,TensorRT预测库目前仅支持使用GPU编译,且需要设置编译选项TENSORRT_ROOT为TensorRT所在的路径。
Windows支持需要TensorRT 版本5.0以上。
使用Paddle-TRT的动态shape输入功能要求TRT的版本在6.0以上。
一:环境准备¶
使用Paddle-TRT功能,我们需要准备带TRT的Paddle运行环境,我们提供了以下几种方式:
1)linux下通过pip安装
# 该whl包依赖cuda10.1, cudnnv7.6, tensorrt6.0 的lib, 需自行下载安装并设置lib路径到LD_LIBRARY_PATH中
wget https://paddle-inference-dist.bj.bcebos.com/libs/paddlepaddle_gpu-1.8.0-cp27-cp27mu-linux_x86_64.whl
pip install -U paddlepaddle_gpu-1.8.0-cp27-cp27mu-linux_x86_64.whl
如果您想在Nvidia Jetson平台上使用,请点击此 链接 下载whl包,然后通过pip 安装。
2)使用docker镜像
# 拉取镜像,该镜像预装Paddle 1.8 Python环境,并包含c++的预编译库,lib存放在主目录~/ 下。
docker pull hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6
export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
export NVIDIA_SMI="-v /usr/bin/nvidia-smi:/usr/bin/nvidia-smi"
docker run $CUDA_SO $DEVICES $NVIDIA_SMI --name trt_open --privileged --security-opt seccomp=unconfined --net=host -v $PWD:/paddle -it hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6 /bin/bash
3)手动编译 编译的方式请参照 编译文档
Note1: cmake 期间请设置 TENSORRT_ROOT (即TRT lib的路径), WITH_PYTHON (是否产出python whl包, 设置为ON)选项。
Note2: 编译期间会出现TensorRT相关的错误。
需要手动在 NvInfer.h (trt5) 或 NvInferRuntime.h (trt6) 文件中为 class IPluginFactory 和 class IGpuAllocator 分别添加虚析构函数:
virtual ~IPluginFactory() {};
virtual ~IGpuAllocator() {};
需要将 NvInferRuntime.h (trt6)中的 protected: ~IOptimizationProfile() noexcept = default;
改为
virtual ~IOptimizationProfile() noexcept = default;
二:API使用介绍¶
在 使用流程 一节中,我们了解到Paddle Inference预测包含了以下几个方面:
配置推理选项
创建predictor
准备模型输入
模型推理
获取模型输出
使用Paddle-TRT 也是遵照这样的流程。我们先用一个简单的例子来介绍这一流程(我们假设您已经对Paddle Inference有一定的了解,如果您刚接触Paddle Inference,请访问 这里 对Paddle Inference有个初步认识。):
import numpy as np
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
def create_predictor():
# config = AnalysisConfig("")
config = AnalysisConfig("./resnet50/model", "./resnet50/params")
config.switch_use_feed_fetch_ops(False)
config.enable_memory_optim()
config.enable_use_gpu(1000, 0)
# 打开TensorRT。此接口的详细介绍请见下文
config.enable_tensorrt_engine(workspace_size = 1<<30,
max_batch_size=1, min_subgraph_size=5,
precision_mode=AnalysisConfig.Precision.Float32,
use_static=False, use_calib_mode=False)
predictor = create_paddle_predictor(config)
return predictor
def run(predictor, img):
# 准备输入
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_tensor(name)
input_tensor.reshape(img[i].shape)
input_tensor.copy_from_cpu(img[i].copy())
# 预测
predictor.zero_copy_run()
results = []
# 获取输出
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_tensor(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)
return results
if __name__ == '__main__':
pred = create_predictor()
img = np.ones((1, 3, 224, 224)).astype(np.float32)
result = run(pred, [img])
print ("class index: ", np.argmax(result[0][0]))
通过例子我们可以看出,我们通过 enable_tensorrt_engine 接口来打开TensorRT选项的。
config.enable_tensorrt_engine(
workspace_size = 1<<30,
max_batch_size=1, min_subgraph_size=5,
precision_mode=AnalysisConfig.Precision.Float32,
use_static=False, use_calib_mode=False)
接下来让我们看下该接口中各个参数的作用:
workspace_size,类型:int,默认值为1 << 30 (1G)。指定TensorRT使用的工作空间大小,TensorRT会在该大小限制下筛选最优的kernel执行预测运算。
max_batch_size,类型:int,默认值为1。需要提前设置最大的batch大小,运行时batch大小不得超过此限定值。
min_subgraph_size,类型:int,默认值为3。Paddle-TRT是以子图的形式运行,为了避免性能损失,当子图内部节点个数大于 min_subgraph_size 的时候,才会使用Paddle-TRT运行。
precision_mode,类型:AnalysisConfig.Precision, 默认值为 AnalysisConfig.Precision.Float32。指定使用TRT的精度,支持FP32(Float32),FP16(Half),Int8(Int8)。若需要使用Paddle-TRT int8离线量化校准,需设定precision为 AnalysisConfig.Precision.Int8 , 且设置 use_calib_mode 为True。
use_static,类型:bool, 默认值为False。如果指定为True,在初次运行程序的时候会将TRT的优化信息进行序列化到磁盘上,下次运行时直接加载优化的序列化信息而不需要重新生成。
use_calib_mode,类型:bool, 默认值为False。若要运行Paddle-TRT int8离线量化校准,需要将此选项设置为True。
Int8量化预测¶
神经网络的参数在一定程度上是冗余的,在很多任务上,我们可以在保证模型精度的前提下,将Float32的模型转换成Int8的模型,从而达到减小计算量降低运算耗时、降低计算内存、降低模型大小的目的。使用Int8量化预测的流程可以分为两步:1)产出量化模型;2)加载量化模型进行Int8预测。下面我们对使用Paddle-TRT进行Int8量化预测的完整流程进行详细介绍。
1. 产出量化模型
目前,我们支持通过两种方式产出量化模型:
使用TensorRT自带Int8离线量化校准功能。校准即基于训练好的FP32模型和少量校准数据(如500~1000张图片)生成校准表(Calibration table),预测时,加载FP32模型和此校准表即可使用Int8精度预测。生成校准表的方法如下:
指定TensorRT配置时,将 precision_mode 设置为 AnalysisConfig.Precision.Int8 并且设置 use_calib_mode 为 True。
config.enable_tensorrt_engine( workspace_size=1<<30, max_batch_size=1, min_subgraph_size=5, precision_mode=AnalysisConfig.Precision.Int8, use_static=False, use_calib_mode=True)准备500张左右的真实输入数据,在上述配置下,运行模型。(Paddle-TRT会统计模型中每个tensor值的范围信息,并将其记录到校准表中,运行结束后,会将校准表写入模型目录下的 _opt_cache 目录中)
如果想要了解使用TensorRT自带Int8离线量化校准功能生成校准表的完整代码,请参考 这里 的demo。
使用模型压缩工具库PaddleSlim产出量化模型。PaddleSlim支持离线量化和在线量化功能,其中,离线量化与TensorRT离线量化校准原理相似;在线量化又称量化训练(Quantization Aware Training, QAT),是基于较多数据(如>=5000张图片)对预训练模型进行重新训练,使用模拟量化的思想,在训练阶段更新权重,实现减小量化误差的方法。使用PaddleSlim产出量化模型可以参考文档:
离线量化的优点是无需重新训练,简单易用,但量化后精度可能受影响;量化训练的优点是模型精度受量化影响较小,但需要重新训练模型,使用门槛稍高。在实际使用中,我们推荐先使用TRT离线量化校准功能生成量化模型,若精度不能满足需求,再使用PaddleSlim产出量化模型。
2. 加载量化模型进行Int8预测
加载量化模型进行Int8预测,需要在指定TensorRT配置时,将 precision_mode 设置为 AnalysisConfig.Precision.Int8 。
若使用的量化模型为TRT离线量化校准产出的,需要将 use_calib_mode 设为 True :
config.enable_tensorrt_engine( workspace_size=1<<30, max_batch_size=1, min_subgraph_size=5, precision_mode=AnalysisConfig.Precision.Int8, use_static=False, use_calib_mode=True)完整demo请参考 这里 。
若使用的量化模型为PaddleSlim量化产出的,需要将 use_calib_mode 设为 False :
config.enable_tensorrt_engine( workspace_size=1<<30, max_batch_size=1, min_subgraph_size=5, precision_mode=AnalysisConfig.Precision.Int8, use_static=False, use_calib_mode=False)完整demo请参考 这里 。
运行Dynamic shape¶
从1.8 版本开始, Paddle对TRT子图进行了Dynamic shape的支持。 使用接口如下:
config.enable_tensorrt_engine(
workspace_size = 1<<30,
max_batch_size=1, min_subgraph_size=5,
precision_mode=AnalysisConfig.Precision.Float32,
use_static=False, use_calib_mode=False)
min_input_shape = {"image":[1,3, 10, 10]}
max_input_shape = {"image":[1,3, 224, 224]}
opt_input_shape = {"image":[1,3, 100, 100]}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape)
从上述使用方式来看,在 config.enable_tensorrt_engine 接口的基础上,新加了一个config.set_trt_dynamic_shape_info 的接口。
该接口用来设置模型输入的最小,最大,以及最优的输入shape。 其中,最优的shape处于最小最大shape之间,在预测初始化期间,会根据opt shape对op选择最优的kernel。
调用了 config.set_trt_dynamic_shape_info 接口,预测器会运行TRT子图的动态输入模式,运行期间可以接受最小,最大shape间的任意的shape的输入数据。
四:Paddle-TRT子图运行原理¶
PaddlePaddle采用子图的形式对TensorRT进行集成,当模型加载后,神经网络可以表示为由变量和运算节点组成的计算图。Paddle TensorRT实现的功能是对整个图进行扫描,发现图中可以使用TensorRT优化的子图,并使用TensorRT节点替换它们。在模型的推断期间,如果遇到TensorRT节点,Paddle会调用TensorRT库对该节点进行优化,其他的节点调用Paddle的原生实现。TensorRT在推断期间能够进行Op的横向和纵向融合,过滤掉冗余的Op,并对特定平台下的特定的Op选择合适的kernel等进行优化,能够加快模型的预测速度。
下图使用一个简单的模型展示了这个过程:
原始网络
转换的网络
![]()
我们可以在原始模型网络中看到,绿色节点表示可以被TensorRT支持的节点,红色节点表示网络中的变量,黄色表示Paddle只能被Paddle原生实现执行的节点。那些在原始网络中的绿色节点被提取出来汇集成子图,并由一个TensorRT节点代替,成为转换后网络中的 block-25 节点。在网络运行过程中,如果遇到该节点,Paddle将调用TensorRT库来对其执行。
模型可视化¶
通过 Quick Start 一节中,我们了解到,预测模型包含了两个文件,一部分为模型结构文件,通常以 model 或 __model__ 文件存在;另一部分为参数文件,通常以params 文件或一堆分散的文件存在。
模型结构文件,顾名思义,存储了模型的拓扑结构,其中包括模型中各种OP的计算顺序以及OP的详细信息。很多时候,我们希望能够将这些模型的结构以及内部信息可视化,方便我们进行模型分析。接下来将会通过两种方式来讲述如何对Paddle 预测模型进行可视化。
一: 通过 VisualDL 可视化¶
1) 安装
VisualDL是飞桨可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等,帮助用户更清晰直观地理解深度学习模型训练过程及模型结构,实现高效的模型优化。 我们可以进入 GitHub主页 进行下载安装。
2)可视化
点击 下载测试模型。
支持两种启动方式:
前端拖拽上传模型文件:
无需添加任何参数,在命令行执行 visualdl 后启动界面上传文件即可:

后端透传模型文件:
在命令行加入参数 –model 并指定 模型文件 路径(非文件夹路径),即可启动:
visualdl --model ./log/model --port 8080

Graph功能详细使用,请见 Graph使用指南 。
二: 通过代码方式生成dot文件¶
1)pip 安装Paddle
2)生成dot文件
点击 下载测试模型。
#!/usr/bin/env python
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
def get_graph(program_path):
with open(program_path, 'rb') as f:
binary_str = f.read()
program = fluid.framework.Program.parse_from_string(binary_str)
return IrGraph(core.Graph(program.desc), for_test=True)
if __name__ == '__main__':
program_path = './lecture_model/__model__'
offline_graph = get_graph(program_path)
offline_graph.draw('.', 'test_model', [])
3)生成svg
Note:需要环境中安装graphviz
dot -Tsvg ./test_mode.dot -o test_model.svg
然后将test_model.svg以浏览器打开预览即可。

模型转换工具 X2Paddle¶
X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。
X2Paddle 支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。目前X2Paddle支持的模型参考 x2paddle_model_zoo 。
多框架支持¶
模型 |
caffe |
tensorflow |
onnx |
---|---|---|---|
mobilenetv1 |
Y |
Y |
F |
mobilenetv2 |
Y |
Y |
Y |
resnet18 |
Y |
Y |
F |
resnet50 |
Y |
Y |
Y |
mnasnet |
Y |
Y |
F |
efficientnet |
Y |
Y |
Y |
squeezenetv1.1 |
Y |
Y |
Y |
shufflenet |
Y |
Y |
F |
mobilenet_ssd |
Y |
Y |
F |
mobilenet_yolov3 |
F |
Y |
F |
inceptionv4 |
F |
F |
F |
mtcnn |
Y |
Y |
F |
facedetection |
Y |
F |
F |
unet |
Y |
Y |
F |
ocr_attention |
F |
F |
F |
vgg16 |
F |
F |
F |
安装¶
pip install x2paddle
安装最新版本,可使用如下安装方式
pip install git+https://github.com/PaddlePaddle/X2Paddle.git@develop
使用¶
Caffe¶
x2paddle --framework caffe \
--prototxt model.proto \
--weight model.caffemodel \
--save_dir paddle_model
TensorFlow¶
x2paddle --framework tensorflow \
--model model.pb \
--save_dir paddle_model
转换结果说明¶
在指定的 save_dir 下生成两个目录
inference_model : 模型结构和参数均序列化保存的模型格式
model_with_code : 保存了模型参数文件和模型的python代码
问题反馈
X2Paddle使用时存在问题时,欢迎您将问题或Bug报告以 Github Issues 的形式提交给我们,我们会实时跟进。
Library API¶
Class Hierarchy¶
-
- Namespace paddle
- Struct AnalysisConfig
- Struct NativeConfig
- Struct PaddleTensor
- Class CpuPassStrategy
- Class GpuPassStrategy
- Class MkldnnQuantizerConfig
- Class PaddleBuf
- Class PaddlePassBuilder
- Class PaddlePredictor
- Struct PaddlePredictor::Config
- Class PassStrategy
- Class ZeroCopyTensor
- Enum PaddleDType
- Enum PaddleEngineKind
- Enum PaddlePlace
- Enum ScaleAlgo
- Namespace paddle
Full API¶
Classes and Structs¶
Struct AnalysisConfig¶
Defined in File paddle_analysis_config.h
Struct Documentation¶
-
struct
paddle
::
AnalysisConfig
¶ configuration manager for AnalysisPredictor.
AnalysisConfig manages configurations of AnalysisPredictor. During inference procedure, there are many parameters(model/params path, place of inference, etc.) to be specified, and various optimizations(subgraph fusion, memory optimazation, TensorRT engine, etc.) to be done. Users can manage these settings by creating and modifying an AnalysisConfig, and loading it into AnalysisPredictor.
- Since
1.7.0
Public Types
Public Functions
-
AnalysisConfig
() = default¶
-
AnalysisConfig
(const AnalysisConfig &other)¶ Construct a new AnalysisConfig from another AnalysisConfig.
- Parameters
[in] other
: another AnalysisConfig
-
AnalysisConfig
(const std::string &model_dir)¶ Construct a new AnalysisConfig from a no-combined model.
- Parameters
[in] model_dir
: model directory of the no-combined model.
-
AnalysisConfig
(const std::string &prog_file, const std::string ¶ms_file)¶ Construct a new AnalysisConfig from a combined model.
- Parameters
[in] prog_file
: model file path of the combined model.[in] params_file
: params file path of the combined model.
-
void
SetModel
(const std::string &model_dir)¶ Set the no-combined model dir path.
- Parameters
model_dir
: model dir path.
-
void
SetModel
(const std::string &prog_file_path, const std::string ¶ms_file_path)¶ Set the combined model with two specific pathes for program and parameters.
- Parameters
prog_file_path
: model file path of the combined model.params_file_path
: params file path of the combined model.
-
void
SetProgFile
(const std::string &x)¶ Set the model file path of a combined model.
- Parameters
x
: model file path.
-
void
SetParamsFile
(const std::string &x)¶ Set the params file path of a combined model.
- Parameters
x
: params file path.
-
void
SetOptimCacheDir
(const std::string &opt_cache_dir)¶ Set the path of optimization cache directory.
- Parameters
opt_cache_dir
: the path of optimization cache directory.
-
const std::string &
model_dir
() const¶ Get the model directory path.
- Return
const std::string& The model directory path.
-
const std::string &
prog_file
() const¶ Get the program file path.
- Return
const std::string& The program file path.
-
const std::string &
params_file
() const¶ Get the combined parameters file.
- Return
const std::string& The combined parameters file.
-
void
DisableFCPadding
()¶ Turn off FC Padding.
-
bool
use_fc_padding
() const¶ A boolean state telling whether fc padding is used.
- Return
bool Whether fc padding is used.
-
void
EnableUseGpu
(uint64_t memory_pool_init_size_mb, int device_id = 0)¶ Turn on GPU.
- Parameters
memory_pool_init_size_mb
: initial size of the GPU memory pool in MB.device_id
: device_id the GPU card to use (default is 0).
-
void
DisableGpu
()¶ Turn off GPU.
-
bool
use_gpu
() const¶ A boolean state telling whether the GPU is turned on.
- Return
bool Whether the GPU is turned on.
-
int
gpu_device_id
() const¶ Get the GPU device id.
- Return
int The GPU device id.
-
int
memory_pool_init_size_mb
() const¶ Get the initial size in MB of the GPU memory pool.
- Return
int The initial size in MB of the GPU memory pool.
-
float
fraction_of_gpu_memory_for_pool
() const¶ Get the proportion of the initial memory pool size compared to the device.
- Return
float The proportion of the initial memory pool size.
-
void
EnableCUDNN
()¶ Turn on CUDNN.
-
bool
cudnn_enabled
() const¶ A boolean state telling whether to use CUDNN.
- Return
bool Whether to use CUDNN.
-
void
SwitchIrOptim
(int x = true)¶ Control whether to perform IR graph optimization. If turned off, the AnalysisConfig will act just like a NativeConfig.
- Parameters
x
: Whether the ir graph optimization is actived.
-
bool
ir_optim
() const¶ A boolean state telling whether the ir graph optimization is actived.
- Return
bool Whether to use ir graph optimization.
-
void
SwitchUseFeedFetchOps
(int x = true)¶ INTERNAL Determine whether to use the feed and fetch operators. Just for internal development, not stable yet. When ZeroCopyTensor is used, this should be turned off.
- Parameters
x
: Whether to use the feed and fetch operators.
-
bool
use_feed_fetch_ops_enabled
() const¶ A boolean state telling whether to use the feed and fetch operators.
- Return
bool Whether to use the feed and fetch operators.
-
void
SwitchSpecifyInputNames
(bool x = true)¶ Control whether to specify the inputs’ names. The ZeroCopyTensor type has a name member, assign it with the corresponding variable name. This is used only when the input ZeroCopyTensors passed to the AnalysisPredictor.ZeroCopyRun() cannot follow the order in the training phase.
- Parameters
x
: Whether to specify the inputs’ names.
-
bool
specify_input_name
() const¶ A boolean state tell whether the input ZeroCopyTensor names specified should be used to reorder the inputs in AnalysisPredictor.ZeroCopyRun().
- Return
bool Whether to specify the inputs’ names.
-
void
EnableTensorRtEngine
(int workspace_size = 1 << 20, int max_batch_size = 1, int min_subgraph_size = 3, Precision precision = Precision::kFloat32, bool use_static = false, bool use_calib_mode = true)¶ Turn on the TensorRT engine. The TensorRT engine will accelerate some subgraphes in the original Fluid computation graph. In some models such as resnet50, GoogleNet and so on, it gains significant performance acceleration.
- Parameters
workspace_size
: The memory size(in byte) used for TensorRT workspace.max_batch_size
: The maximum batch size of this prediction task, better set as small as possible for less performance loss.min_subgrpah_size
: The minimum TensorRT subgraph size needed, if a subgraph is smaller than this, it will not be transferred to TensorRT engine.precision
: The precision used in TensorRT.use_static
: Serialize optimization information to disk for reusing.use_calib_mode
: Use TRT int8 calibration(post training quantization).
-
bool
tensorrt_engine_enabled
() const¶ A boolean state telling whether the TensorRT engine is used.
- Return
bool Whether the TensorRT engine is used.
-
void
SetTRTDynamicShapeInfo
(std::map<std::string, std::vector<int>> min_input_shape, std::map<std::string, std::vector<int>> max_input_shape, std::map<std::string, std::vector<int>> optim_input_shape, bool disable_trt_plugin_fp16 = false)¶ Set min, max, opt shape for TensorRT Dynamic shape mode.
- Parameters
min_input_shape
: The min input shape of the subgraph input.max_input_shape
: The max input shape of the subgraph input.opt_input_shape
: The opt input shape of the subgraph input.disable_trt_plugin_fp16
: Setting this parameter to true means that TRT plugin will not run fp16.
-
void
EnableLiteEngine
(AnalysisConfig::Precision precision_mode = Precision::kFloat32, const std::vector<std::string> &passes_filter = {}, const std::vector<std::string> &ops_filter = {})¶ Turn on the usage of Lite sub-graph engine.
- Parameters
precision_mode
: Precion used in Lite sub-graph engine.passes_filter
: Set the passes used in Lite sub-graph engine.ops_filter
: Operators not supported by Lite.
-
bool
lite_engine_enabled
() const¶ A boolean state indicating whether the Lite sub-graph engine is used.
- Return
bool whether the Lite sub-graph engine is used.
-
void
SwitchIrDebug
(int x = true)¶ Control whether to debug IR graph analysis phase. This will generate DOT files for visualizing the computation graph after each analysis pass applied.
- Parameters
x
: whether to debug IR graph analysis phase.
-
void
EnableMKLDNN
()¶ Turn on MKLDNN.
-
void
SetMkldnnCacheCapacity
(int capacity)¶ Set the cache capacity of different input shapes for MKLDNN. Default value 0 means not caching any shape.
- Parameters
capacity
: The cache capacity.
-
bool
mkldnn_enabled
() const¶ A boolean state telling whether to use the MKLDNN.
- Return
bool Whether to use the MKLDNN.
-
void
SetCpuMathLibraryNumThreads
(int cpu_math_library_num_threads)¶ Set the number of cpu math library threads.
- Parameters
cpu_math_library_num_threads
: The number of cpu math library threads.
-
int
cpu_math_library_num_threads
() const¶ An int state telling how many threads are used in the CPU math library.
- Return
int The number of threads used in the CPU math library.
-
NativeConfig
ToNativeConfig
() const¶ Transform the AnalysisConfig to NativeConfig.
- Return
NativeConfig The NativeConfig transformed.
-
void
SetMKLDNNOp
(std::unordered_set<std::string> op_list)¶ Specify the operator type list to use MKLDNN acceleration.
- Parameters
op_list
: The operator type list.
-
void
EnableMkldnnQuantizer
()¶ Turn on MKLDNN quantization.
-
bool
mkldnn_quantizer_enabled
() const¶ A boolean state telling whether the MKLDNN quantization is enabled.
- Return
bool Whether the MKLDNN quantization is enabled.
-
MkldnnQuantizerConfig *
mkldnn_quantizer_config
() const¶ Get MKLDNN quantizer config.
- Return
MkldnnQuantizerConfig* MKLDNN quantizer config.
-
void
SetModelBuffer
(const char *prog_buffer, size_t prog_buffer_size, const char *params_buffer, size_t params_buffer_size)¶ Specify the memory buffer of program and parameter. Used when model and params are loaded directly from memory.
- Parameters
prog_buffer
: The memory buffer of program.prog_buffer_size
: The size of the model data.params_buffer
: The memory buffer of the combined parameters file.params_buffer_size
: The size of the combined parameters data.
-
bool
model_from_memory
() const¶ A boolean state telling whether the model is set from the CPU memory.
- Return
bool Whether model and params are loaded directly from memory.
-
void
EnableMemoryOptim
()¶ Turn on memory optimize NOTE still in development.
-
bool
enable_memory_optim
() const¶ A boolean state telling whether the memory optimization is activated.
- Return
bool Whether the memory optimization is activated.
-
void
EnableProfile
()¶ Turn on profiling report. If not turned on, no profiling report will be generated.
-
bool
profile_enabled
() const¶ A boolean state telling whether the profiler is activated.
- Return
bool Whether the profiler is activated.
-
void
DisableGlogInfo
()¶ Mute all logs in Paddle inference.
-
bool
glog_info_disabled
() const¶ A boolean state telling whether logs in Paddle inference are muted.
- Return
bool Whether logs in Paddle inference are muted.
-
void
SetInValid
() const¶ Set the AnalysisConfig to be invalid. This is to ensure that an AnalysisConfig can only be used in one AnalysisPredictor.
-
bool
is_valid
() const¶ A boolean state telling whether the AnalysisConfig is valid.
- Return
bool Whether the AnalysisConfig is valid.
-
PassStrategy *
pass_builder
() const¶ Get a pass builder for customize the passes in IR analysis phase. NOTE: Just for developer, not an official API, easy to be broken.
-
void
PartiallyRelease
()¶
Protected Attributes
-
std::string
model_dir_
¶
-
std::string
prog_file_
¶
-
std::string
params_file_
¶
-
bool
use_gpu_
= {false}¶
-
int
device_id_
= {0}¶
-
uint64_t
memory_pool_init_size_mb_
= {100}¶
-
bool
use_cudnn_
= {false}¶
-
bool
use_fc_padding_
= {true}¶
-
bool
use_tensorrt_
= {false}¶
-
int
tensorrt_workspace_size_
= {1 << 30}¶
-
int
tensorrt_max_batchsize_
= {1}¶
-
int
tensorrt_min_subgraph_size_
= {3}¶
-
bool
trt_use_static_engine_
= {false}¶
-
bool
trt_use_calib_mode_
= {true}¶
-
std::map<std::string, std::vector<int>>
min_input_shape_
= {}¶
-
std::map<std::string, std::vector<int>>
max_input_shape_
= {}¶
-
std::map<std::string, std::vector<int>>
optim_input_shape_
= {}¶
-
bool
disable_trt_plugin_fp16_
= {false}¶
-
bool
enable_memory_optim_
= {false}¶
-
bool
use_mkldnn_
= {false}¶
-
std::unordered_set<std::string>
mkldnn_enabled_op_types_
¶
-
bool
model_from_memory_
= {false}¶
-
bool
enable_ir_optim_
= {true}¶
-
bool
use_feed_fetch_ops_
= {true}¶
-
bool
ir_debug_
= {false}¶
-
bool
specify_input_name_
= {false}¶
-
int
cpu_math_library_num_threads_
= {1}¶
-
bool
with_profile_
= {false}¶
-
bool
with_glog_info_
= {true}¶
-
std::string
serialized_info_cache_
¶
-
std::unique_ptr<PassStrategy>
pass_builder_
¶
-
bool
use_lite_
= {false}¶
-
std::vector<std::string>
lite_passes_filter_
¶
-
std::vector<std::string>
lite_ops_filter_
¶
-
int
mkldnn_cache_capacity_
= {0}¶
-
bool
use_mkldnn_quantizer_
= {false}¶
-
std::shared_ptr<MkldnnQuantizerConfig>
mkldnn_quantizer_config_
¶
-
bool
is_valid_
= {true}¶
-
std::string
opt_cache_dir_
¶
Friends
- friend class ::paddle::AnalysisPredictor
Struct NativeConfig¶
Defined in File paddle_api.h
Inheritance Relationships¶
public paddle::PaddlePredictor::Config
(Struct PaddlePredictor::Config)
Struct Documentation¶
-
struct
paddle
::
NativeConfig
: public paddle::PaddlePredictor::Config¶ configuration manager for
NativePredictor
.AnalysisConfig
manages configurations ofNativePredictor
. During inference procedure, there are many parameters(model/params path, place of inference, etc.)Public Functions
-
void
SetCpuMathLibraryNumThreads
(int cpu_math_library_num_threads)¶ Set and get the number of cpu math library threads.
-
int
cpu_math_library_num_threads
() const¶
Public Members
-
bool
use_gpu
= {false}¶ GPU related fields.
-
int
device
= {0}¶
-
float
fraction_of_gpu_memory
{-1.f}¶ Change to a float in (0,1] if needed.
-
std::string
prog_file
¶
-
std::string
param_file
¶ Specify the exact path of program and parameter files.
-
bool
specify_input_name
= {false}¶ Specify the variable’s name of each input if input tensors don’t follow the
feeds
andfetches
of the phasesave_inference_model
.
Protected Attributes
-
int
cpu_math_library_num_threads_
= {1}¶ number of cpu math library (such as MKL, OpenBlas) threads for each instance.
-
void
Struct PaddlePredictor::Config¶
Defined in File paddle_api.h
Nested Relationships¶
This struct is a nested type of Class PaddlePredictor.
Inheritance Relationships¶
public paddle::NativeConfig
(Struct NativeConfig)
Struct Documentation¶
-
struct
paddle::PaddlePredictor
::
Config
Base class for NativeConfig and AnalysisConfig.
Subclassed by paddle::NativeConfig
Public Members
-
std::string
model_dir
path to the model directory.
-
std::string
Struct PaddleTensor¶
Defined in File paddle_api.h
Struct Documentation¶
-
struct
paddle
::
PaddleTensor
¶ Basic input and output data structure for PaddlePredictor.
Public Functions
-
PaddleTensor
() = default¶
-
Class CpuPassStrategy¶
Defined in File paddle_pass_builder.h
Inheritance Relationships¶
public paddle::PassStrategy
(Class PassStrategy)
Class Documentation¶
-
class
paddle
::
CpuPassStrategy
: public paddle::PassStrategy¶ The CPU passes controller, it is used in AnalysisPredictor with CPU mode.
Public Functions
-
CpuPassStrategy
()¶ Default constructor of CpuPassStrategy.
-
CpuPassStrategy
(const CpuPassStrategy &other)¶ Construct by copying another CpuPassStrategy object.
- Parameters
[in] other
: The CpuPassStrategy object we want to copy.
-
~CpuPassStrategy
() = default¶ Default destructor.
-
void
EnableCUDNN
() override¶ Enable the use of cuDNN kernel.
-
void
EnableMKLDNN
() override¶ Enable the use of MKLDNN.
-
void
EnableMkldnnQuantizer
() override¶ Enable MKLDNN quantize optimization.
-
Class GpuPassStrategy¶
Defined in File paddle_pass_builder.h
Inheritance Relationships¶
public paddle::PassStrategy
(Class PassStrategy)
Class Documentation¶
-
class
paddle
::
GpuPassStrategy
: public paddle::PassStrategy¶ The GPU passes controller, it is used in AnalysisPredictor with GPU mode.
Public Functions
-
GpuPassStrategy
()¶ Default constructor of GpuPassStrategy.
-
GpuPassStrategy
(const GpuPassStrategy &other)¶ Construct by copying another GpuPassStrategy object.
- Parameters
[in] other
: The GpuPassStrategy object we want to copy.
-
void
EnableCUDNN
() override¶ Enable the use of cuDNN kernel.
-
void
EnableMKLDNN
() override¶ Not supported in GPU mode yet.
-
void
EnableMkldnnQuantizer
() override¶ Not supported in GPU mode yet.
-
~GpuPassStrategy
() = default¶ Default destructor.
-
Class MkldnnQuantizerConfig¶
Defined in File paddle_mkldnn_quantizer_config.h
Class Documentation¶
-
class
paddle
::
MkldnnQuantizerConfig
¶ Config for mkldnn quantize.
The MkldnnQuantizerConfig is used to configure Mkldnn’s quantization parameters, including scale algorithm, warmup data, warmup batch size, quantized op list, etc.
It is not recommended to use this config directly, please refer to AnalysisConfig::mkldnn_quantizer_config()
Public Functions
-
MkldnnQuantizerConfig
()¶ Construct a new Mkldnn Quantizer Config object.
-
void
SetScaleAlgo
(std::string op_type_name, std::string conn_name, ScaleAlgo algo)¶ Set the scale algo.
Specify a quantization algorithm for a connection (input/output) of the operator type.
- Parameters
[in] op_type_name
: the operator’s name.[in] conn_name
: name of the connection (input/output) of the operator.[in] algo
: the algorithm for computing scale.
-
ScaleAlgo
scale_algo
(const std::string &op_type_name, const std::string &conn_name) const¶ Get the scale algo.
Get the quantization algorithm for a connection (input/output) of the operator type.
- Return
the scale algo.
- Parameters
[in] op_type_name
: the operator’s name.[in] conn_name
: name of the connection (input/output) of the operator.
Set the warmup data.
Set the batch of data to be used for warm-up iteration.
- Parameters
[in] data
: batch of data.
-
std::shared_ptr<std::vector<PaddleTensor>>
warmup_data
() const¶ Get the warmup data.
Get the batch of data used for warm-up iteration.
- Return
the warm up data
-
void
SetWarmupBatchSize
(int batch_size)¶ Set the warmup batch size.
Set the batch size for warm-up iteration.
- Parameters
[in] batch_size
: warm-up batch size
-
int
warmup_batch_size
() const¶ Get the warmup batch size.
Get the batch size for warm-up iteration.
- Return
the warm up batch size
-
void
SetEnabledOpTypes
(std::unordered_set<std::string> op_list)¶ Set quantized op list.
In the quantization process, set the op list that supports quantization
- Parameters
[in] op_list
: List of quantized ops
-
const std::unordered_set<std::string> &
enabled_op_types
() const¶ Get quantized op list.
- Return
list of quantized ops
-
void
SetExcludedOpIds
(std::unordered_set<int> op_ids_list)¶ Set the excluded op ids.
- Parameters
[in] op_ids_list
: excluded op ids
-
const std::unordered_set<int> &
excluded_op_ids
() const¶ Get the excluded op ids.
- Return
exclude op ids
Protected Attributes
-
std::unordered_set<std::string>
enabled_op_types_
¶
-
std::unordered_set<int>
excluded_op_ids_
¶
-
std::shared_ptr<std::vector<PaddleTensor>>
warmup_data_
¶
-
int
warmup_bs_
= {1}¶
-
Class PaddleBuf¶
Defined in File paddle_api.h
Class Documentation¶
-
class
paddle
::
PaddleBuf
¶ Memory manager for PaddleTensor.
The PaddleBuf holds a buffer for data input or output. The memory can be allocated by user or by PaddleBuf itself, but in any case, the PaddleBuf should be reused for better performance.
For user allocated memory, the following API can be used:
PaddleBuf(void* data, size_t length) to set an external memory by specifying the memory address and length.
Reset(void* data, size_t length) to reset the PaddleBuf with an external memory. ATTENTION, for user allocated memory, deallocation should be done by users externally after the program finished. The PaddleBuf won’t do any allocation or deallocation.
To have the PaddleBuf allocate and manage the memory:
PaddleBuf(size_t length) will allocate a memory of size
length
.Resize(size_t length) resize the memory to no less than
length
, ATTENTION if the allocated memory is larger thanlength
, nothing will done.
Usage:
Let PaddleBuf manage the memory internally.
const int num_elements = 128; PaddleBuf buf(num_elements/// sizeof(float));
Or
Works the exactly the same.PaddleBuf buf; buf.Resize(num_elements/// sizeof(float));
One can also make the
PaddleBuf
use the external memory.PaddleBuf buf; void* external_memory = new float[num_elements]; buf.Reset(external_memory, num_elements*sizeof(float)); ... delete[] external_memory; // manage the memory lifetime outside.
Public Functions
-
PaddleBuf
(size_t length)¶ PaddleBuf allocate memory internally, and manage it.
- Parameters
[in] length
: The length of data.
-
PaddleBuf
(void *data, size_t length)¶ Set external memory, the PaddleBuf won’t manage it.
- Parameters
[in] data
: The start address of the external memory.[in] length
: The length of data.
-
PaddleBuf
(const PaddleBuf &other)¶ Copy only available when memory is managed externally.
- Parameters
[in] other
: anotherPaddleBuf
-
void
Resize
(size_t length)¶ Resize the memory.
- Parameters
[in] length
: The length of data.
-
void
Reset
(void *data, size_t length)¶ Reset to external memory, with address and length set.
- Parameters
[in] data
: The start address of the external memory.[in] length
: The length of data.
-
bool
empty
() const¶ Tell whether the buffer is empty.
-
void *
data
() const¶ Get the data’s memory address.
-
size_t
length
() const¶ Get the memory length.
-
~PaddleBuf
()¶
-
PaddleBuf
() = default¶
Class PaddlePassBuilder¶
Defined in File paddle_pass_builder.h
Inheritance Relationships¶
public paddle::PassStrategy
(Class PassStrategy)
Class Documentation¶
-
class
paddle
::
PaddlePassBuilder
¶ This class build passes based on vector<string> input. It is part of inference API. Users can build passes, insert new passes, delete passes using this class and its functions.
Example Usage: Build a new pass.
const vector<string> passes(1, "conv_relu_mkldnn_fuse_pass"); PaddlePassBuilder builder(passes);
Subclassed by paddle::PassStrategy
Public Functions
-
PaddlePassBuilder
(const std::vector<std::string> &passes)¶ Constructor of the class. It stores the input passes.
- Parameters
[in] passes
: passes’ types.
-
void
SetPasses
(std::initializer_list<std::string> passes)¶ Stores the input passes.
- Parameters
[in] passes
: passes’ types.
-
void
AppendPass
(const std::string &pass_type)¶ Append a pass to the end of the passes.
- Parameters
[in] pass_type
: the type of the new pass.
-
void
InsertPass
(size_t idx, const std::string &pass_type)¶ Insert a pass to a specific position.
- Parameters
[in] idx
: the position to insert.[in] pass_type
: the type of insert pass.
-
void
DeletePass
(size_t idx)¶ Delete the pass at certain position ‘idx’.
- Parameters
[in] idx
: the position to delete.
-
void
DeletePass
(const std::string &pass_type)¶ Delete all passes that has a certain type ‘pass_type’.
- Parameters
[in] pass_type
: the certain pass type to be deleted.
-
void
ClearPasses
()¶ Delete all the passes.
-
void
AppendAnalysisPass
(const std::string &pass)¶ Append an analysis pass.
- Parameters
[in] pass
: the type of the new analysis pass.
-
void
TurnOnDebug
()¶ Visualize the computation graph after each pass by generating a DOT language file, one can draw them with the Graphviz toolkit.
-
std::string
DebugString
()¶ Human-readable information of the passes.
-
const std::vector<std::string> &
AllPasses
() const¶ Get information of passes.
- Return
Return list of the passes.
-
std::vector<std::string>
AnalysisPasses
() const¶ Get information of analysis passes.
- Return
Return list of analysis passes.
-
Class PaddlePredictor¶
Defined in File paddle_api.h
Class Documentation¶
-
class
paddle
::
PaddlePredictor
¶ A Predictor for executing inference on a model. Base class for AnalysisPredictor and NativePaddlePredictor.
Public Functions
-
PaddlePredictor
() = default¶
-
PaddlePredictor
(const PaddlePredictor&) = delete¶
-
PaddlePredictor &
operator=
(const PaddlePredictor&) = delete¶
-
bool
Run
(const std::vector<PaddleTensor> &inputs, std::vector<PaddleTensor> *output_data, int batch_size = -1) = 0¶ This interface takes input and runs the network. There are redundant copies of data between hosts in this operation, so it is more recommended to use the zecopyrun interface.
- Return
Whether the run is successful
- Parameters
[in] inputs
: An list of PaddleTensor as the input to the network.[out] output_data
: Pointer to the tensor list, which holds the output paddletensor[in] batch_size
: This setting has been discarded and can be ignored.
-
std::vector<std::string>
GetInputNames
()¶ Used to get the name of the network input. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
- Return
Input tensor names.
-
std::map<std::string, std::vector<int64_t>>
GetInputTensorShape
()¶ Get the input shape of the model.
- Return
A map contains all the input names and shape defined in the model.
-
std::vector<std::string>
GetOutputNames
()¶ Used to get the name of the network output. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
- Return
Output tensor names.
-
std::unique_ptr<ZeroCopyTensor>
GetInputTensor
(const std::string &name)¶ Get the input ZeroCopyTensor by name. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. The name is obtained from the GetInputNames() interface.
- Return
Return the corresponding input ZeroCopyTensor.
- Parameters
name
: The input tensor name.
-
std::unique_ptr<ZeroCopyTensor>
GetOutputTensor
(const std::string &name)¶ Get the output ZeroCopyTensor by name. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. The name is obtained from the GetOutputNames() interface.
- Return
Return the corresponding output ZeroCopyTensor.
- Parameters
name
: The output tensor name.
-
bool
ZeroCopyRun
()¶ Run the network with zero-copied inputs and outputs. Be inherited by AnalysisPredictor and only used in ZeroCopy scenarios. This will save the IO copy for transfering inputs and outputs to predictor workspace and get some performance improvement. To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(true) and then use the
GetInputTensor
andGetOutputTensor
to directly write or read the input/output tensors.- Return
Whether the run is successful
-
std::unique_ptr<PaddlePredictor>
Clone
() = 0¶ Clone an existing predictor When using clone, the same network will be created, and the parameters between them are shared.
- Return
unique_ptr which contains the pointer of predictor
-
~PaddlePredictor
() = default¶ Destroy the Predictor.
-
std::string
GetSerializedProgram
() const¶
-
struct
Config
¶ Base class for NativeConfig and AnalysisConfig.
Subclassed by paddle::NativeConfig
Public Members
-
std::string
model_dir
¶ path to the model directory.
-
std::string
-
Class PassStrategy¶
Defined in File paddle_pass_builder.h
Inheritance Relationships¶
public paddle::PaddlePassBuilder
(Class PaddlePassBuilder)
public paddle::CpuPassStrategy
(Class CpuPassStrategy)public paddle::GpuPassStrategy
(Class GpuPassStrategy)
Class Documentation¶
-
class
paddle
::
PassStrategy
: public paddle::PaddlePassBuilder¶ This class defines the pass strategies like whether to use gpu/cuDNN kernel/MKLDNN.
Subclassed by paddle::CpuPassStrategy, paddle::GpuPassStrategy
Public Functions
-
PassStrategy
(const std::vector<std::string> &passes)¶ Constructor of PassStrategy class. It works the same as PaddlePassBuilder class.
- Parameters
[in] passes
: passes’ types.
-
void
EnableCUDNN
()¶ Enable the use of cuDNN kernel.
-
void
EnableMKLDNN
()¶ Enable the use of MKLDNN. The MKLDNN control exists in both CPU and GPU mode, because there can still be some CPU kernels running in GPU mode.
-
void
EnableMkldnnQuantizer
()¶ Enable MKLDNN quantize optimization.
-
bool
use_gpu
() const¶ Check if we are using gpu.
- Return
A bool variable implying whether we are in gpu mode.
-
~PassStrategy
() = default¶ Default destructor.
-
Class ZeroCopyTensor¶
Defined in File paddle_api.h
Class Documentation¶
-
class
paddle
::
ZeroCopyTensor
¶ Represents an n-dimensional array of values. The ZeroCopyTensor is used to store the input or output of the network. Zero copy means that the tensor supports direct copy of host or device data to device, eliminating additional CPU copy. ZeroCopyTensor is only used in the AnalysisPredictor. It is obtained through PaddlePredictor::GetinputTensor() and PaddlePredictor::GetOutputTensor() interface.
Public Functions
-
void
Reshape
(const std::vector<int> &shape)¶ Reset the shape of the tensor. Generally it’s only used for the input tensor. Reshape must be called before calling mutable_data() or copy_from_cpu()
- Parameters
shape
: The shape to set.
-
template<typename
T
>
T *mutable_data
(PaddlePlace place)¶ Get the memory pointer in CPU or GPU with specific data type. Please Reshape the tensor first before call this. It’s usually used to get input data pointer.
- Parameters
place
: The place of the tensor.
-
template<typename
T
>
T *data
(PaddlePlace *place, int *size) const¶ Get the memory pointer directly. It’s usually used to get the output data pointer.
- Return
The tensor data buffer pointer.
- Parameters
[out] place
: To get the device type of the tensor.[out] size
: To get the data size of the tensor.
-
template<typename
T
>
voidcopy_from_cpu
(const T *data)¶ Copy the host memory to tensor data. It’s usually used to set the input tensor data.
- Parameters
data
: The pointer of the data, from which the tensor will copy.
-
template<typename
T
>
voidcopy_to_cpu
(T *data)¶ Copy the tensor data to the host memory. It’s usually used to get the output tensor data.
- Parameters
[out] data
: The tensor will copy the data to the address.
-
std::vector<int>
shape
() const¶ Return the shape of the Tensor.
-
void
SetLoD
(const std::vector<std::vector<size_t>> &x)¶ Set lod info of the tensor. More about LOD can be seen here: https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor.
- Parameters
x
: the lod info.
-
std::vector<std::vector<size_t>>
lod
() const¶ Return the lod info of the tensor.
-
const std::string &
name
() const¶ Return the name of the tensor.
-
void
SetPlace
(PaddlePlace place, int device = -1)¶
-
PaddleDType
type
() const¶ Return the data type of the tensor. It’s usually used to get the output tensor data type.
- Return
The data type of the tensor.
-
void
Enums¶
Enum ScaleAlgo¶
Defined in File paddle_mkldnn_quantizer_config.h
Enum Documentation¶
-
enum
paddle
::
ScaleAlgo
¶ Algorithms for finding scale of quantized Tensors.
Values:
-
enumerator
NONE
¶ Do not compute scale.
-
enumerator
MAX
¶ Find scale based on the max absolute value.
-
enumerator
MAX_CH
¶ Find scale based on the max absolute value per output channel.
-
enumerator
MAX_CH_T
¶ Find scale based on the max absolute value per output channel of a transposed tensor
-
enumerator
KL
¶ Find scale based on KL Divergence.
-
enumerator
Functions¶
Template Function paddle::CreatePaddlePredictor(const ConfigT&)¶
Defined in File paddle_api.h
Function Documentation¶
Warning
doxygenfunction: Unable to resolve multiple matches for function “paddle::CreatePaddlePredictor” with arguments (const ConfigT&) in doxygen xml output for project “My Project” from directory: ./doxyoutput/xml. Potential matches:
- template<typename ConfigT, PaddleEngineKind engine> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
- template<typename ConfigT> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
Template Function paddle::CreatePaddlePredictor(const ConfigT&)¶
Defined in File paddle_api.h
Function Documentation¶
Warning
doxygenfunction: Unable to resolve multiple matches for function “paddle::CreatePaddlePredictor” with arguments (const ConfigT&) in doxygen xml output for project “My Project” from directory: ./doxyoutput/xml. Potential matches:
- template<typename ConfigT, PaddleEngineKind engine> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
- template<typename ConfigT> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
Function paddle::get_version¶
Defined in File paddle_api.h
Function paddle::PaddleDtypeSize¶
Defined in File paddle_api.h
Function Documentation¶
-
int
paddle
::
PaddleDtypeSize
(PaddleDType dtype)¶
Variables¶
Variable paddle::kLiteSubgraphPasses¶
Defined in File paddle_pass_builder.h
Variable paddle::kTRTSubgraphPasses¶
Defined in File paddle_pass_builder.h