首页 > tensor乘运算

tensor乘运算

torch.mul(a, b) 是矩阵 对应位相乘,即点乘操作, a和b的维度必须相等,a的维度是(1,2), 则b的维度必须是(1,2), 返回还是(1,2)的矩阵

torch.mm(a,b)是矩阵a和b矩阵相乘,a的维度是(1,2),b的维度是(2,3),返回是(1,3)的矩阵

torch.bmm(a,b)是矩阵a和b在维度1、2上矩阵相乘,一般要求是 三维矩阵,a的维度是(64,1,2),b的维度是(64,2,3)返回的是(64,1,3)矩阵

import torchif __name__ == "__main__":x = torch.ones(1, 2)y = torch.ones(1, 2) * 2z = torch.mul(x, y)print("torch.mul() example ")print("x.shape is ", x.shape) print("y.shape is ", y.shape)print("z.shape is ", z.shape) ## [1, 2]x = torch.ones(1, 2)y = torch.ones(2, 3) z = torch.mm(x, y)print("torch.mm() example ")print("x.shape is ", x.shape)print("y.shape is ", y.shape)print("z.shape is ", z.shape) ## [1,3]x = torch.randn(64, 1, 2)y = torch.randn(64, 2, 3)z = torch.bmm(x, y)print("torch.bmm example ")print("x.shape is ", x.shape)print("y.shape is ", y.shape)print("z.shape is ", z.shape) ## [64, 1, 3]

转载于:https://www.cnblogs.com/yeran/p/11288171.html

更多相关:

  • #coding:utf-8'''Created on 2017年10月25日@author: li.liu'''import pymysqldb=pymysql.connect('localhost','root','root','test',charset='utf8')m=db.cursor()'''try:#a=raw_inpu...

  • python数据类型:int、string、float、boolean 可变变量:list 不可变变量:string、元组tuple 1.list list就是列表、array、数组 列表根据下标(0123)取值,下标也叫索引、角标、编号 new_stus =['刘德华','刘嘉玲','孙俪','范冰冰'] 最前面一个元素下标是0,最...

  • from pathlib import Path srcPath = Path(‘../src/‘) [x for x in srcPath.iterdir() if srcPath.is_dir()] 列出指定目录及子目录下的所有文件 from pathlib import Path srcPath = Path(‘../tenso...

  • 我在使用OpenResty编写lua代码时,需要使用到lua的正则表达式,其中pattern是这样的, --热水器设置时间 local s = '12:33' local pattern = "(20|21|22|23|[01][0-9]):([0-5][0-9])" local matched = string.match(s, "...

  • 在分析ats的访问日志时,我经常会遇到将一些特殊字段对齐显示的需求,网上调研了一下,发现使用column -t就可以轻松搞定,比如 找到ATS的access.log中的200响应时间过长的日志 cat access.log | grep ' 200 ' | awk -F '"' '{print $3}' > taoyx.log co...

  • 转自:http://stackoverflow.com/questions/8377091/what-are-the-differences-between-cv-8u-and-cv-32f-and-what-should-i-worry-about CV_8U is unsigned 8bit/pixel - ie a pixel...