Welcome to Paddle-Inference’s documentation!

概述

Paddle Inference为飞桨核心框架推理引擎。Paddle Inference功能特性丰富,性能优异,针对不同平台不同的应用场景进行了深度的适配优化,做到高吞吐、低时延,保证了飞桨模型在服务器端即训即用,快速部署。

特性

  • 通用性。支持对Paddle训练出的所有模型进行预测。

  • 内存/显存复用。在推理初始化阶段,对模型中的OP输出Tensor 进行依赖分析,将两两互不依赖的Tensor在内存/显存空间上进行复用,进而增大计算并行量,提升服务吞吐量。

  • 细粒度OP融合。在推理初始化阶段,按照已有的融合模式将模型中的多个OP融合成一个OP,减少了模型的计算量的同时,也减少了 Kernel Launch的次数,从而能提升推理性能。目前Paddle Inference支持的融合模式多达几十个。

  • 高性能CPU/GPU Kernel。内置同Intel、Nvidia共同打造的高性能kernel,保证了模型推理高性能的执行。

  • 子图集成 TensorRT。Paddle Inference采用子图的形式集成TensorRT,针对GPU推理场景,TensorRT可对一些子图进行优化,包括OP的横向和纵向融合,过滤冗余的OP,并为OP自动选择最优的kernel,加快推理速度。

  • 集成MKLDNN

  • 支持加载PaddleSlim量化压缩后的模型。 PaddleSlim 是飞桨深度学习模型压缩工具,Paddle Inference可联动PaddleSlim,支持加载量化、裁剪和蒸馏后的模型并部署,由此减小模型存储空间、减少计算占用内存、加快模型推理速度。其中在模型量化方面,Paddle Inference在X86 CPU上做了深度优化 ,常见分类模型的单线程性能可提升近3倍,ERNIE模型的单线程性能可提升2.68倍。

支持系统及硬件

支持服务器端X86 CPU、NVIDIA GPU芯片,兼容Linux/macOS/Windows系统。

同时也支持NVIDIA Jetson嵌入式平台。

语言支持

  • 支持Pyhton语言

  • 支持C++ 语言

  • 支持Go语言

  • 支持R语言

下一步

  • 如果您刚接触Paddle Inference, 请访问 Quick start

Quick Start

前提准备 接下来我们会通过几段Python代码的方式对Paddle Inference使用进行介绍, 为了能够成功运行代码,请您在环境中(Mac, Windows,Linux)安装不低于1.7版本的Paddle, 安装Paddle 请参考 飞桨官网主页

导出预测模型文件

在模型训练期间,我们通常使用Python来构建模型结构,比如:

import paddle.fluid as fluid
res = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu", param_attr=param_attr)

在模型部署时,我们需要提前将这种Python表示的结构以及参数序列化到磁盘中。那是如何做到的呢?

在模型训练过程中或者模型训练结束后,我们可以通过save_inference_model接口来导出标准化的模型文件。

我们用一个简单的代码例子来展示下导出模型文件的这一过程。

import paddle
import paddle.fluid as fluid
# 建立一个简单的网络,网络的输入的shape为[batch, 3, 28, 28]
image_shape = [3, 28, 28]

img = fluid.layers.data(name='image', shape=image_shape, dtype='float32', append_batch_size=True)
# 模型包含两个Conv层
conv1 = fluid.layers.conv2d(
        input=img,
        num_filters=8,
        filter_size=3,
        stride=2,
        padding=1,
        groups=1,
        act=None,
        bias_attr=True)

out = fluid.layers.conv2d(
        input=conv1,
        num_filters=8,
        filter_size=3,
        stride=2,
        padding=1,
        groups=1,
        act=None,
        bias_attr=True)

place = fluid.CPUPlace()
exe = fluid.Executor(place)
# 创建网络中的参数变量,并初始化参数变量
exe.run(fluid.default_startup_program())

# 如果存在预训练模型
# def if_exist(var):
#            return os.path.exists(os.path.join("./ShuffleNet", var.name))
#    fluid.io.load_vars(exe, "./pretrained_model", predicate=if_exist)
# 保存模型到model目录中,只保存与输入image和输出与推理相关的部分网络
fluid.io.save_inference_model(dirname='./sample_model', feeded_var_names=['image'], target_vars = [out], executor=exe, model_filename='model', params_filename='params')

该程序运行结束后,会在本目录中生成一个sample_model目录,目录中包含model, params 两个文件,model文件表示模型的结构文件,params表示所有参数的融合文件。

飞桨提供了 两种标准 的模型文件,一种为Combined方式, 一种为No-Combined的方式。

  • Combined的方式

fluid.io.save_inference_model(dirname='./sample_model', feeded_var_names=['image'], target_vars = [out], executor=exe, model_filename='model', params_filename='params')

model_filename,params_filename表示要生成的模型结构文件、融合参数文件的名字。

  • No-Combined的方式

fluid.io.save_inference_model(dirname='./sample_model', feeded_var_names=['image'], target_vars = [out], executor=exe)

如果不指定model_filename,params_filename,会在sample_model目录下生成__model__ 模型结构文件,以及一系列的参数文件。

在模型部署期间,我们更推荐使用Combined的方式,因为涉及模型上线加密的场景时,这种方式会更友好一些。

加载模型预测

1)使用load_inference方式

我们可以使用load_inference_model接口加载训练好的模型(以sample_model模型举例),并复用训练框架的前向计算,直接完成推理。 示例程序如下所示:

import paddle.fluid as fluid
import numpy as np

data = np.ones((1, 3, 28, 28)).astype(np.float32)
exe = fluid.Executor(fluid.CPUPlace())

# 加载Combined的模型需要指定model_filename, params_filename
# 加载No-Combined的模型不需要指定model_filename, params_filename
[inference_program, feed_target_names, fetch_targets] = \
        fluid.io.load_inference_model(dirname='sample_model', executor=exe, model_filename='model', params_filename='params')

with fluid.program_guard(inference_program):
results = exe.run(inference_program,
        feed={feed_target_names[0]: data},
        fetch_list=fetch_targets, return_numpy=False)

print (np.array(results[0]).shape)
# (1, 8, 7, 7)

在上述方式中,在模型加载后会按照执行顺序将所有的OP进行拓扑排序,在运行期间Op会按照排序一一运行,整个过程中运行的为训练中前向的OP,期间不会有任何的优化(OP融合,显存优化,预测Kernel针对优化)。 因此,load_inference_model的方式预测期间很可能不会有很好的性能表现,此方式比较适合用来做实验(测试模型的效果、正确性等)使用,并不适用于真正的部署上线。接下来我们会重点介绍Paddle Inference的使用。

2)使用Paddle Inference API方式

不同于 load_inference_model方式,Paddle Inference 在模型加载后会进行一系列的优化,包括: Kernel优化,OP横向,纵向融合,显存/内存优化,以及MKLDNN,TensorRT的集成等,性能和吞吐会得到大幅度的提升。这些优化会在之后的文档中进行详细的介绍。

那我们先用一个简单的代码例子来介绍Paddle Inference 的使用。

from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor

import numpy as np

# 配置运行信息
# config = AnalysisConfig("./sample_model") # 加载non-combined 模型格式
config = AnalysisConfig("./sample_model/model", "./sample_model/params") # 加载combine的模型格式

config.switch_use_feed_fetch_ops(False)
config.enable_memory_optim()
config.enable_use_gpu(1000, 0)

# 根据config创建predictor
predictor = create_paddle_predictor(config)

img = np.ones((1, 3, 28, 28)).astype(np.float32)

# 准备输入
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
input_tensor.reshape(img.shape)
input_tensor.copy_from_cpu(img.copy())

# 运行
predictor.zero_copy_run()

# 获取输出
output_names = predictor.get_output_names()
output_tensor = predictor.get_output_tensor(output_names[0])
output_data = output_tensor.copy_to_cpu()

print (output_data)

上述的代码例子,我们通过加载一个简答模型以及随机输入的方式,展示了如何使用Paddle Inference进行模型预测。可能对于刚接触Paddle Inferenece同学来说,代码中会有一些陌生名词出现,比如AnalysisConfig, Predictor 等。先不要着急,接下来的文章中会对这些概念进行详细的介绍。

相关链接

Python API 使用介绍 C++ API使用介绍 Python 使用样例 C++ 使用样例

使用流程

一: 模型准备

Paddle Inference目前支持的模型结构为PaddlePaddle深度学习框架产出的模型格式。因此,在您开始使用 Paddle Inference框架前您需要准备一个由PaddlePaddle框架保存的模型。 如果您手中的模型是由诸如Caffe2、Tensorflow等框架产出的,那么我们推荐您使用 X2Paddle 工具进行模型格式转换。

二: 环境准备

1) Python 环境

安装Python环境有以下三种方式:

  1. 参照 官方主页 的引导进行pip安装。

  2. 参照接下来的 预测库编译 页面进行自行编译。

  3. 使用docker镜像

# 拉取镜像,该镜像预装Paddle 1.8 Python环境
docker pull hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6

export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
export NVIDIA_SMI="-v /usr/bin/nvidia-smi:/usr/bin/nvidia-smi"

docker run $CUDA_SO $DEVICES $NVIDIA_SMI --name trt_open --privileged --security-opt seccomp=unconfined --net=host -v $PWD:/paddle -it hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6 /bin/bash

2) C++ 环境

获取c++预测库有以下三种方式:

  1. 官网 下载预编译库

  2. 使用docker镜像

# 拉取镜像,在容器内主目录~/下存放c++预编译库。
docker pull hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6

export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
export NVIDIA_SMI="-v /usr/bin/nvidia-smi:/usr/bin/nvidia-smi"

docker run $CUDA_SO $DEVICES $NVIDIA_SMI --name trt_open --privileged --security-opt seccomp=unconfined --net=host -v $PWD:/paddle -it hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6 /bin/bash
  1. 参照接下来的 `预测库编译 <./source_compile.html>`_页面进行自行编译。

三:使用Paddle Inference执行预测

使用Paddle Inference进行推理部署的流程如下所示。

https://ai-studio-static-online.cdn.bcebos.com/10d5cee239374bd59e41283b3233f49dc306109da9d540b48285980810ab4e36
  1. 配置推理选项。 AnalysisConfig 是飞桨提供的配置管理器API。在使用Paddle Inference进行推理部署过程中,需要使用 AnalysisConfig 详细地配置推理引擎参数,包括但不限于在何种设备(CPU/GPU)上部署( config.EnableUseGPU )、加载模型路径、开启/关闭计算图分析优化、使用MKLDNN/TensorRT进行部署的加速等。参数的具体设置需要根据实际需求来定。

  2. 创建 AnalysisPredictorAnalysisPredictor 是Paddle Inference提供的推理引擎。你只需要简单的执行一行代码即可完成预测引擎的初始化 std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config) ,config为1步骤中创建的 AnalysisConfig

  3. 准备输入数据。执行 auto input_names = predictor->GetInputNames() ,您会获取到模型所有输入tensor的名字,同时通过执行 auto tensor = predictor->GetInputTensor(input_names[i]) ; 您可以获取第i个输入的tensor,通过 tensor->copy_from_cpu(data) 方式,将data中的数据拷贝到tensor中。

  4. 调用predictor->ZeroCopyRun()执行推理。

  5. 获取推理输出。执行 auto out_names = predictor->GetOutputNames() ,您会获取到模型所有输出tensor的名字,同时通过执行 auto tensor = predictor->GetOutputTensor(out_names[i]) ; 您可以获取第i个输出的tensor。通过 tensor->copy_to_cpu(data) 将tensor中的数据copy到data指针上

