|
| 1 | +# paddle.sgn 设计文档 |
| 2 | + |
| 3 | +| API名称 | paddle.sgn | |
| 4 | +|----------------------------------------------------------|------------------------------------------------| |
| 5 | +| 提交作者<input type="checkbox" class="rowselector hidden"> | TreeML | |
| 6 | +| 提交时间<input type="checkbox" class="rowselector hidden"> | 2022-07-05 | |
| 7 | +| 版本号 | V1.0 | |
| 8 | +| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop | |
| 9 | +| 文件名 | 20220705_api_design_for_sgn.md<br> | |
| 10 | + |
| 11 | +# 一、概述 |
| 12 | + |
| 13 | +## 1、相关背景 |
| 14 | + |
| 15 | +对于复数张量,此函数返回一个新的张量,其元素与 input 元素的角度相同且绝对值为1; |
| 16 | +对于非复数张量,此函数返回 input 元素的符号。 |
| 17 | +此任务的目标是在 Paddle 框架中,新增 sgn API,调用路径为:paddle.sgn 和 Tensor.sgn。 |
| 18 | + |
| 19 | + |
| 20 | +## 3、意义 |
| 21 | + |
| 22 | +完善paddle中对于复数的sgn运算。 |
| 23 | + |
| 24 | +# 二、飞桨现状 |
| 25 | + |
| 26 | +目前paddle拥有类似的对于实数进行运算的API:sign, |
| 27 | +sign对输入x中每个元素进行正负判断,并且输出正负判断值:1代表正,-1代表负,0代表零, |
| 28 | +sgn是对sign复数功能的实现。 |
| 29 | + |
| 30 | +# 三、业内方案调研 |
| 31 | + |
| 32 | +## Pytorch |
| 33 | + |
| 34 | +Pytorch中有API`torch.sgn(input, *, out=None)` , 支持复数的符号函数运算: |
| 35 | + |
| 36 | + ``` |
| 37 | + This function is an extension of torch.sign() to complex tensors. It computes a new tensor whose elements have the same |
| 38 | + angles as the corresponding elements of input and absolute values (i.e. magnitudes) of one for complex tensors and is |
| 39 | + equivalent to torch.sign() for non-complex tensors. |
| 40 | + ``` |
| 41 | +官方文档链接为:https://pytorch.org/docs/stable/generated/torch.sgn.html?highlight=sgn#torch.sgn |
| 42 | + |
| 43 | +## Tensorflow |
| 44 | + |
| 45 | +在Tensorflow中sign此API同时支持复数与实数的符号函数运算: |
| 46 | + ``` |
| 47 | +y = sign(x) = -1 if x < 0; 0 if x == 0; 1 if x > 0. |
| 48 | +对于复数,y = sign(x) = x / |x| if x != 0, otherwise y = 0. |
| 49 | + ``` |
| 50 | +官方文档链接为:https://www.tensorflow.org/api_docs/python/tf/math/sign |
| 51 | + |
| 52 | +## Numpy |
| 53 | + |
| 54 | +在Numpy中sign此API同时支持复数与实数的符号函数运算,但其复数运算所得到的结果为sign(x.real) + 0j: |
| 55 | + ``` |
| 56 | +The sign function returns -1 if x < 0, 0 if x==0, 1 if x > 0. nan is returned for nan inputs. |
| 57 | +For complex inputs, the sign function returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j. |
| 58 | +complex(nan, 0) is returned for complex nan inputs. |
| 59 | +There is more than one definition of sign in common use for complex numbers. The definition used here is equivalent to |
| 60 | + x/sqrt(x*x) which is different from x/|x|a common alternative, . |
| 61 | + ``` |
| 62 | +官方文档链接为:https://numpy.org/doc/stable/reference/generated/numpy.sign.html?highlight=sign#numpy.sign |
| 63 | + |
| 64 | +### 实现方法 |
| 65 | + |
| 66 | +代码如下: |
| 67 | + |
| 68 | +Pytorch中使用C++来实现复数功能 |
| 69 | + |
| 70 | + ``` |
| 71 | + template<typename T> |
| 72 | +inline c10::complex<T> sgn_impl (c10::complex<T> z) { |
| 73 | + if (z == c10::complex<T>(0, 0)) { |
| 74 | + return c10::complex<T>(0, 0); |
| 75 | + } else { |
| 76 | + return z / zabs(z); |
| 77 | + } |
| 78 | +} |
| 79 | +
|
| 80 | + ``` |
| 81 | +github链接为:https://github.com/pytorch/pytorch/blob/d7fc864f0da461512fb7b972f04e24e296bd266d/aten/src/ATen/native/cpu/zmath.h#L156-L163 |
| 82 | +Tensorflow中使用python实现复数功能 |
| 83 | + |
| 84 | +``` |
| 85 | + if x.dtype.is_complex: |
| 86 | + return gen_math_ops.div_no_nan( |
| 87 | + x, |
| 88 | + cast( |
| 89 | + gen_math_ops.complex_abs( |
| 90 | + x, |
| 91 | + Tout=dtypes.float32 |
| 92 | + if x.dtype == dtypes.complex64 else dtypes.float64), |
| 93 | + dtype=x.dtype), |
| 94 | + name=name) |
| 95 | + return gen_math_ops.sign(x, name=name) |
| 96 | +``` |
| 97 | +github链接为:https://github.com/tensorflow/tensorflow/blob/7272e9f1f52ffe1b5aee67d1af3c2127634ab47d/tensorflow/python/ops/math_ops.py#L746-L790 |
| 98 | + |
| 99 | +# 四、对比分析 |
| 100 | + |
| 101 | +Tensorflow与Pytorch对于实现复数功能部分的代码核心逻辑相同,torch的代码使用C++实现但它将实数和复数拆分成了两个API,类似于paddle的想法; |
| 102 | +Tensorflow的代码使用Python实现但它将两个功能合于一个API中。 |
| 103 | +鉴于两段代码的逻辑类似,故参考Pytorch的代码或参考Tensorflow的代码皆可。 |
| 104 | + |
| 105 | +# 五、方案设计 |
| 106 | + |
| 107 | +## 命名与参数设计 |
| 108 | + |
| 109 | +API设计为`paddle.sgn(x, name=None)`和`paddle.Tensor.sgn(x, name=None)` |
| 110 | +命名与参数顺序为:形参名`input`->`x`, 与paddle其他API保持一致性,不影响实际功能使用。 |
| 111 | + |
| 112 | + |
| 113 | +## 底层OP设计 |
| 114 | + |
| 115 | +使用已有API进行组合,不再单独设计底层OP。 |
| 116 | + |
| 117 | + |
| 118 | +## API实现方案 |
| 119 | + |
| 120 | +使用is_complex判断输入是否为复数、若为实数则使用sign进行运算;若为复数则使用as_real将其转化为实数tensor,将其中的非零部分除以它自己的绝对值 |
| 121 | +,最后再使用as_complex将其转换回复数返回。 |
| 122 | + |
| 123 | +# 六、测试和验收的考量 |
| 124 | + |
| 125 | +测试考虑的case如下: |
| 126 | + |
| 127 | +- 编程范式场景:覆盖静态图和动态图测试场景 |
| 128 | +- 硬件场景:覆盖CPU和GPU测试场景 |
| 129 | +- Tensor精度场景:支持float16, float32 , float64, complex64 , complex128 |
| 130 | +- 参数组合场景 |
| 131 | +- 计算精度:前向计算,和numpy实现的函数对比结果;反向计算,由Python组合的新增API无需验证反向计算 |
| 132 | +- 异常测试:由于使用了已有API:sign,该API不支持整型运算,仅支持float16, float32 或 float64,所以需要做数据类型的异常测试 |
| 133 | + |
| 134 | + |
| 135 | + |
| 136 | +# 七、可行性分析及规划排期 |
| 137 | + |
| 138 | +方案主要依赖paddle现有API组合而成,并自行实现核心算法。 |
| 139 | + |
| 140 | +# 八、影响面 |
| 141 | + |
| 142 | +为独立新增API,对其他模块没有影响。 |
| 143 | + |
| 144 | +# 名词解释 |
| 145 | + |
| 146 | +无 |
| 147 | + |
| 148 | +# 附件及参考资料 |
| 149 | + |
| 150 | +无 |
0 commit comments