cal_anchors_heading
接口原型
mx_driving.cal_anchors_heading(Tensor anchors, Tensor origin_pos=None) -> Tensor
功能描述
根据输入的 anchors 和起始点坐标计算方向。
参数说明
anchors(Tensor):每个锚点轨迹的序列坐标,数据类型为float32,shape 为[batch_size, anchors_num, seq_length, 2]。origin_pos(Tensor):可选参数,每个 anchor 的起始位置坐标,数据类型为float32,shape 为[batch_size, 2]。
返回值
heading(Tensor):每个 anchor 的轨迹点坐标方向(弧度),数据类型为float32,shape 为[batch_size, anchors_num, seq_length]。
算子约束
- 1 <= batch_size <= 2048
- 1 <= anchors_num <= 10240
- 1 <= seq_length <= 256
支持的型号
- Atlas A2 训练系列产品
调用示例
import torch
import torch_npu
import mx_driving
batch_size = 2
anchors_num = 64
seq_length = 24
anchors = torch.randn((batch_size, anchors_num, seq_length, 2)).npu()
origin_pos = torch.randn((batch_size, 2)).npu()
heading = mx_driving.cal_anchors_heading(anchors, origin_pos)