源码编译

什么时候需要源码编译?

深度学习的发展十分迅速,对科研或工程人员来说,可能会遇到一些需要自己开发op的场景,可以在python层面编写op,但如果对性能有严格要求的话则必须在C++层面开发op,对于这种情况,需要用户源码编译飞桨,使之生效。 此外对于绝大多数使用C++将模型部署上线的工程人员来说,您可以直接通过飞桨官网下载已编译好的预测库,快捷开启飞桨使用之旅。飞桨官网 提供了多个不同环境下编译好的预测库。如果用户环境与官网提供环境不一致(如cuda 、cudnn、tensorrt版本不一致等),或对飞桨源代码有修改需求,或希望进行定制化构建,可查阅本文档自行源码编译得到预测库。

编译原理

一:目标产物

飞桨框架的源码编译包括源代码的编译和链接,最终生成的目标产物包括:

  • 含有 C++ 接口的头文件及其二进制库:用于C++环境,将文件放到指定路径即可开启飞桨使用之旅。

  • Python Wheel 形式的安装包:用于Python环境,此安装包需要参考 飞桨安装教程 进行安装操作。也就是说,前面讲的pip安装属于在线安装,这里属于本地安装。

二:基础概念

飞桨主要由C++语言编写,通过pybind工具提供了Python端的接口,飞桨的源码编译主要包括编译和链接两步。 * 编译过程由编译器完成,编译器以编译单元(后缀名为 .cc 或 .cpp 的文本文件)为单位,将 C++ 语言 ASCII 源代码翻译为二进制形式的目标文件。一个工程通常由若干源码文件组织得到,所以编译完成后,将生成一组目标文件。 * 链接过程使分离编译成为可能,由链接器完成。链接器按一定规则将分离的目标文件组合成一个能映射到内存的二进制程序文件,并解析引用。由于这个二进制文件通常包含源码中指定可被外部用户复用的函数接口,所以也被称作函数库。根据链接规则不同,链接可分为静态和动态链接。静态链接对目标文件进行归档;动态链接使用地址无关技术,将链接放到程序加载时进行。 配合包含声明体的头文件(后缀名为 .h 或 .hpp),用户可以复用程序库中的代码开发应用。静态链接构建的应用程序可独立运行,而动态链接程序在加载运行时需到指定路径下搜寻其依赖的二进制库。

三:编译方式

飞桨框架的设计原则之一是满足不同平台的可用性。然而,不同操作系统惯用的编译和链接器是不一样的,使用它们的命令也不一致。比如,Linux 一般使用 GNU 编译器套件(GCC),Windows 则使用 Microsoft Visual C++(MSVC)。为了统一编译脚本,飞桨使用了支持跨平台构建的 CMake,它可以输出上述编译器所需的各种 Makefile 或者 Project 文件。 为方便编译,框架对常用的CMake命令进行了封装,如仿照 Bazel工具封装了 cc_binary 和 cc_library ,分别用于可执行文件和库文件的产出等,对CMake感兴趣的同学可在 cmake/generic.cmake 中查看具体的实现逻辑。Paddle的CMake中集成了生成python wheel包的逻辑,对如何生成wheel包感兴趣的同学可参考 相关文档

编译步骤

