PyTorch Tensor的数学运算
转载自
转载自:嘻哈吼嘿呵
本文介绍了 PyTorch 的一些数学运算方法。
基础运算
可以使用 + - * / 四则运算符号(推荐)
也可以使用 torch.add, torch.mul, torch.sub, torch.div
加法运算
def add():
# add +
# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3,4)
b = torch.rand(4)
print("a = {}".format(a))
print("b = {}".format(b))
# a、b列数相同,行数不同,将a的每行与b对应位置相加
c1 = a + b
c2 = torch.add(a,b)
c3 = torch.eq(c1,c2)
# torch.all()判断每个位置的元素是否相同
c4 = torch.all(c3)
print("a + b = {}".format(c1))
print("a + b = {}".format(c2))
print("torch.eq = {}".format(c3))
print("torch all = {}".format(c4))
# a = tensor([[0.8514, 0.5017, 0.3924, 0.7817],
# [0.0219, 0.7352, 0.5634, 0.7285],
# [0.9187, 0.1628, 0.9236, 0.3603]])
# b = tensor([0.0809, 0.0295, 0.6065, 0.8024])
# a + b = tensor([[0.9322, 0.5312, 0.9989, 1.5841],
# [0.1028, 0.7647, 1.1700, 1.5309],
# [0.9996, 0.1923, 1.5301, 1.1627]])
# a + b = tensor([[0.9322, 0.5312, 0.9989, 1.5841],
# [0.1028, 0.7647, 1.1700, 1.5309],
# [0.9996, 0.1923, 1.5301, 1.1627]])
# torch.eq = tensor([[True, True, True, True],
# [True, True, True, True],
# [True, True, True, True]])
# torch
# all = True
减法运算
def minus():
# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3,4)
b = torch.rand(4)
print("a = {}".format(a))
print("b = {}".format(b))
# a、b列数相同,行数不同,将a的每行与b对应位置相加
c1 = a - b
c2 = torch.sub(a,b)
# torch.all()判断每个位置的元素是否相同
c3 = torch.eq(c1,c2)
c4 = torch.all(c3)
print("a - b = {}".format(c1))
print("a - b = {}".format(c2))
print("torch.eq = {}".format(c3))
print("torch all = {}".format(c4))
# a = tensor([[0.8499, 0.1003, 0.3179, 0.1217],
# [0.2119, 0.7742, 0.3973, 0.7241],
# [0.8559, 0.3558, 0.1549, 0.4583]])
# b = tensor([0.4750, 0.9261, 0.7107, 0.1397])
# a - b = tensor([[0.3749, -0.8258, -0.3928, -0.0180],
# [-0.2631, -0.1519, -0.3135, 0.5844],
# [0.3809, -0.5703, -0.5558, 0.3186]])
# a - b = tensor([[0.3749, -0.8258, -0.3928, -0.0180],
# [-0.2631, -0.1519, -0.3135, 0.5844],
# [0.3809, -0.5703, -0.5558, 0.3186]])
# torch.eq = tensor([[True, True, True, True],
# [True, True, True, True],
# [True, True, True, True]])
# torch
# all = True
哈达玛积 (element wise,对应元素相乘)
def mul_element():
# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3,4)
b = torch.rand(4)
print("a = {}".format(a))
print("b = {}".format(b))
# a、b列数相同,行数不同,将a的每行与b对应位置相加
c1 = a * b
c2 = torch.mul(a,b)
# torch.all()判断每个位置的元素是否相同
c3 = torch.eq(c1,c2)
c4 = torch.all(c3)
print("a * b = {}".format(c1))
print("a * b = {}".format(c2))
print("torch.eq = {}".format(c3))
print("torch all = {}".format(c4))
# a = tensor([[0.9678, 0.8896, 0.5657, 0.7644],
# [0.0581, 0.3479, 0.2008, 0.1259],
# [0.4169, 0.9426, 0.1330, 0.5813]])
# b = tensor([0.3827, 0.7139, 0.4547, 0.6798])
# a * b = tensor([[0.3704, 0.6351, 0.2572, 0.5197],
# [0.0222, 0.2484, 0.0913, 0.0856],
# [0.1595, 0.6729, 0.0605, 0.3952]])
# a * b = tensor([[0.3704, 0.6351, 0.2572, 0.5197],
# [0.0222, 0.2484, 0.0913, 0.0856],
# [0.1595, 0.6729, 0.0605, 0.3952]])
# torch.eq = tensor([[True, True, True, True],
# [True, True, True, True],
# [True, True, True, True]])
# torch all = True
除法运算
对应元素相除
def test():
# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3,4)
b = torch.rand(4)
print("a = {}".format(a))
print("b = {}".format(b))
# a、b列数相同,行数不同,将a的每行与b对应位置相加
c1 = a / b
c2 = torch.div(a,b)
# torch.all()判断每个位置的元素是否相同
c3 = torch.eq(c1,c2)
c4 = torch.all(c3)
print("a / b = {}".format(c1))
print("a / b = {}".format(c2))
print("torch.eq = {}".format(c3))
print("torch all = {}".format(c4))
#a = tensor([[0.6079, 0.2791, 0.0034, 0.6169],
# [0.5279, 0.7804, 0.5960, 0.0359],
# [0.3385, 0.2300, 0.2021, 0.7161]])
# b = tensor([0.5951, 0.8573, 0.7276, 0.8717])
# a * b = tensor([[1.0214, 0.3256, 0.0047, 0.7077],
# [0.8870, 0.9103, 0.8190, 0.0412],
# [0.5687, 0.2682, 0.2778, 0.8215]])
# a * b = tensor([[1.0214, 0.3256, 0.0047, 0.7077],
# [0.8870, 0.9103, 0.8190, 0.0412],
# [0.5687, 0.2682, 0.2778, 0.8215]])
# torch.eq = tensor([[True, True, True, True],
# [True, True, True, True],
# [True, True, True, True]])
# torch all = True
矩阵运算
- matmul 表示 matrix mul
*
表示的是 element-wise, 对应元素相乘torch.mm(a,b)
只能计算 2D 不推荐,矩阵相乘torch.matmul(a,b)
可以计算更高维度,落脚点依旧在行与列。 推荐@
是 matmul 的重载形式
二维矩阵相乘
二维矩阵乘法运算操作包括 torch.mm()、torch.matmul()、@
def test():
a = torch.ones(2,1)
b = torch.ones(1,2)
print("a = {}".format(a))
print("b = {}".format(b))
c1 = torch.mm(a,b)
c2 = torch.matmul(a,b)
c3 = a @ b
print("c1 = {}".format(c1))
print("c2 = {}".format(c2))
print("c3 = {}".format(c3))
# a = tensor([[1.],
# [1.]])
# b = tensor([[1., 1.]])
# c1 = tensor([[1., 1.],
# [1., 1.]])
# c2 = tensor([[1., 1.],
# [1., 1.]])
# c3 = tensor([[1., 1.],
# [1., 1.]])
多维矩阵相乘
对于高维的 Tensor(dim>2),定义其矩阵乘法仅在最后的两个维度上,要求前面的维度必须保持一致,就像矩阵的索引一样并且运算操只有 torch.matmul()。
- 对于 2 维以上的 matrix multiply ,
torch.mm(a,b)
就不行了。 - 运算规则:只取最后的两维做矩阵乘法
- 对于 [b, c, h, w] 来说,b,c 是不变的,图片的大小在改变;并且也并行的计算出了 b,c。也就是支持多个矩阵并行相乘。
- 对于不同的 size,如果符合 broadcast,先执行 broadcast,在进行矩阵相乘。
def test():
# 多维矩阵计算,前两个维度必须一致
c = torch.rand(4, 3, 28, 64)
d = torch.rand(4, 3, 64, 32)
print(torch.matmul(c,d).shape)
# torch.Size([4, 3, 28, 32])
注意,在这种情形下的矩阵相乘,前面的 “矩阵索引维度” 如果符合 Broadcasting 机制,也会自动做广播,然后相乘。
def test():
# 多维矩阵计算,前两个维度必须一致
c = torch.rand(4, 3, 28, 64)
d = torch.rand(4, 1, 64, 32)
print(torch.matmul(c,d).shape)
# torch.Size([4, 3, 28, 32])
幂运算
def test():
# troch.full(size, fill_value)
# 参数:
# size: 生成张量的大小,list, tuple, torch.size
# fill_value: 填充张量的数
a = torch.full([2, 2], 3)
print("a = {}".format(a))
b1 = a.pow(2) # 也可以a**2
b2 = a**2
print("b1 = {}".format(b1))
print("b2 = {}".format(b2))
# a = tensor([[3., 3.],
# [3., 3.]])
# b1 = tensor([[9., 9.],
# [9., 9.]])
# b2 = tensor([[9., 9.],
# [9., 9.]])
#
开方运算
- pow(a, n) a 的 n 次方
**
也表示次方(可以是 2,0.5,0.25,3) 推荐- sqrt() 表示 square root 平方根
- rsqrt() 表示平方根的倒数
def test():
a = torch.full([2, 2], 9)
print("a = {}".format(a))
b1 = a.sqrt() # 也可以a**(0.5)
# 平方根的倒数
b2 = a.rsqrt()
print("b1 = {}".format(b1))
print("b2 = {}".format(b2))
# a = tensor([[9., 9.],
# [9., 9.]])
# b1 = tensor([[3., 3.],
# [3., 3.]])
# b2 = tensor([[0.3333, 0.3333],
# [0.3333, 0.3333]])
指数与对数运算
注意log
是以自然对数为底数的,以 2 为底的用log2
,以 10 为底的用log10
- exp(n) 表示:e 的 n 次方
- log(a) 表示:ln(a)
- log2() 、 log10()
def test():
a = torch.ones(2,2)
print("a = {}".format(a))
# 得到 2*2 矩阵的全是 e 的Tensor,相当于a的所有元素乘以e
b = torch.exp(a)
c = torch.log(a)
print("b = {}".format(b))
print("c = {}".format(c))
# a = tensor([[1., 1.],
# [1., 1.]])
# b = tensor([[2.7183, 2.7183],
# [2.7183, 2.7183]])
# c = tensor([[0., 0.],
# [0., 0.]])
近似值运算
近似相关 1
- floor、ceil 向下取整、向上取整
- round 4 舍 5 入
- trunc、frac 裁剪
def test():
a = torch.tensor(3.14)
b = torch.tensor(3.49)
c = torch.tensor(3.5)
# 取下,取上,取整数,取小数
print("a.floor = {},a.ceil = {},a.trunc = {},a.frac = {}"
.format(a.floor(),a.ceil(),a.trunc(),a.frac()))
# 四舍五入
print("b.rounc = {}, c.round = {}".format(b.round(),c.round()))
# a.floor = 3.0, a.ceil = 4.0, a.trunc = 3.0, a.frac = 0.1400001049041748
# b.rounc = 3.0, c.round = 4.0
裁剪运算
即对 Tensor 中的元素进行范围过滤,不符合条件的可以把它变换到范围内部(边界)上,常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理,实际使用时可以查看梯度的(L2 范数)模来看看需不需要做处理:w.grad.norm(2)
。
近似相关 2 (用的更多一些)
- gradient clipping 梯度裁剪
- (min) 小于 min 的都变为某某值
- (min, max) 不在这个区间的都变为某某值
- 梯度爆炸:一般来说,当梯度达到 100 左右的时候,就已经很大了,正常在 10 左右,通过打印梯度的模来查看
w.grad.norm(2)
- 对于 w 的限制叫做 weight clipping,对于 weight gradient clipping 称为 gradient clipping。
def test():
# 两行三列,切元素在0-15之间随机生成
grad = torch.rand(2, 3) * 15 # 0~15随机生成
print("grad = {}".format(grad))
# 最大值最小值平均值
print("grad.max = {}, grad.min = {}, grad.median = {}"
.format(grad.max(), grad.min(), grad.median()))
# 最小是10,小于10的都变成10
print("grad.clamp(10) = {}".format(grad.clamp(10)))
# 最小是3, 小于3的都变成3; 最大是10, 大于10的都变成10
print("grad.clamp(3, 10) = {}".format(grad.clamp(3, 10))) #
# grad = tensor([[7.2015, 13.5902, 3.7276],
# [3.9825, 2.9701, 11.7545]])
# grad.max = 13.590229034423828, grad.min = 2.9700870513916016, grad.median = 3.982494831085205
# grad.clamp(10) = tensor([[10.0000, 13.5902, 10.0000],
# [10.0000, 10.0000, 11.7545]])
# grad.clamp(3, 10) = tensor([[7.2015, 10.0000, 3.7276],
# [3.9825, 3.0000, 10.0000]])