目录
- 一、 torch.bmm
- 二、 torch.einsum
- 三、python中变量前面有个*
- 四、numpy.prod
- 五、torch.chunk
一、 torch.bmm
torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法( matrix multiplication)操作。
它的输入是三维张量,形状为 (batch, n, m) 和 (batch, m, p):
其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。
torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch, n, p)。
二、 torch.einsum
torch.einsum是pytorch上的一个强大的函数,用于矩阵相关的计算,注意,这里没有限定为矩阵乘法。torch.einsum基于爱因斯坦求和约定执行张量操作,能够用简洁的表达式实现复杂的多维数组操作,从而避免繁琐的张量操作组合(如reshape、permute、bmm等),减少错误率。需要说明的是,尽管einsum函数内部进行了大量计算优化,但其主要优势在于表达式简洁,如果与单步reshape等pytorch实现的矩阵运算操作相比,其运算速度与内存占用不一定占优势。
1.矩阵乘法:‘ij,jk->ik’ 表示形状为(i,j)与形状为(j,k)的矩阵进行矩阵乘法,得到新矩阵形状为(i,k)。这也是torch.einsum最常规的用法。
2.维度调换:'ij->ji’表示形状为(i,j)的矩阵维度调换成为形状为(j,i)的矩阵。
torch.einsum还有多种用法,遇到再来添加
三、python中变量前面有个*
在Python中,变量前面的星号(*)有多种用法,主要与函数参数或解包序列有关。
1、在函数参数中,星号(*)用来表示任意多个参数,这些参数会被当作元组传递。例如:
def fun(*args):
for i in args:
print(i)
fun(1, 2, 3, 4)
2、在函数参数中,星号(*)还可以用来解包序列。例如:
def fun(a, b, c, d):
print(a, b, c, d)
args = (1, 2, 3, 4)
fun(*args)
3、在函数参数中,星号(*)还可以与命名参数,或者字典一起使用。例如:
def fun(*args, a=1):
print(args, a)
fun(1, 2, 3, a=4)
def fun(*args, **kwargs):
print(args, kwargs)
fun(1, 2, 3, a=4, b=5)
4、 在解包列表或元组时,星号(*)也可以用来解包选定项。例如:
lst = [1, 2, 3, 4, 5]
a, *b, c = lst
print(a, b, c)
四、numpy.prod
计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=0))按列计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=1))按行计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=0))计算所有元素和
五、torch.chunk
对于一个输入tensor,torch.chunk方法会按照dim指定的维度将输入tensor划分为若干个chunk,划分的数量为chunks。
torch.chunk(input, chunks, dim=0)
temp=torch.randn((4,6))
print(torch.chunk(temp,2,0))行方向分块
print(torch.chunk(temp,2,1))列方向分块