飞桨分为 CPU 版本和 GPU 版本。如果您的计算机没有 Nvidia GPU,请选择 CPU 版本构建安装。如果您的计算机含有 Nvidia GPU( 1.0 且预装有 CUDA / CuDNN,也可选择 GPU 版本构建安装。本节简述飞桨在常用环境下的源码编译方式,欢迎访问飞桨官网获取更详细内容。请阅读本节内容。

推荐配置及依赖项

1、稳定的互联网连接,主频 1 GHz 以上的多核处理器,9 GB 以上磁盘空间。 2、Python 版本 2.7 或 3.5 以上,pip 版本 9.0 及以上;CMake v3.5 及以上;Git 版本 2.17 及以上。请将可执行文件放入系统环境变量中以方便运行。 3、GPU 版本额外需要 Nvidia CUDA 9 / 10,CuDNN v7 及以上版本。根据需要还可能依赖 NCCL 和 TensorRT。

基于Ubuntu 18.04

一:环境准备

除了本节开头提到的依赖,在 Ubuntu 上进行飞桨的源码编译,您还需要准备 GCC8 编译器等工具,可使用下列命令安装:

sudo apt-get install gcc g++ make cmake git vim unrar python3 python3-dev python3-pip swig wget patchelf libopencv-dev
pip3 install numpy protobuf wheel setuptools

若需启用 cuda 加速,需准备 cuda、cudnn、nccl。上述工具的安装请参考 nvidia 官网,以 cuda10.1,cudnn7.6 为例配置 cuda 环境。

# cuda
sh cuda_10.1.168_418.67_linux.run
export PATH=/usr/local/cuda-10.1/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-10.1/${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

# cudnn
tar -xzvf cudnn-10.1-linux-x64-v7.6.4.38.tgz
sudo cp -a cuda/include/cudnn.h /usr/local/cuda/include/
sudo cp -a cuda/lib64/libcudnn* /usr/local/cuda/lib64/

# nccl
# install nccl local deb 参考https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html
sudo dpkg -i nccl-repo-ubuntu1804-2.5.6-ga-cuda10.1_1-1_amd64.deb
# 根据安装提示,还需要执行sudo apt-key add /var/nccl-repo-2.5.6-ga-cuda10.1/7fa2af80.pub
sudo apt update
sudo apt install libnccl2 libnccl-dev

sudo ldconfig

编译飞桨过程中可能会打开很多文件,Ubuntu 18.04 默认设置最多同时打开的文件数是1024(参见 ulimit -a),需要更改这个设定值。

在 /etc/security/limits.conf 文件中添加两行。

* hard noopen 102400
* soft noopen 102400

重启计算机,重启后执行以下指令,请将${user}切换成当前用户名。

su ${user}
ulimit -n 102400

二:编译命令

使用 Git 将飞桨代码克隆到本地,并进入目录,切换到稳定版本(git tag显示的标签名,如v1.7.1)。 飞桨使用 develop 分支进行最新特性的开发,使用 release 分支发布稳定版本。在 GitHub 的 Releases 选项卡中,可以看到飞桨版本的发布记录。

git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
git checkout v1.7.1

下面以 GPU 版本为例说明编译命令。其他环境可以参考“CMake编译选项表”修改对应的cmake选项。比如,若编译 CPU 版本,请将 WITH_GPU 设置为 OFF。

# 创建并进入 build 目录
mkdir build_cuda && cd build_cuda
# 执行cmake指令
cmake -DPY_VERSION=3 \
        -DWITH_TESTING=OFF \
        -DWITH_MKL=ON \
        -DWITH_GPU=ON \
        -DON_INFER=ON \
        -DCMAKE_BUILD_TYPE=RelWithDebInfo \
        ..

使用make编译

make -j4

编译成功后可在dist目录找到生成的.whl包

pip3 install python/dist/paddlepaddle-1.7.1-cp36-cp36m-linux_x86_64.whl

预测库编译

make inference_lib_dist -j4

cmake编译环境表

以下介绍的编译方法都是通用步骤,根据环境对应修改cmake选项即可。

选项

说明

默认值

WITH_GPU

是否支持GPU

ON

WITH_AVX

是否编译含有AVX指令集的飞桨二进制文件

ON

WITH_PYTHON

是否内嵌PYTHON解释器并编译Wheel安装包

ON

WITH_TESTING

是否开启单元测试

OFF

WITH_MKL

是否使用MKL数学库,如果为否,将使用OpenBLAS

ON

WITH_SYSTEM_BLAS

是否使用系统自带的BLAS

OFF

WITH_DISTRIBUTE

是否编译带有分布式的版本

OFF

WITH_BRPC_RDMA

是否使用BRPC,RDMA作为RPC协议

OFF

ON_INFER

是否打开预测优化

OFF

CUDA_ARCH_NAME

是否只针对当前CUDA架构编译

All:编译所有可支持的CUDA架构;Auto:自动识别当前环境的架构编译

TENSORRT_ROOT

TensorRT_lib的路径,该路径指定后会编译TRT子图功能eg:/paddle/nvidia/TensorRT/

/usr

基于Windows 10

一:环境准备

除了本节开头提到的依赖,在 Windows 10 上编译飞桨,您还需要准备 Visual Studio 2015 Update3 以上版本。本节以 Visual Studio 企业版 2019(C++ 桌面开发,含 MSVC 14.24)、Python 3.8 为例介绍编译过程。

在命令提示符输入下列命令,安装必需的 Python 组件。

pip3 install numpy protobuf wheel`

二:编译命令

使用 Git 将飞桨代码克隆到本地,并进入目录,切换到稳定版本(git tag显示的标签名,如v1.7.1)。 飞桨使用 develop 分支进行最新特性的开发,使用 release 分支发布稳定版本。在 GitHub 的 Releases 选项卡中,可以看到 Paddle 版本的发布记录。

git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
git checkout v1.7.1

创建一个构建目录,并在其中执行 CMake,生成解决方案文件 Solution File,以编译 CPU 版本为例说明编译命令,其他环境可以参考“CMake编译选项表”修改对应的cmake选项。

mkdir build
cd build
cmake .. -G "Visual Studio 16 2019" -A x64 -DWITH_GPU=OFF -DWITH_TESTING=OFF
        -DCMAKE_BUILD_TYPE=Release -DPY_VERSION=3
https://agroup-bos.cdn.bcebos.com/1b21aff9424cb33a98f2d1e018d8301614caedda

使用 Visual Studio 打开解决方案文件,在窗口顶端的构建配置菜单中选择 Release x64,单击生成解决方案,等待构建完毕即可。

cmake编译环境表

选项

说明

默认值

WITH_GPU

是否支持GPU

ON

WITH_AVX

是否编译含有AVX指令集的飞桨二进制文件

ON

WITH_PYTHON

是否内嵌PYTHON解释器并编译Wheel安装包

ON

WITH_TESTING

是否开启单元测试

OFF

WITH_MKL

是否使用MKL数学库,如果为否,将使用OpenBLAS

ON

WITH_SYSTEM_BLAS

是否使用系统自带的BLAS

OFF

WITH_DISTRIBUTE

是否编译带有分布式的版本

OFF

WITH_BRPC_RDMA

是否使用BRPC,RDMA作为RPC协议

OFF

ON_INFER

是否打开预测优化

OFF

CUDA_ARCH_NAME

是否只针对当前CUDA架构编译

All:编译所有可支持的CUDA架构;Auto:自动识别当前环境的架构编译

TENSORRT_ROOT

TensorRT_lib的路径,该路径指定后会编译TRT子图功能eg:/paddle/nvidia/TensorRT/

/usr

结果验证

一:python whl包

编译完毕后,会在 python/dist 目录下生成一个文件名类似 paddlepaddle-1.7.1-cp36-cp36m-linux_x86_64.whl 的 Python Wheel 安装包,安装测试的命令为:

pip3 install python/dist/paddlepaddle-1.7.1-cp36-cp36m-linux_x86_64.whl

安装完成后,可以使用 python3 进入python解释器,输入以下指令,出现 `Your Paddle Fluid is installed succesfully! ` ,说明安装成功。

import paddle.fluid as fluid
fluid.install_check.run_check()

二:c++ lib

预测库编译后,所有产出均位于build目录下的fluid_inference_install_dir目录内,目录结构如下。version.txt 中记录了该预测库的版本信息,包括Git Commit ID、使用OpenBlas或MKL数学库、CUDA/CUDNN版本号。

build/fluid_inference_install_dir
├── CMakeCache.txt
├── paddle
│   ├── include
│   │   ├── paddle_anakin_config.h
│   │   ├── paddle_analysis_config.h
│   │   ├── paddle_api.h
│   │   ├── paddle_inference_api.h
│   │   ├── paddle_mkldnn_quantizer_config.h
│   │   └── paddle_pass_builder.h
│   └── lib
│       ├── libpaddle_fluid.a (Linux)
│       ├── libpaddle_fluid.so (Linux)
│       └── libpaddle_fluid.lib (Windows)
├── third_party
│   ├── boost
│   │   └── boost
│   ├── eigen3
│   │   ├── Eigen
│   │   └── unsupported
│   └── install
│       ├── gflags
│       ├── glog
│       ├── mkldnn
│       ├── mklml
│       ├── protobuf
│       ├── xxhash
│       └── zlib
└── version.txt

Include目录下包括了使用飞桨预测库需要的头文件,lib目录下包括了生成的静态库和动态库,third_party目录下包括了预测库依赖的其它库文件。

您可以编写应用代码,与预测库联合编译并测试结果。请参 C++ 预测库 API 使用 一节。

使用Python预测

Paddle Inference提供了高度优化的Python 和C++ API预测接口,本篇文档主要介绍Python API,使用C++ API进行预测的文档可以参考可以参考 这里

下面是详细的使用说明。

使用Python预测API预测包含以下几个主要步骤:

  • 配置推理选项

  • 创建Predictor

  • 准备模型输入

  • 模型推理

  • 获取模型输出

我们先从一个简单程序入手,介绍这一流程:

def create_predictor():
        # 通过AnalysisConfig配置推理选项
        config = AnalysisConfig("./resnet50/model", "./resnet50/params")
        config.switch_use_feed_fetch_ops(False)
        config.enable_use_gpu(100, 0)
        config.enable_mkldnn()
        config.enable_memory_optim()
        predictor = create_paddle_predictor(config)
        return predictor

def run(predictor, data):
        # 准备模型输入
        input_names = predictor.get_input_names()
        for i,  name in enumerate(input_names):
                input_tensor = predictor.get_input_tensor(name)
                input_tensor.reshape(data[i].shape)
                input_tensor.copy_from_cpu(data[i].copy())

        # 执行模型推理
        predictor.zero_copy_run()

        results = []
        # 获取模型输出
        output_names = predictor.get_output_names()
        for i, name in enumerate(output_names):
                output_tensor = predictor.get_output_tensor(name)
                output_data = output_tensor.copy_to_cpu()
                results.append(output_data)

        return results

以上的程序中 create_predictor 函数对推理过程进行了配置以及创建了Predictor。 run 函数进行了输入数据的准备、模型推理以及输出数据的获取过程。

在接下来的部分中,我们会依次对程序中出现的AnalysisConfig,Predictor,模型输入,模型输出进行详细的介绍。

一、推理配置管理器AnalysisConfig

AnalysisConfig管理AnalysisPredictor的推理配置,提供了模型路径设置、推理引擎运行设备选择以及多种优化推理流程的选项。配置中包括了必选配置以及可选配置。

1. 必选配置

a.设置模型和参数路径

  • Non-combined形式:模型文件夹 model_dir 下存在一个模型文件和多个参数文件时,传入模型文件夹路径,模型文件名默认为__model__。 使用方式为: config.set_model(“./model_dir”)

  • Combined形式:模型文件夹 model_dir 下只有一个模型文件 model 和一个参数文件params时,传入模型文件和参数文件路径。使用方式为: config.set_model(“./model_dir/model”, “./model_dir/params”)

  • 内存加载模式:如果模型是从内存加载,可以使用:

    import os
    model_buffer = open('./resnet50/model','rb')
    params_buffer = open('./resnet50/params','rb')
    model_size = os.fstat(model_buffer.fileno()).st_size
    params_size = os.fstat(params_buffer.fileno()).st_size
    config.set_model_buffer(model_buffer.read(), model_size, params_buffer.read(), params_size)
    

关于 non-combined 以及 combined 模型介绍,请参照 这里

b. 关闭feed与fetch OP

config.switch_use_feed_fetch_ops(False) # 关闭feed和fetch OP

2. 可选配置

a. 加速CPU推理

# 开启MKLDNN,可加速CPU推理,要求预测库带MKLDNN功能。
config.enable_mkldnn()
# 可以设置CPU数学库线程数math_threads,可加速推理。
# 注意:math_threads * 外部线程数 需要小于总的CPU的核心数目,否则会影响预测性能。
config.set_cpu_math_library_num_threads(10)

b. 使用GPU推理

# enable_use_gpu后,模型将运行在GPU上。
# 第一个参数表示预先分配显存数目,第二个参数表示设备的ID。
config.enable_use_gpu(100, 0)

如果使用的预测lib带Paddle-TRT子图功能,可以打开TRT选项进行加速:

# 开启TensorRT推理,可提升GPU推理性能,需要使用带TensorRT的推理库
config.enable_tensorrt_engine(1 << 30,    # workspace_size
                batch_size,    # max_batch_size
                3,    # min_subgraph_size
                AnalysisConfig.Precision.Float32,    # precision
                False,    # use_static
                False,    # use_calib_mode
                )

通过计算图分析,Paddle可以自动将计算图中部分子图融合,并调用NVIDIA的 TensorRT 来进行加速。 使用Paddle-TensorRT 预测的完整方法可以参考 这里

c. 内存/显存优化

config.enable_memory_optim()  # 开启内存/显存复用

该配置设置后,在模型图分析阶段会对图中的变量进行依赖分类,两两互不依赖的变量会使用同一块内存/显存空间,缩减了运行时的内存/显存占用(模型较大或batch较大时效果显著)。

d. debug开关

# 该配置设置后,会关闭模型图分析阶段的任何图优化,预测期间运行同训练前向代码一致。
config.switch_ir_optim(False)
# 该配置设置后,会在模型图分析的每个阶段后保存图的拓扑信息到.dot文件中,该文件可用graphviz可视化。
config.switch_ir_debug(True)

二、预测器PaddlePredictor

PaddlePredictor 是在模型上执行推理的预测器,根据AnalysisConfig中的配置进行创建。

predictor = create_paddle_predictor(config)

create_paddle_predictor 期间首先对模型进行加载,并且将模型转换为由变量和运算节点组成的计算图。接下来将进行一系列的图优化,包括OP的横向纵向融合,删除无用节点,内存/显存优化,以及子图(Paddle-TRT)的分析,加速推理性能,提高吞吐。

三:输入/输出

1.准备输入

a. 获取模型所有输入的Tensor名字

input_names = predictor.get_input_names()

b. 获取对应名字下的Tensor

# 获取第0个输入
input_tensor = predictor.get_input_tensor(input_names[0])

c. 将输入数据copy到Tensor中

# 在copy前需要设置Tensor的shape
input_tensor.reshape((batch_size, channels, height, width))
# Tensor会根据上述设置的shape从input_data中拷贝对应数目的数据。input_data为numpy数组。
input_tensor.copy_from_cpu(input_data)

2.获取输出

a. 获取模型所有输出的Tensor名字

b. 获取对应名字下的Tensor

# 获取第0个输出
output_tensor = predictor.get_output_tensor(ouput_names[0])

c. 将数据copy到Tensor中

# output_data为numpy数组
output_data = output_tensor.copy_to_cpu()

下一步

看到这里您是否已经对 Paddle Inference 的 Python API 使用有所了解了呢?请访问 这里 进行样例测试。

使用C++预测

为了简单方便地进行推理部署,飞桨提供了一套高度优化的C++ API推理接口。下面对各主要API使用方法进行详细介绍。

使用流程 一节中,我们了解到Paddle Inference预测包含了以下几个方面:

  • 配置推理选项

  • 创建predictor

  • 准备模型输入

  • 模型推理

  • 获取模型输出

那我们先用一个简单的程序介绍这一过程:

std::unique_ptr<paddle::PaddlePredictor> CreatePredictor() {
        // 通过AnalysisConfig配置推理选项
        AnalysisConfig config;
        config.SetModel(“./resnet50/model”,
                     "./resnet50/params");
        config.EnableUseGpu(100, 0);
        config.SwitchUseFeedFetchOps(false);
        config.EnableMKLDNN();
        config.EnableMemoryOptim();
        // 创建predictor
        return CreatePaddlePredictor(config);
}

void Run(paddle::PaddlePredictor *predictor,
                const std::vector<float>& input,
                const std::vector<int>& input_shape,
                std::vector<float> *out_data) {
        // 准备模型的输入
        int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int>());

        auto input_names = predictor->GetInputNames();
        auto input_t = predictor->GetInputTensor(input_names[0]);
        input_t->Reshape(input_shape);
        input_t->copy_from_cpu(input.data());
        // 模型推理
        CHECK(predictor->ZeroCopyRun());

        // 获取模型的输出
        auto output_names = predictor->GetOutputNames();
        // there is only one output of Resnet50
        auto output_t = predictor->GetOutputTensor(output_names[0]);
        std::vector<int> output_shape = output_t->shape();
        int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
        out_data->resize(out_num);
        output_t->copy_to_cpu(out_data->data());
}

以上的程序中 CreatePredictor 函数对推理过程进行了配置以及创建了Predictor。 Run 函数进行了输入数据的准备、模型推理以及输出数据的获取过程。

接下来我们依次对程序中出现的AnalysisConfig,Predictor,模型输入,模型输出做一个详细的介绍。

一:关于AnalysisConfig

AnalysisConfig管理AnalysisPredictor的推理配置,提供了模型路径设置、推理引擎运行设备选择以及多种优化推理流程的选项。配置中包括了必选配置以及可选配置。

1. 必选配置

a. 设置模型和参数路径

从磁盘加载模型时,根据模型和参数文件存储方式不同,设置AnalysisConfig加载模型和参数的路径有两种形式:

  • non-combined形式 :模型文件夹model_dir下存在一个模型文件和多个参数文件时,传入模型文件夹路径,模型文件名默认为__model__。 使用方式为: config->SetModel(“./model_dir”);。

  • combined形式 :模型文件夹model_dir下只有一个模型文件`model`和一个参数文件params时,传入模型文件和参数文件路径。 使用方式为: config->SetModel(“./model_dir/model”, “./model_dir/params”);

  • 内存加载模式:如果模型是从内存加载(模型必须为combined形式),可以使用

std::ifstream in_m(FLAGS_dirname + "/model");
std::ifstream in_p(FLAGS_dirname + "/params");
std::ostringstream os_model, os_param;
os_model << in_m.rdbuf();
os_param << in_p.rdbuf();
config.SetModelBuffer(os_model.str().data(), os_model.str().size(), os_param.str().data(), os_param.str().size());

Paddle Inference有两种格式的模型,分别为 non-combined 以及 combined 。这两种类型我们在 Quick Start 一节中提到过,忘记的同学可以回顾下。

b. 关闭Feed,Fetch op

config->SwitchUseFeedFetchOps(false); // 关闭feed和fetch OP使用,使用ZeroCopy接口必须设置此项`

我们用一个小的例子来说明我们为什么要关掉它们。 假设我们有一个模型,模型运行的序列为: input -> FEED_OP -> feed_out -> CONV_OP -> conv_out -> FETCH_OP -> output

序列中大些字母的FEED_OP, CONV_OP, FETCH_OP 为模型中的OP, 小写字母的input,feed_out,output 为模型中的变量。

在ZeroCopy模式下,我们通过 predictor->GetInputTensor(input_names[0]) 获取的模型输入为FEED_OP的输出, 即feed_out,我们通过 predictor->GetOutputTensor(output_names[0]) 接口获取的模型输出为FETCH_OP的输入,即conv_out,这种情况下,我们在运行期间就没有必要运行feed和fetch OP了,因此需要设置 config->SwitchUseFeedFetchOps(false) 来关闭feed和fetch op。

2. 可选配置

a. 加速CPU推理

// 开启MKLDNN,可加速CPU推理,要求预测库带MKLDNN功能。
config->EnableMKLDNN();
// 可以设置CPU数学库线程数math_threads,可加速推理。
// 注意:math_threads * 外部线程数 需要小于总的CPU的核心数目,否则会影响预测性能。
config->SetCpuMathLibraryNumThreads(10);

b. 使用GPU推理

// EnableUseGpu后,模型将运行在GPU上。
// 第一个参数表示预先分配显存数目,第二个参数表示设备的ID。
config->EnableUseGpu(100, 0);

如果使用的预测lib带Paddle-TRT子图功能,可以打开TRT选项进行加速, 详细的请访问 Paddle-TensorRT文档

// 开启TensorRT推理,可提升GPU推理性能,需要使用带TensorRT的推理库
config->EnableTensorRtEngine(1 << 30      /*workspace_size*/,
                                                        batch_size        /*max_batch_size*/,
                                                        3                 /*min_subgraph_size*/,
                                                        AnalysisConfig::Precision::kFloat32 /*precision*/,
                                                        false             /*use_static*/,
                                                        false             /*use_calib_mode*/);

通过计算图分析,Paddle可以自动将计算图中部分子图融合,并调用NVIDIA的 TensorRT 来进行加速。

c. 内存/显存优化

config->EnableMemoryOptim();  // 开启内存/显存复用

该配置设置后,在模型图分析阶段会对图中的变量进行依赖分类,两两互不依赖的变量会使用同一块内存/显存空间,缩减了运行时的内存/显存占用(模型较大或batch较大时效果显著)。

d. debug开关

// 该配置设置后,会关闭模型图分析阶段的任何图优化,预测期间运行同训练前向代码一致。
config->SwitchIrOptim(false);
// 该配置设置后,会在模型图分析的每个阶段后保存图的拓扑信息到.dot文件中,该文件可用graphviz可视化。
config->SwitchIrDebug();

二:关于PaddlePredictor

PaddlePredictor 是在模型上执行推理的预测器,根据AnalysisConfig中的配置进行创建。

std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config);

CreatePaddlePredictor 期间首先对模型进行加载,并且将模型转换为由变量和运算节点组成的计算图。接下来将进行一系列的图优化,包括OP的横向纵向融合,删除无用节点,内存/显存优化,以及子图(Paddle-TRT)的分析,加速推理性能,提高吞吐。

三:输入输出

1. 准备输入

a. 获取模型所有输入的tensor名字

std::vector<std::string> input_names = predictor->GetInputNames();

b. 获取对应名字下的tensor

// 获取第0个输入
auto input_t = predictor->GetInputTensor(input_names[0]);

c. 将数据copy到tensor中

// 在copy前需要设置tensor的shape
input_t->Reshape({batch_size, channels, height, width});
// tensor会根据上述设置的shape从input_data中拷贝对应数目的数据到tensor中。
input_t->copy_from_cpu<float>(input_data /*数据指针*/);

当然我们也可以用mutable_data获取tensor的数据指针:

// 参数可为PaddlePlace::kGPU, PaddlePlace::kCPU
float *input_d = input_t->mutable_data<float>(PaddlePlace::kGPU);

2. 获取输出

a. 获取模型所有输出的tensor名字

std::vector<std::string> out_names = predictor->GetOutputNames();

b. 获取对应名字下的tensor

// 获取第0个输出
auto output_t = predictor->GetOutputTensor(out_names[0]);

c. 将数据copy到tensor中

std::vector<float> out_data;
// 获取输出的shpae
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,      std::multiplies<int>());
out_data->resize(out_num);
output_t->copy_to_cpu(out_data->data());

