ZeroCopyTensor 方法

ZeroCopyTensor 是 Paddle Inference 的数据组织形式,用于对底层数据进行封装并提供接口对数据进行操作,包括设置 Shape、数据、LoD 信息等。

注意: 应使用 PredictorGetInputTensorsGetOutputTensors 接口获取输入输出 ZeroCopyTensor

ZeroCopyTensor 的API定义如下:

// 获取 ZeroCopyTensor 维度信息
// 参数:tensor - ZeroCopyTensor 对象指针
// 返回:[]int32 - 包含 ZeroCopyTensor 维度信息的int数组
func (tensor *ZeroCopyTensor) Shape() []int32

// 设置 ZeroCopyTensor 维度信息
// 参数:tensor - ZeroCopyTensor 对象指针
//      shape - 包含维度信息的int数组
// 返回:None
func (tensor *ZeroCopyTensor) Reshape(shape []int32)

// 获取 ZeroCopyTensor 名称
// 参数:tensor - ZeroCopyTensor 对象指针
// 返回:string - ZeroCopyTensor 名称
func (tensor *ZeroCopyTensor) Name() string

// 设置 ZeroCopyTensor 名称
// 参数:tensor - ZeroCopyTensor 对象指针
//      name - ZeroCopyTensor 名称
// 返回:None
func (tensor *ZeroCopyTensor) Rename(name string)

// 获取 ZeroCopyTensor 数据类型
// 参数:tensor - ZeroCopyTensor 对象指针
// 返回:PaddleDType - ZeroCopyTensor 数据类型
func (tensor *ZeroCopyTensor) DataType() PaddleDType

// 设置 ZeroCopyTensor 数据
// 参数:tensor - ZeroCopyTensor 对象指针
//      value - ZeroCopyTensor 数据
// 返回:None
func (tensor *ZeroCopyTensor) SetValue(value interface{})

// 获取 ZeroCopyTensor 数据
// 参数:tensor - ZeroCopyTensor 对象指针
// 返回:interface{} - ZeroCopyTensor 数据
func (tensor *ZeroCopyTensor) Value() interface{}

代码示例:

package main

// 引入 Paddle Golang Package
import "/pathto/Paddle/go/paddle"
import "reflect"

func main() {
    // 创建 AnalysisConfig 对象
    config := paddle.NewAnalysisConfig()
    config.SwitchUseFeedFetchOps(false)

    // 设置预测模型路径,这里为非 Combined 模型
    config.SetModel("data/mobilenet_v1", "")
    // config.SetModel("data/model/__model__", "data/model/__params__")

    // 根据 Config 构建预测执行对象 Predictor
    predictor := paddle.NewPredictor(config)

    // 获取输入输出 Tensor 信息
    println("input num: ", predictor.GetInputNum())
    println("input name: ", predictor.GetInputNames()[0])
    println("output num: ", predictor.GetOutputNum())
    println("output name: ", predictor.GetInputNames()[0])

    // 获取输入输出 Tensor 指针
    input := predictor.GetInputTensors()[0]
    output := predictor.GetOutputTensors()[0]

    input_data := make([]float32, 1 * 3 * 224 * 224)
    for i := 0; i < 1 * 3 * 224 * 224; i++ {
      input_data[i] = 1.0
    }
    input.SetValue(input_data)
    input.Reshape([]int32{1, 3, 224, 224})

    // 设置输入 Tensor
    predictor.SetZeroCopyInput(input)
    // 执行预测
    predictor.ZeroCopyRun()
    // 获取输出 Tensor
    predictor.GetZeroCopyOutput(output)

    // 获取输出 Tensor 信息
    output_val := output.Value()
    value := reflect.ValueOf(output_val)
    shape, dtype := paddle.ShapeAndTypeOf(value)
    v := value.Interface().([][]float32)
    println("Ouptut Shape is: ", shape[0], "x", shape[1])
    println("Ouptut Dtype is: ", dtype)
    println("Output Data is: ", v[0][0], v[0][1], v[0][2], v[0][3], v[0][4], "...")
    
    // 删除 Predictor 对象
    paddle.DeletePredictor(predictor)
}