我们可以用data接口获取tensor的数据指针:

// 参数可为PaddlePlace::kGPU, PaddlePlace::kCPU
int output_size;
float *output_d = output_t->data<float>(PaddlePlace::kGPU, &output_size);

下一步

看到这里您是否已经对Paddle Inference的C++使用有所了解了呢?请访问 这里 进行样例测试。

使用Paddle-TensorRT库预测

NVIDIA TensorRT 是一个高性能的深度学习预测库,可为深度学习推理应用程序提供低延迟和高吞吐量。PaddlePaddle 采用子图的形式对TensorRT进行了集成,即我们可以使用该模块来提升Paddle模型的预测性能。在这篇文章中,我们会介绍如何使用Paddle-TRT子图加速预测。

概述

当模型加载后,神经网络可以表示为由变量和运算节点组成的计算图。如果我们打开TRT子图模式,在图分析阶段,Paddle会对模型图进行分析同时发现图中可以使用TensorRT优化的子图,并使用TensorRT节点替换它们。在模型的推断期间,如果遇到TensorRT节点,Paddle会调用TensorRT库对该节点进行优化,其他的节点调用Paddle的原生实现。TensorRT除了有常见的OP融合以及显存/内存优化外,还针对性的对OP进行了优化加速实现,降低预测延迟,提升推理吞吐。

目前Paddle-TRT支持静态shape模式以及/动态shape模式。在静态shape模式下支持图像分类,分割,检测模型,同时也支持Fp16, Int8的预测加速。在动态shape模式下,除了对动态shape的图像模型(FCN, Faster rcnn)支持外,同时也对NLP的Bert/Ernie模型也进行了支持。

Paddle-TRT的现有能力:

1)静态shape:

支持模型:

分类模型

检测模型

分割模型

Mobilenetv1

yolov3

ICNET

Resnet50

SSD

UNet

Vgg16

Mask-rcnn

FCN

Resnext

Faster-rcnn

AlexNet

Cascade-rcnn

Se-ResNext

Retinanet

GoogLeNet

Mobilenet-SSD

DPN

Fp16:

Calib Int8:

优化信息序列化:

加载PaddleSlim Int8模型:

2)动态shape:

支持模型:

图像

NLP

FCN

Bert

Faster_RCNN

Ernie

Fp16:

Calib Int8:

优化信息序列化:

加载PaddleSlim Int8模型:

Note:

  1. 从源码编译时,TensorRT预测库目前仅支持使用GPU编译,且需要设置编译选项TENSORRT_ROOT为TensorRT所在的路径。

  2. Windows支持需要TensorRT 版本5.0以上。

  3. 使用Paddle-TRT的动态shape输入功能要求TRT的版本在6.0以上。

一:环境准备

使用Paddle-TRT功能,我们需要准备带TRT的Paddle运行环境,我们提供了以下几种方式:

1)linux下通过pip安装

# 该whl包依赖cuda10.1, cudnnv7.6, tensorrt6.0 的lib, 需自行下载安装并设置lib路径到LD_LIBRARY_PATH中
wget https://paddle-inference-dist.bj.bcebos.com/libs/paddlepaddle_gpu-1.8.0-cp27-cp27mu-linux_x86_64.whl
pip install -U paddlepaddle_gpu-1.8.0-cp27-cp27mu-linux_x86_64.whl

如果您想在Nvidia Jetson平台上使用,请点击此 链接 下载whl包,然后通过pip 安装。

2)使用docker镜像

# 拉取镜像,该镜像预装Paddle 1.8 Python环境,并包含c++的预编译库,lib存放在主目录~/ 下。
docker pull hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6

export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
export NVIDIA_SMI="-v /usr/bin/nvidia-smi:/usr/bin/nvidia-smi"

docker run $CUDA_SO $DEVICES $NVIDIA_SMI --name trt_open --privileged --security-opt seccomp=unconfined --net=host -v $PWD:/paddle -it hub.baidubce.com/paddlepaddle/paddle:1.8.0-gpu-cuda10.0-cudnn7-trt6 /bin/bash

3)手动编译 编译的方式请参照 编译文档

Note1: cmake 期间请设置 TENSORRT_ROOT (即TRT lib的路径), WITH_PYTHON (是否产出python whl包, 设置为ON)选项。

Note2: 编译期间会出现TensorRT相关的错误。

需要手动在 NvInfer.h (trt5) 或 NvInferRuntime.h (trt6) 文件中为 class IPluginFactory 和 class IGpuAllocator 分别添加虚析构函数:

virtual ~IPluginFactory() {};
virtual ~IGpuAllocator() {};

需要将 NvInferRuntime.h (trt6)中的 protected: ~IOptimizationProfile() noexcept = default;

改为

virtual ~IOptimizationProfile() noexcept = default;

二:API使用介绍

使用流程 一节中,我们了解到Paddle Inference预测包含了以下几个方面:

  • 配置推理选项

  • 创建predictor

  • 准备模型输入

  • 模型推理

  • 获取模型输出

使用Paddle-TRT 也是遵照这样的流程。我们先用一个简单的例子来介绍这一流程(我们假设您已经对Paddle Inference有一定的了解,如果您刚接触Paddle Inference,请访问 这里 对Paddle Inference有个初步认识。):

import numpy as np
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor

def create_predictor():
        # config = AnalysisConfig("")
        config = AnalysisConfig("./resnet50/model", "./resnet50/params")
        config.switch_use_feed_fetch_ops(False)
        config.enable_memory_optim()
        config.enable_use_gpu(1000, 0)

        # 打开TensorRT。此接口的详细介绍请见下文
        config.enable_tensorrt_engine(workspace_size = 1<<30,
                max_batch_size=1, min_subgraph_size=5,
                precision_mode=AnalysisConfig.Precision.Float32,
                use_static=False, use_calib_mode=False)

        predictor = create_paddle_predictor(config)
        return predictor

def run(predictor, img):
        # 准备输入
        input_names = predictor.get_input_names()
        for i,  name in enumerate(input_names):
                input_tensor = predictor.get_input_tensor(name)
                input_tensor.reshape(img[i].shape)
                input_tensor.copy_from_cpu(img[i].copy())
        # 预测
        predictor.zero_copy_run()
        results = []
        # 获取输出
        output_names = predictor.get_output_names()
        for i, name in enumerate(output_names):
                output_tensor = predictor.get_output_tensor(name)
                output_data = output_tensor.copy_to_cpu()
                results.append(output_data)
        return results

        if __name__ == '__main__':
                pred = create_predictor()
                img = np.ones((1, 3, 224, 224)).astype(np.float32)
                result = run(pred, [img])
                print ("class index: ", np.argmax(result[0][0]))

通过例子我们可以看出,我们通过 enable_tensorrt_engine 接口来打开TensorRT选项的。

config.enable_tensorrt_engine(
        workspace_size = 1<<30,
        max_batch_size=1, min_subgraph_size=5,
        precision_mode=AnalysisConfig.Precision.Float32,
        use_static=False, use_calib_mode=False)

接下来让我们看下该接口中各个参数的作用:

  • workspace_size,类型:int,默认值为1 << 30 (1G)。指定TensorRT使用的工作空间大小,TensorRT会在该大小限制下筛选最优的kernel执行预测运算。

  • max_batch_size,类型:int,默认值为1。需要提前设置最大的batch大小,运行时batch大小不得超过此限定值。

  • min_subgraph_size,类型:int,默认值为3。Paddle-TRT是以子图的形式运行,为了避免性能损失,当子图内部节点个数大于 min_subgraph_size 的时候,才会使用Paddle-TRT运行。

  • precision_mode,类型:AnalysisConfig.Precision, 默认值为 AnalysisConfig.Precision.Float32。指定使用TRT的精度,支持FP32(Float32),FP16(Half),Int8(Int8)。若需要使用Paddle-TRT int8离线量化校准,需设定precision为 AnalysisConfig.Precision.Int8 , 且设置 use_calib_mode 为True。

  • use_static,类型:bool, 默认值为False。如果指定为True,在初次运行程序的时候会将TRT的优化信息进行序列化到磁盘上,下次运行时直接加载优化的序列化信息而不需要重新生成。

  • use_calib_mode,类型:bool, 默认值为False。若要运行Paddle-TRT int8离线量化校准,需要将此选项设置为True。

Int8量化预测

神经网络的参数在一定程度上是冗余的,在很多任务上,我们可以在保证模型精度的前提下,将Float32的模型转换成Int8的模型,从而达到减小计算量降低运算耗时、降低计算内存、降低模型大小的目的。使用Int8量化预测的流程可以分为两步:1)产出量化模型;2)加载量化模型进行Int8预测。下面我们对使用Paddle-TRT进行Int8量化预测的完整流程进行详细介绍。

1. 产出量化模型

目前,我们支持通过两种方式产出量化模型:

  1. 使用TensorRT自带Int8离线量化校准功能。校准即基于训练好的FP32模型和少量校准数据(如500~1000张图片)生成校准表(Calibration table),预测时,加载FP32模型和此校准表即可使用Int8精度预测。生成校准表的方法如下:

  • 指定TensorRT配置时,将 precision_mode 设置为 AnalysisConfig.Precision.Int8 并且设置 use_calib_modeTrue

    config.enable_tensorrt_engine(
      workspace_size=1<<30,
      max_batch_size=1, min_subgraph_size=5,
      precision_mode=AnalysisConfig.Precision.Int8,
      use_static=False, use_calib_mode=True)
    
  • 准备500张左右的真实输入数据,在上述配置下,运行模型。(Paddle-TRT会统计模型中每个tensor值的范围信息,并将其记录到校准表中,运行结束后,会将校准表写入模型目录下的 _opt_cache 目录中)

如果想要了解使用TensorRT自带Int8离线量化校准功能生成校准表的完整代码,请参考 这里 的demo。

  1. 使用模型压缩工具库PaddleSlim产出量化模型。PaddleSlim支持离线量化和在线量化功能,其中,离线量化与TensorRT离线量化校准原理相似;在线量化又称量化训练(Quantization Aware Training, QAT),是基于较多数据(如>=5000张图片)对预训练模型进行重新训练,使用模拟量化的思想,在训练阶段更新权重,实现减小量化误差的方法。使用PaddleSlim产出量化模型可以参考文档:

离线量化的优点是无需重新训练,简单易用,但量化后精度可能受影响;量化训练的优点是模型精度受量化影响较小,但需要重新训练模型,使用门槛稍高。在实际使用中,我们推荐先使用TRT离线量化校准功能生成量化模型,若精度不能满足需求,再使用PaddleSlim产出量化模型。

2. 加载量化模型进行Int8预测

加载量化模型进行Int8预测,需要在指定TensorRT配置时,将 precision_mode 设置为 AnalysisConfig.Precision.Int8

若使用的量化模型为TRT离线量化校准产出的,需要将 use_calib_mode 设为 True

config.enable_tensorrt_engine(
  workspace_size=1<<30,
  max_batch_size=1, min_subgraph_size=5,
  precision_mode=AnalysisConfig.Precision.Int8,
  use_static=False, use_calib_mode=True)

完整demo请参考 这里

若使用的量化模型为PaddleSlim量化产出的,需要将 use_calib_mode 设为 False

config.enable_tensorrt_engine(
  workspace_size=1<<30,
  max_batch_size=1, min_subgraph_size=5,
  precision_mode=AnalysisConfig.Precision.Int8,
  use_static=False, use_calib_mode=False)

完整demo请参考 这里

运行Dynamic shape

从1.8 版本开始, Paddle对TRT子图进行了Dynamic shape的支持。 使用接口如下:

config.enable_tensorrt_engine(
        workspace_size = 1<<30,
        max_batch_size=1, min_subgraph_size=5,
        precision_mode=AnalysisConfig.Precision.Float32,
        use_static=False, use_calib_mode=False)

min_input_shape = {"image":[1,3, 10, 10]}
max_input_shape = {"image":[1,3, 224, 224]}
opt_input_shape = {"image":[1,3, 100, 100]}

config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape)

从上述使用方式来看,在 config.enable_tensorrt_engine 接口的基础上,新加了一个config.set_trt_dynamic_shape_info 的接口。

该接口用来设置模型输入的最小,最大,以及最优的输入shape。 其中,最优的shape处于最小最大shape之间,在预测初始化期间,会根据opt shape对op选择最优的kernel。

调用了 config.set_trt_dynamic_shape_info 接口,预测器会运行TRT子图的动态输入模式,运行期间可以接受最小,最大shape间的任意的shape的输入数据。

三:测试样例

我们在github上提供了使用TRT子图预测的更多样例:

  • Python 样例请访问此处 链接

  • C++ 样例地址请访问此处 链接

四:Paddle-TRT子图运行原理

PaddlePaddle采用子图的形式对TensorRT进行集成,当模型加载后,神经网络可以表示为由变量和运算节点组成的计算图。Paddle TensorRT实现的功能是对整个图进行扫描,发现图中可以使用TensorRT优化的子图,并使用TensorRT节点替换它们。在模型的推断期间,如果遇到TensorRT节点,Paddle会调用TensorRT库对该节点进行优化,其他的节点调用Paddle的原生实现。TensorRT在推断期间能够进行Op的横向和纵向融合,过滤掉冗余的Op,并对特定平台下的特定的Op选择合适的kernel等进行优化,能够加快模型的预测速度。

下图使用一个简单的模型展示了这个过程:

原始网络

https://raw.githubusercontent.com/NHZlX/FluidDoc/add_trt_doc/doc/fluid/user_guides/howto/inference/image/model_graph_original.png

转换的网络

https://raw.githubusercontent.com/NHZlX/FluidDoc/add_trt_doc/doc/fluid/user_guides/howto/inference/image/model_graph_trt.png

我们可以在原始模型网络中看到,绿色节点表示可以被TensorRT支持的节点,红色节点表示网络中的变量,黄色表示Paddle只能被Paddle原生实现执行的节点。那些在原始网络中的绿色节点被提取出来汇集成子图,并由一个TensorRT节点代替,成为转换后网络中的 block-25 节点。在网络运行过程中,如果遇到该节点,Paddle将调用TensorRT库来对其执行。

模型可视化

通过 Quick Start 一节中,我们了解到,预测模型包含了两个文件,一部分为模型结构文件,通常以 model__model__ 文件存在;另一部分为参数文件,通常以params 文件或一堆分散的文件存在。

模型结构文件,顾名思义,存储了模型的拓扑结构,其中包括模型中各种OP的计算顺序以及OP的详细信息。很多时候,我们希望能够将这些模型的结构以及内部信息可视化,方便我们进行模型分析。接下来将会通过两种方式来讲述如何对Paddle 预测模型进行可视化。

一: 通过 VisualDL 可视化

1) 安装

VisualDL是飞桨可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等,帮助用户更清晰直观地理解深度学习模型训练过程及模型结构,实现高效的模型优化。 我们可以进入 GitHub主页 进行下载安装。

2)可视化

点击 下载测试模型。

支持两种启动方式:

  • 前端拖拽上传模型文件:

    • 无需添加任何参数,在命令行执行 visualdl 后启动界面上传文件即可:

https://user-images.githubusercontent.com/48054808/88628504-a8b66980-d0e0-11ea-908b-196d02ed1fa2.png
  • 后端透传模型文件:

    • 在命令行加入参数 –model 并指定 模型文件 路径(非文件夹路径),即可启动:

visualdl --model ./log/model --port 8080
https://user-images.githubusercontent.com/48054808/88621327-b664f280-d0d2-11ea-9e76-e3fcfeea4e57.png

Graph功能详细使用,请见 Graph使用指南

二: 通过代码方式生成dot文件

1)pip 安装Paddle

2)生成dot文件

点击 下载测试模型。

#!/usr/bin/env python
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.framework import IrGraph

def get_graph(program_path):
    with open(program_path, 'rb') as f:
            binary_str = f.read()
    program =   fluid.framework.Program.parse_from_string(binary_str)
    return IrGraph(core.Graph(program.desc), for_test=True)

if __name__ == '__main__':
    program_path = './lecture_model/__model__'
    offline_graph = get_graph(program_path)
    offline_graph.draw('.', 'test_model', [])

3)生成svg

Note:需要环境中安装graphviz

dot -Tsvg ./test_mode.dot -o test_model.svg

然后将test_model.svg以浏览器打开预览即可。

https://user-images.githubusercontent.com/5595332/81796500-19b59e80-9540-11ea-8c70-31122e969683.png

模型转换工具 X2Paddle

X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。

X2Paddle 支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。目前X2Paddle支持的模型参考 x2paddle_model_zoo

多框架支持

模型

caffe

tensorflow

onnx

mobilenetv1

Y

Y

F

mobilenetv2

Y

Y

Y

resnet18

Y

Y

F

resnet50

Y

Y

Y

mnasnet

Y

Y

F

efficientnet

Y

Y

Y

squeezenetv1.1

Y

Y

Y

shufflenet

Y

Y

F

mobilenet_ssd

Y

Y

F

mobilenet_yolov3

F

Y

F

inceptionv4

F

F

F

mtcnn

Y

Y

F

facedetection

Y

F

F

unet

Y

Y

F

ocr_attention

F

F

F

vgg16

F

F

F

安装

pip install x2paddle

安装最新版本,可使用如下安装方式

pip install git+https://github.com/PaddlePaddle/X2Paddle.git@develop

使用

Caffe

x2paddle --framework caffe \
        --prototxt model.proto \
        --weight model.caffemodel \
        --save_dir paddle_model

TensorFlow

x2paddle --framework tensorflow \
        --model model.pb \
        --save_dir paddle_model

转换结果说明

在指定的 save_dir 下生成两个目录

  1. inference_model : 模型结构和参数均序列化保存的模型格式

  2. model_with_code : 保存了模型参数文件和模型的python代码

问题反馈

X2Paddle使用时存在问题时,欢迎您将问题或Bug报告以 Github Issues 的形式提交给我们,我们会实时跟进。

性能数据

GPU性能

测试条件

  • 测试模型
    • Mobilenetv1

    • Resnet50

    • Yolov3

    • Unet

    • Bert/Ernie

  • 测试机器
    • P4

    • T4

  • 测试说明
    • 测试Paddle版本:release/1.8

    • warmup=10, repeats=1000,统计平均时间,单位ms。

性能数据

X86 CPU性能

测试条件

性能数据

Library API

Full API

Classes and Structs

Struct AnalysisConfig
Struct Documentation
struct paddle::AnalysisConfig

configuration manager for AnalysisPredictor.

AnalysisConfig manages configurations of AnalysisPredictor. During inference procedure, there are many parameters(model/params path, place of inference, etc.) to be specified, and various optimizations(subgraph fusion, memory optimazation, TensorRT engine, etc.) to be done. Users can manage these settings by creating and modifying an AnalysisConfig, and loading it into AnalysisPredictor.

Since

1.7.0

Public Types

enum Precision

Precision of inference in TensorRT.

Values:

enumerator kFloat32 = 0

fp32

enumerator kInt8

int8

enumerator kHalf

fp16

Public Functions

AnalysisConfig() = default
AnalysisConfig(const AnalysisConfig &other)

Construct a new AnalysisConfig from another AnalysisConfig.

Parameters

AnalysisConfig(const std::string &model_dir)

Construct a new AnalysisConfig from a no-combined model.

Parameters
  • [in] model_dir: model directory of the no-combined model.

AnalysisConfig(const std::string &prog_file, const std::string &params_file)

Construct a new AnalysisConfig from a combined model.

Parameters
  • [in] prog_file: model file path of the combined model.

  • [in] params_file: params file path of the combined model.

void SetModel(const std::string &model_dir)

Set the no-combined model dir path.

Parameters
  • model_dir: model dir path.

void SetModel(const std::string &prog_file_path, const std::string &params_file_path)

Set the combined model with two specific pathes for program and parameters.

Parameters
  • prog_file_path: model file path of the combined model.

  • params_file_path: params file path of the combined model.

void SetProgFile(const std::string &x)

Set the model file path of a combined model.

Parameters
  • x: model file path.

void SetParamsFile(const std::string &x)

Set the params file path of a combined model.

Parameters
  • x: params file path.

void SetOptimCacheDir(const std::string &opt_cache_dir)

Set the path of optimization cache directory.

Parameters
  • opt_cache_dir: the path of optimization cache directory.

const std::string &model_dir() const

Get the model directory path.

Return

const std::string& The model directory path.

const std::string &prog_file() const

Get the program file path.

Return

const std::string& The program file path.

const std::string &params_file() const

Get the combined parameters file.

Return

const std::string& The combined parameters file.

void DisableFCPadding()

Turn off FC Padding.

bool use_fc_padding() const

A boolean state telling whether fc padding is used.

Return

bool Whether fc padding is used.

void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0)

Turn on GPU.

Parameters
  • memory_pool_init_size_mb: initial size of the GPU memory pool in MB.

  • device_id: device_id the GPU card to use (default is 0).

void DisableGpu()

Turn off GPU.

bool use_gpu() const

A boolean state telling whether the GPU is turned on.

Return

bool Whether the GPU is turned on.

int gpu_device_id() const

Get the GPU device id.

Return

int The GPU device id.

int memory_pool_init_size_mb() const

Get the initial size in MB of the GPU memory pool.

Return

int The initial size in MB of the GPU memory pool.

float fraction_of_gpu_memory_for_pool() const

Get the proportion of the initial memory pool size compared to the device.

Return

float The proportion of the initial memory pool size.

void EnableCUDNN()

Turn on CUDNN.

bool cudnn_enabled() const

A boolean state telling whether to use CUDNN.

Return

bool Whether to use CUDNN.

void SwitchIrOptim(int x = true)

Control whether to perform IR graph optimization. If turned off, the AnalysisConfig will act just like a NativeConfig.

Parameters
  • x: Whether the ir graph optimization is actived.

bool ir_optim() const

A boolean state telling whether the ir graph optimization is actived.

Return

bool Whether to use ir graph optimization.

void SwitchUseFeedFetchOps(int x = true)

INTERNAL Determine whether to use the feed and fetch operators. Just for internal development, not stable yet. When ZeroCopyTensor is used, this should be turned off.

Parameters
  • x: Whether to use the feed and fetch operators.

bool use_feed_fetch_ops_enabled() const

A boolean state telling whether to use the feed and fetch operators.

Return

bool Whether to use the feed and fetch operators.

void SwitchSpecifyInputNames(bool x = true)

Control whether to specify the inputs’ names. The ZeroCopyTensor type has a name member, assign it with the corresponding variable name. This is used only when the input ZeroCopyTensors passed to the AnalysisPredictor.ZeroCopyRun() cannot follow the order in the training phase.

Parameters
  • x: Whether to specify the inputs’ names.

bool specify_input_name() const

A boolean state tell whether the input ZeroCopyTensor names specified should be used to reorder the inputs in AnalysisPredictor.ZeroCopyRun().

Return

bool Whether to specify the inputs’ names.

void EnableTensorRtEngine(int workspace_size = 1 << 20, int max_batch_size = 1, int min_subgraph_size = 3, Precision precision = Precision::kFloat32, bool use_static = false, bool use_calib_mode = true)

Turn on the TensorRT engine. The TensorRT engine will accelerate some subgraphes in the original Fluid computation graph. In some models such as resnet50, GoogleNet and so on, it gains significant performance acceleration.

Parameters
  • workspace_size: The memory size(in byte) used for TensorRT workspace.

  • max_batch_size: The maximum batch size of this prediction task, better set as small as possible for less performance loss.

  • min_subgrpah_size: The minimum TensorRT subgraph size needed, if a subgraph is smaller than this, it will not be transferred to TensorRT engine.

  • precision: The precision used in TensorRT.

  • use_static: Serialize optimization information to disk for reusing.

  • use_calib_mode: Use TRT int8 calibration(post training quantization).

bool tensorrt_engine_enabled() const

A boolean state telling whether the TensorRT engine is used.

Return

bool Whether the TensorRT engine is used.

void SetTRTDynamicShapeInfo(std::map<std::string, std::vector<int>> min_input_shape, std::map<std::string, std::vector<int>> max_input_shape, std::map<std::string, std::vector<int>> optim_input_shape, bool disable_trt_plugin_fp16 = false)

Set min, max, opt shape for TensorRT Dynamic shape mode.

Parameters
  • min_input_shape: The min input shape of the subgraph input.

  • max_input_shape: The max input shape of the subgraph input.

  • opt_input_shape: The opt input shape of the subgraph input.

  • disable_trt_plugin_fp16: Setting this parameter to true means that TRT plugin will not run fp16.

void EnableLiteEngine(AnalysisConfig::Precision precision_mode = Precision::kFloat32, const std::vector<std::string> &passes_filter = {}, const std::vector<std::string> &ops_filter = {})

Turn on the usage of Lite sub-graph engine.

Parameters
  • precision_mode: Precion used in Lite sub-graph engine.

  • passes_filter: Set the passes used in Lite sub-graph engine.

  • ops_filter: Operators not supported by Lite.

bool lite_engine_enabled() const

A boolean state indicating whether the Lite sub-graph engine is used.

Return

bool whether the Lite sub-graph engine is used.

void SwitchIrDebug(int x = true)

Control whether to debug IR graph analysis phase. This will generate DOT files for visualizing the computation graph after each analysis pass applied.

Parameters
  • x: whether to debug IR graph analysis phase.

void EnableMKLDNN()

Turn on MKLDNN.

void SetMkldnnCacheCapacity(int capacity)

Set the cache capacity of different input shapes for MKLDNN. Default value 0 means not caching any shape.

Parameters
  • capacity: The cache capacity.

bool mkldnn_enabled() const

A boolean state telling whether to use the MKLDNN.

Return

bool Whether to use the MKLDNN.

void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads)

Set the number of cpu math library threads.

Parameters
  • cpu_math_library_num_threads: The number of cpu math library threads.

int cpu_math_library_num_threads() const

An int state telling how many threads are used in the CPU math library.

Return

int The number of threads used in the CPU math library.

NativeConfig ToNativeConfig() const

Transform the AnalysisConfig to NativeConfig.

Return

NativeConfig The NativeConfig transformed.

void SetMKLDNNOp(std::unordered_set<std::string> op_list)

Specify the operator type list to use MKLDNN acceleration.

Parameters
  • op_list: The operator type list.

void EnableMkldnnQuantizer()

Turn on MKLDNN quantization.

bool mkldnn_quantizer_enabled() const

A boolean state telling whether the MKLDNN quantization is enabled.

Return

bool Whether the MKLDNN quantization is enabled.

MkldnnQuantizerConfig *mkldnn_quantizer_config() const

Get MKLDNN quantizer config.

Return

MkldnnQuantizerConfig* MKLDNN quantizer config.

void SetModelBuffer(const char *prog_buffer, size_t prog_buffer_size, const char *params_buffer, size_t params_buffer_size)

Specify the memory buffer of program and parameter. Used when model and params are loaded directly from memory.

Parameters
  • prog_buffer: The memory buffer of program.

  • prog_buffer_size: The size of the model data.

  • params_buffer: The memory buffer of the combined parameters file.

  • params_buffer_size: The size of the combined parameters data.

bool model_from_memory() const

A boolean state telling whether the model is set from the CPU memory.

Return

bool Whether model and params are loaded directly from memory.

void EnableMemoryOptim()

Turn on memory optimize NOTE still in development.

bool enable_memory_optim() const

A boolean state telling whether the memory optimization is activated.

Return

bool Whether the memory optimization is activated.

void EnableProfile()

Turn on profiling report. If not turned on, no profiling report will be generated.

bool profile_enabled() const

A boolean state telling whether the profiler is activated.

Return

bool Whether the profiler is activated.

void DisableGlogInfo()

Mute all logs in Paddle inference.

bool glog_info_disabled() const

A boolean state telling whether logs in Paddle inference are muted.

Return

bool Whether logs in Paddle inference are muted.

void SetInValid() const

Set the AnalysisConfig to be invalid. This is to ensure that an AnalysisConfig can only be used in one AnalysisPredictor.

bool is_valid() const

A boolean state telling whether the AnalysisConfig is valid.

Return

bool Whether the AnalysisConfig is valid.

PassStrategy *pass_builder() const

Get a pass builder for customize the passes in IR analysis phase. NOTE: Just for developer, not an official API, easy to be broken.

void PartiallyRelease()

Protected Functions

void Update()
std::string SerializeInfoCache()

Protected Attributes

std::string model_dir_
std::string prog_file_
std::string params_file_
bool use_gpu_ = {false}
int device_id_ = {0}
uint64_t memory_pool_init_size_mb_ = {100}
bool use_cudnn_ = {false}
bool use_fc_padding_ = {true}
bool use_tensorrt_ = {false}
int tensorrt_workspace_size_ = {1 << 30}
int tensorrt_max_batchsize_ = {1}
int tensorrt_min_subgraph_size_ = {3}
Precision tensorrt_precision_mode_ = {Precision::kFloat32}
bool trt_use_static_engine_ = {false}
bool trt_use_calib_mode_ = {true}
std::map<std::string, std::vector<int>> min_input_shape_ = {}
std::map<std::string, std::vector<int>> max_input_shape_ = {}
std::map<std::string, std::vector<int>> optim_input_shape_ = {}
bool disable_trt_plugin_fp16_ = {false}
bool enable_memory_optim_ = {false}
bool use_mkldnn_ = {false}
std::unordered_set<std::string> mkldnn_enabled_op_types_
bool model_from_memory_ = {false}
bool enable_ir_optim_ = {true}
bool use_feed_fetch_ops_ = {true}
bool ir_debug_ = {false}
bool specify_input_name_ = {false}
int cpu_math_library_num_threads_ = {1}
bool with_profile_ = {false}
bool with_glog_info_ = {true}
std::string serialized_info_cache_
std::unique_ptr<PassStrategy> pass_builder_
bool use_lite_ = {false}
std::vector<std::string> lite_passes_filter_
std::vector<std::string> lite_ops_filter_
Precision lite_precision_mode_
int mkldnn_cache_capacity_ = {0}
bool use_mkldnn_quantizer_ = {false}
std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config_
bool is_valid_ = {true}
std::string opt_cache_dir_

Friends

friend class ::paddle::AnalysisPredictor
Struct NativeConfig
Inheritance Relationships
Base Type
Struct Documentation
struct paddle::NativeConfig : public paddle::PaddlePredictor::Config

configuration manager for NativePredictor.

AnalysisConfig manages configurations of NativePredictor. During inference procedure, there are many parameters(model/params path, place of inference, etc.)

Public Functions

void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads)

Set and get the number of cpu math library threads.

int cpu_math_library_num_threads() const

Public Members

bool use_gpu = {false}

GPU related fields.

int device = {0}
float fraction_of_gpu_memory{-1.f}

Change to a float in (0,1] if needed.

std::string prog_file
std::string param_file

Specify the exact path of program and parameter files.

bool specify_input_name = {false}

Specify the variable’s name of each input if input tensors don’t follow the feeds and fetches of the phase save_inference_model.

Protected Attributes

int cpu_math_library_num_threads_ = {1}

number of cpu math library (such as MKL, OpenBlas) threads for each instance.

Struct PaddlePredictor::Config
Nested Relationships

This struct is a nested type of Class PaddlePredictor.

Inheritance Relationships
Derived Type
Struct Documentation
struct paddle::PaddlePredictor::Config

Base class for NativeConfig and AnalysisConfig.

Subclassed by paddle::NativeConfig

Public Members

std::string model_dir

path to the model directory.

Struct PaddleTensor
Struct Documentation
struct paddle::PaddleTensor

Basic input and output data structure for PaddlePredictor.

Public Functions

PaddleTensor() = default

Public Members

std::string name

variable name.

std::vector<int> shape
PaddleBuf data

blob of data.

PaddleDType dtype
std::vector<std::vector<size_t>> lod

Tensor+LoD equals LoDTensor.

Class CpuPassStrategy
Inheritance Relationships
Base Type
Class Documentation
class paddle::CpuPassStrategy : public paddle::PassStrategy

The CPU passes controller, it is used in AnalysisPredictor with CPU mode.

Public Functions

CpuPassStrategy()

Default constructor of CpuPassStrategy.

CpuPassStrategy(const CpuPassStrategy &other)

Construct by copying another CpuPassStrategy object.

Parameters

~CpuPassStrategy() = default

Default destructor.

void EnableCUDNN() override

Enable the use of cuDNN kernel.

void EnableMKLDNN() override

Enable the use of MKLDNN.

void EnableMkldnnQuantizer() override

Enable MKLDNN quantize optimization.

Class GpuPassStrategy
Inheritance Relationships
Base Type
Class Documentation
class paddle::GpuPassStrategy : public paddle::PassStrategy

The GPU passes controller, it is used in AnalysisPredictor with GPU mode.

Public Functions

GpuPassStrategy()

Default constructor of GpuPassStrategy.

GpuPassStrategy(const GpuPassStrategy &other)

Construct by copying another GpuPassStrategy object.

Parameters

void EnableCUDNN() override

Enable the use of cuDNN kernel.

void EnableMKLDNN() override

Not supported in GPU mode yet.

void EnableMkldnnQuantizer() override

Not supported in GPU mode yet.

~GpuPassStrategy() = default

Default destructor.

Class MkldnnQuantizerConfig
Class Documentation
class paddle::MkldnnQuantizerConfig

Config for mkldnn quantize.

The MkldnnQuantizerConfig is used to configure Mkldnn’s quantization parameters, including scale algorithm, warmup data, warmup batch size, quantized op list, etc.

It is not recommended to use this config directly, please refer to AnalysisConfig::mkldnn_quantizer_config()

Public Functions

MkldnnQuantizerConfig()

Construct a new Mkldnn Quantizer Config object.

void SetScaleAlgo(std::string op_type_name, std::string conn_name, ScaleAlgo algo)

Set the scale algo.

Specify a quantization algorithm for a connection (input/output) of the operator type.

Parameters
  • [in] op_type_name: the operator’s name.

  • [in] conn_name: name of the connection (input/output) of the operator.

  • [in] algo: the algorithm for computing scale.

ScaleAlgo scale_algo(const std::string &op_type_name, const std::string &conn_name) const

Get the scale algo.

Get the quantization algorithm for a connection (input/output) of the operator type.

Return

the scale algo.

Parameters
  • [in] op_type_name: the operator’s name.

  • [in] conn_name: name of the connection (input/output) of the operator.

void SetWarmupData(std::shared_ptr<std::vector<PaddleTensor>> data)

Set the warmup data.

Set the batch of data to be used for warm-up iteration.

Parameters
  • [in] data: batch of data.

std::shared_ptr<std::vector<PaddleTensor>> warmup_data() const

Get the warmup data.

Get the batch of data used for warm-up iteration.

Return

the warm up data

void SetWarmupBatchSize(int batch_size)

Set the warmup batch size.

Set the batch size for warm-up iteration.

Parameters
  • [in] batch_size: warm-up batch size

int warmup_batch_size() const

Get the warmup batch size.

Get the batch size for warm-up iteration.

Return

the warm up batch size

void SetEnabledOpTypes(std::unordered_set<std::string> op_list)

Set quantized op list.

In the quantization process, set the op list that supports quantization

Parameters
  • [in] op_list: List of quantized ops

const std::unordered_set<std::string> &enabled_op_types() const

Get quantized op list.

Return

list of quantized ops

void SetExcludedOpIds(std::unordered_set<int> op_ids_list)

Set the excluded op ids.

Parameters
  • [in] op_ids_list: excluded op ids

const std::unordered_set<int> &excluded_op_ids() const

Get the excluded op ids.

Return

exclude op ids

void SetDefaultScaleAlgo(ScaleAlgo algo)

Set default scale algorithm.

Parameters
  • [in] algo: Method for calculating scale in quantization process

ScaleAlgo default_scale_algo() const

Get default scale algorithm.

Return

Method for calculating scale in quantization process

Protected Attributes

std::map<std::string, std::map<std::string, ScaleAlgo>> rules_
std::unordered_set<std::string> enabled_op_types_
std::unordered_set<int> excluded_op_ids_
std::shared_ptr<std::vector<PaddleTensor>> warmup_data_
int warmup_bs_ = {1}
ScaleAlgo default_scale_algo_ = {ScaleAlgo::MAX}
Class PaddleBuf
Class Documentation
class paddle::PaddleBuf

Memory manager for PaddleTensor.

The PaddleBuf holds a buffer for data input or output. The memory can be allocated by user or by PaddleBuf itself, but in any case, the PaddleBuf should be reused for better performance.

For user allocated memory, the following API can be used:

To have the PaddleBuf allocate and manage the memory:

Usage:

Let PaddleBuf manage the memory internally.

const int num_elements = 128;
PaddleBuf buf(num_elements/// sizeof(float));

Or

PaddleBuf buf;
buf.Resize(num_elements/// sizeof(float));
Works the exactly the same.

One can also make the PaddleBuf use the external memory.

PaddleBuf buf;
void* external_memory = new float[num_elements];
buf.Reset(external_memory, num_elements*sizeof(float));
...
delete[] external_memory; // manage the memory lifetime outside.

Public Functions

PaddleBuf(size_t length)

PaddleBuf allocate memory internally, and manage it.

Parameters
  • [in] length: The length of data.

PaddleBuf(void *data, size_t length)

Set external memory, the PaddleBuf won’t manage it.

Parameters
  • [in] data: The start address of the external memory.

  • [in] length: The length of data.

PaddleBuf(const PaddleBuf &other)

Copy only available when memory is managed externally.

Parameters

void Resize(size_t length)

Resize the memory.

Parameters
  • [in] length: The length of data.

void Reset(void *data, size_t length)

Reset to external memory, with address and length set.

Parameters
  • [in] data: The start address of the external memory.

  • [in] length: The length of data.

bool empty() const

Tell whether the buffer is empty.

void *data() const

Get the data’s memory address.

size_t length() const

Get the memory length.

~PaddleBuf()
PaddleBuf &operator=(const PaddleBuf&)
PaddleBuf &operator=(PaddleBuf&&)
PaddleBuf() = default
PaddleBuf(PaddleBuf &&other)
Class PaddlePassBuilder
Inheritance Relationships
Derived Type
Class Documentation
class paddle::PaddlePassBuilder

This class build passes based on vector<string> input. It is part of inference API. Users can build passes, insert new passes, delete passes using this class and its functions.

Example Usage: Build a new pass.

const vector<string> passes(1, "conv_relu_mkldnn_fuse_pass");
PaddlePassBuilder builder(passes);

Subclassed by paddle::PassStrategy

Public Functions

PaddlePassBuilder(const std::vector<std::string> &passes)

Constructor of the class. It stores the input passes.

Parameters
  • [in] passes: passes’ types.

void SetPasses(std::initializer_list<std::string> passes)

Stores the input passes.

Parameters
  • [in] passes: passes’ types.

void AppendPass(const std::string &pass_type)

Append a pass to the end of the passes.

Parameters
  • [in] pass_type: the type of the new pass.

void InsertPass(size_t idx, const std::string &pass_type)

Insert a pass to a specific position.

Parameters
  • [in] idx: the position to insert.

  • [in] pass_type: the type of insert pass.

void DeletePass(size_t idx)

Delete the pass at certain position ‘idx’.

Parameters
  • [in] idx: the position to delete.

void DeletePass(const std::string &pass_type)

Delete all passes that has a certain type ‘pass_type’.

Parameters
  • [in] pass_type: the certain pass type to be deleted.

void ClearPasses()

Delete all the passes.

void AppendAnalysisPass(const std::string &pass)

Append an analysis pass.

Parameters
  • [in] pass: the type of the new analysis pass.

void TurnOnDebug()

Visualize the computation graph after each pass by generating a DOT language file, one can draw them with the Graphviz toolkit.

std::string DebugString()

Human-readable information of the passes.

const std::vector<std::string> &AllPasses() const

Get information of passes.

Return

Return list of the passes.

std::vector<std::string> AnalysisPasses() const

Get information of analysis passes.

Return

Return list of analysis passes.

Class PaddlePredictor
Nested Relationships
Class Documentation
class paddle::PaddlePredictor

A Predictor for executing inference on a model. Base class for AnalysisPredictor and NativePaddlePredictor.

Public Functions

PaddlePredictor() = default
PaddlePredictor(const PaddlePredictor&) = delete
PaddlePredictor &operator=(const PaddlePredictor&) = delete
bool Run(const std::vector<PaddleTensor> &inputs, std::vector<PaddleTensor> *output_data, int batch_size = -1) = 0

This interface takes input and runs the network. There are redundant copies of data between hosts in this operation, so it is more recommended to use the zecopyrun interface.

Return

Whether the run is successful

Parameters
  • [in] inputs: An list of PaddleTensor as the input to the network.

  • [out] output_data: Pointer to the tensor list, which holds the output paddletensor

  • [in] batch_size: This setting has been discarded and can be ignored.

std::vector<std::string> GetInputNames()

Used to get the name of the network input. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.

Return

Input tensor names.

std::map<std::string, std::vector<int64_t>> GetInputTensorShape()

Get the input shape of the model.

Return

A map contains all the input names and shape defined in the model.

std::vector<std::string> GetOutputNames()

Used to get the name of the network output. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.

Return

Output tensor names.

std::unique_ptr<ZeroCopyTensor> GetInputTensor(const std::string &name)

Get the input ZeroCopyTensor by name. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. The name is obtained from the GetInputNames() interface.

Return

Return the corresponding input ZeroCopyTensor.

Parameters
  • name: The input tensor name.

std::unique_ptr<ZeroCopyTensor> GetOutputTensor(const std::string &name)

Get the output ZeroCopyTensor by name. Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. The name is obtained from the GetOutputNames() interface.

Return

Return the corresponding output ZeroCopyTensor.

Parameters
  • name: The output tensor name.

bool ZeroCopyRun()

Run the network with zero-copied inputs and outputs. Be inherited by AnalysisPredictor and only used in ZeroCopy scenarios. This will save the IO copy for transfering inputs and outputs to predictor workspace and get some performance improvement. To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(true) and then use the GetInputTensor and GetOutputTensor to directly write or read the input/output tensors.

Return

Whether the run is successful

std::unique_ptr<PaddlePredictor> Clone() = 0

Clone an existing predictor When using clone, the same network will be created, and the parameters between them are shared.

Return

unique_ptr which contains the pointer of predictor

~PaddlePredictor() = default

Destroy the Predictor.

std::string GetSerializedProgram() const
struct Config

Base class for NativeConfig and AnalysisConfig.

Subclassed by paddle::NativeConfig

Public Members

std::string model_dir

path to the model directory.

Class PassStrategy
Inheritance Relationships
Base Type
Derived Types
Class Documentation
class paddle::PassStrategy : public paddle::PaddlePassBuilder

This class defines the pass strategies like whether to use gpu/cuDNN kernel/MKLDNN.

Subclassed by paddle::CpuPassStrategy, paddle::GpuPassStrategy

Public Functions

PassStrategy(const std::vector<std::string> &passes)

Constructor of PassStrategy class. It works the same as PaddlePassBuilder class.

Parameters
  • [in] passes: passes’ types.

void EnableCUDNN()

Enable the use of cuDNN kernel.

void EnableMKLDNN()

Enable the use of MKLDNN. The MKLDNN control exists in both CPU and GPU mode, because there can still be some CPU kernels running in GPU mode.

void EnableMkldnnQuantizer()

Enable MKLDNN quantize optimization.

bool use_gpu() const

Check if we are using gpu.

Return

A bool variable implying whether we are in gpu mode.

~PassStrategy() = default

Default destructor.

Class ZeroCopyTensor
Class Documentation
class paddle::ZeroCopyTensor

Represents an n-dimensional array of values. The ZeroCopyTensor is used to store the input or output of the network. Zero copy means that the tensor supports direct copy of host or device data to device, eliminating additional CPU copy. ZeroCopyTensor is only used in the AnalysisPredictor. It is obtained through PaddlePredictor::GetinputTensor() and PaddlePredictor::GetOutputTensor() interface.

Public Functions

void Reshape(const std::vector<int> &shape)

Reset the shape of the tensor. Generally it’s only used for the input tensor. Reshape must be called before calling mutable_data() or copy_from_cpu()

Parameters
  • shape: The shape to set.

template<typename T>
T *mutable_data(PaddlePlace place)

Get the memory pointer in CPU or GPU with specific data type. Please Reshape the tensor first before call this. It’s usually used to get input data pointer.

Parameters
  • place: The place of the tensor.

template<typename T>
T *data(PaddlePlace *place, int *size) const

Get the memory pointer directly. It’s usually used to get the output data pointer.

Return

The tensor data buffer pointer.

Parameters
  • [out] place: To get the device type of the tensor.

  • [out] size: To get the data size of the tensor.

template<typename T>
void copy_from_cpu(const T *data)

Copy the host memory to tensor data. It’s usually used to set the input tensor data.

Parameters
  • data: The pointer of the data, from which the tensor will copy.

template<typename T>
void copy_to_cpu(T *data)

Copy the tensor data to the host memory. It’s usually used to get the output tensor data.

Parameters
  • [out] data: The tensor will copy the data to the address.

std::vector<int> shape() const

Return the shape of the Tensor.

void SetLoD(const std::vector<std::vector<size_t>> &x)

Set lod info of the tensor. More about LOD can be seen here: https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/lod_tensor.html#lodtensor.

Parameters
  • x: the lod info.

std::vector<std::vector<size_t>> lod() const

Return the lod info of the tensor.

const std::string &name() const

Return the name of the tensor.

void SetPlace(PaddlePlace place, int device = -1)
PaddleDType type() const

Return the data type of the tensor. It’s usually used to get the output tensor data type.

Return

The data type of the tensor.

Protected Functions

ZeroCopyTensor(void *scope)
void SetName(const std::string &name)
void *FindTensor() const

Enums

Enum PaddleDType
Enum Documentation
enum paddle::PaddleDType

Paddle data type.

Values:

enumerator FLOAT32
enumerator INT64
enumerator INT32
enumerator UINT8
Enum PaddleEngineKind
Enum Documentation
enum paddle::PaddleEngineKind

NOTE The following APIs are too trivial, we will discard it in the following versions.

Values:

enumerator kNative = 0

Use the native Fluid facility.

enumerator kAutoMixedTensorRT

Automatically mix Fluid with TensorRT.

enumerator kAnalysis

More optimization.

Enum PaddlePlace
Enum Documentation
enum paddle::PaddlePlace

Values:

enumerator kUNK = -1
enumerator kCPU
enumerator kGPU
Enum ScaleAlgo
Enum Documentation
enum paddle::ScaleAlgo

Algorithms for finding scale of quantized Tensors.

Values:

enumerator NONE

Do not compute scale.

enumerator MAX

Find scale based on the max absolute value.

enumerator MAX_CH

Find scale based on the max absolute value per output channel.

enumerator MAX_CH_T

Find scale based on the max absolute value per output channel of a transposed tensor

enumerator KL

Find scale based on KL Divergence.

Functions

Template Function paddle::CreatePaddlePredictor(const ConfigT&)
Function Documentation

Warning

doxygenfunction: Unable to resolve multiple matches for function “paddle::CreatePaddlePredictor” with arguments (const ConfigT&) in doxygen xml output for project “My Project” from directory: ./doxyoutput/xml. Potential matches:

- template<typename ConfigT, PaddleEngineKind engine> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
- template<typename ConfigT> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
Template Function paddle::CreatePaddlePredictor(const ConfigT&)
Function Documentation

Warning

doxygenfunction: Unable to resolve multiple matches for function “paddle::CreatePaddlePredictor” with arguments (const ConfigT&) in doxygen xml output for project “My Project” from directory: ./doxyoutput/xml. Potential matches:

- template<typename ConfigT, PaddleEngineKind engine> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
- template<typename ConfigT> std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &config)
Function paddle::get_version
Function Documentation
std::string paddle::get_version()
Function paddle::PaddleDtypeSize
Function Documentation
int paddle::PaddleDtypeSize(PaddleDType dtype)

Variables

Variable paddle::kLiteSubgraphPasses
Variable Documentation
const std::vector<std::string> paddle::kLiteSubgraphPasses

List of lite subgraph passes.

Variable paddle::kTRTSubgraphPasses
Variable Documentation
const std::vector<std::string> paddle::kTRTSubgraphPasses

List of tensorRT subgraph passes.

FAQ 常见问题