首页 > 其他分享 >cuda 加速矩阵乘法

cuda 加速矩阵乘法

时间:2024-03-15 14:55:05浏览次数:31  
标签:int void 矩阵 grid cuda sizeof SIZE BLOCK 乘法

对于一个m * n的矩阵a和一个n * k的矩阵b

因为最后得到一个m * k的矩阵c,那么我们可以分配m * k个线程。

在线程(i,j)里矩阵a的第i行和矩阵b的第j列进行点积运算得到c[i][j]

#include<iostream>
#include "cuda_runtime.h"
#include "device_launch_parameters.h"

# define BLOCK_SIZE 2

__global__ void gpu_matrix_mult(int* a, int* b, int* c, int m, int n, int k)
{
        //row和col是该线程所在行数和列数
	int row = blockIdx.y * blockDim.y + threadIdx.y;
	int col = blockIdx.x * blockDim.x + threadIdx.x;

	int sum = 0;
	if (col < k && row < m)
	{
		for (int i = 0; i < n; i++)
		{
			sum += a[row * n + i] * b[i * k + col];
		}
		c[row * k + col] = sum;
	}
}
int main()
{
	int m = 100, n = 100, k = 100;
	
	int* h_a, * h_b, * h_c;
	cudaMallocHost((void**)&h_a, sizeof(int) * m * n);
	cudaMallocHost((void**)&h_b, sizeof(int) * n * k);
	cudaMallocHost((void**)&h_c, sizeof(int) * m * k);

	for (int i = 0; i < m; ++i)
	{
		for (int j = 0; j < n; ++j) 
			h_a[i * n + j] = rand() % 1024;
	}
	for (int i = 0; i < n; ++i)
	{
		for (int j = 0; j < k; ++j)
			h_b[i * k + j] = rand() % 1024;
	}

	int* d_a, * d_b, * d_c;
	cudaMalloc((void**)&d_a, sizeof(int) * m * n);
	cudaMalloc((void**)&d_b, sizeof(int) * n * k);
	cudaMalloc((void**)&d_c, sizeof(int) * m * k);

	cudaMemcpy(d_a, h_a, sizeof(int) * m * n, cudaMemcpyHostToDevice);
	cudaMemcpy(d_b, h_b, sizeof(int) * n * k, cudaMemcpyHostToDevice);
	
        //BLOCK_SIZE是一个block边的大小
        //grid_rows是一个grid有几行block
        //grid_cols是一个grid有几列block
        //dimGrid是一个grid一行有几个block,一列有几个block
        //dimBlock是一个block一行有几个thread,一列有几个thread
	unsigned int grid_rows = (m + BLOCK_SIZE - 1) / BLOCK_SIZE; 
	unsigned int grid_cols = (k + BLOCK_SIZE - 1) / BLOCK_SIZE;
	dim3 dimGrid(grid_cols, grid_rows);
	dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);

	gpu_matrix_mult<<<dimGrid , dimBlock>>>(d_a, d_b, d_c, m, n, k);
	cudaMemcpy(h_c, d_c, sizeof(int) * m * k, cudaMemcpyDeviceToHost);
	for (int i = 0; i < m*k; i++)
	{
		std::cout << h_c[i] << std::endl;
	}
	return 0;
}

标签:int,void,矩阵,grid,cuda,sizeof,SIZE,BLOCK,乘法
From: https://www.cnblogs.com/algoshimo/p/18075392

相关文章

  • 关于矩阵的相关运算
    定义肥肠的简单即将矩阵作为一个整体进行加减乘(包括快速幂)等的运算操作对,没有除至于为什么我不会告诉你是因为huge没有讲的不过对于除法我们可以通过增广原除数矩阵再进行高斯消元balabala的操作求得除数矩阵的逆矩阵再将被除数矩阵乘上除数矩阵的逆矩阵得到ans矩阵我们......
  • C#实现图的邻接矩阵和邻接表结构
    原文链接:https://blog.csdn.net/weixin_41883890/article/details/125517599本文介绍C#实现图的邻接矩阵和邻接表结构。逻辑结构分为两部分:V和E集合,其中,V是顶点,E是边。因此,用一个一维数组存放图中所有顶点数据;用一个二维数组存放顶点间关系(边或弧)的数据,这个二维数组称为邻接矩阵......
  • LeetCodeHot100 73. 矩阵置零 54. 螺旋矩阵 48. 旋转图像 240. 搜索二维矩阵 II
    73.矩阵置零https://leetcode.cn/problems/set-matrix-zeroes/description/?envType=study-plan-v2&envId=top-100-likedpublicvoidsetZeroes(int[][]matrix){inttop=0,bottom=matrix.length,left=0,right=matrix[0].length;int[][]flag......
  • 使用for语法打印出九九乘法表
    先打印第一列publicclassForDemo03{publicstaticvoidmain(String[]args){for(inti=1;i<=9;i++){System.out.println(1+"*"+i+"="+(1*i));//1*1~1*9}}}把固定乘数"1"用一个循环......
  • [SCOI 2009] 迷路 (矩阵快速幂)
    [SCOI2009]迷路传送门问题描述Windy在有向图中迷路了。该有向图有\({N}\)个节点,Windy从节点\({1}\)出发,他必须恰好在\({T}\)时刻到达节点\({N}\)。现在给出该有向图,你能告诉Windy总共有多少种不同的路径吗?注意:Windy不能在某个节点逗留,且通过某有向边的时间严格......
  • 矩阵快速幂
    矩阵快速幂例题6.方程【算法赛】-蓝桥云课P1962斐波那契数列-洛谷P2455SDOI2006-洛谷P7112【模板】行列式求值-洛谷P1939矩阵加速(数列)-洛谷我们线代课已经讲到矩阵了,自己也终于把之前卡了好久的矩阵快速幂的题过了ヾ(≧▽≦*)o补充知识矩阵与矩阵乘......
  • 罐头鱼AI视频矩阵系统介绍|视频矩阵获客
    智能化管理,轻松批量剪辑短视频!AI系统助力您的视频营销提效!    随着短视频营销的兴起,我们推出了一款AI批量剪辑短视频系统,让视频制作更加智能高效。以下是系统的主要功能特点:首页显示:清晰展示账号登录状态、可绑定账号数量、已绑定账号情况和最近上传的视频素材,让您......
  • 矩阵模板("+" "-" "*")
    structmat{ intn,m; inta[maxn][maxn]; voidzero() { memset(a,0,sizeof(a)); } voidone() { zero(); for(inti=1;i<=n;i++) { a[i][i]=i; } } voidresize(intx,inty) { n=x; m=y; } matoperator+(constmat&A)const { mat......
  • 深度学习服务器版本查看指令集合(显卡,Ubuntu,CUDA,gcc,conda,torch)
    1.查看显卡版本nvidia-smi-a|grepNVIDIA2.查看Ubuntu版本cat/proc/versionuname-a3.查看CUDA版本nvcc-V4.查看gcc版本gcc-v5.查看conda版本conda-V6.查看torch版本print(torch.__version__) #torch版本torch.version.cuda #torch对......
  • 独立按键与矩阵键盘
    独立按键轻触按键:相当于一种电子开关,按下时开关接通,松开时开关断开,实现原理是通过轻触按键内部的金属弹片受力弹动来实现接通与断开。 独立按键在开发板内部的原理图如下:4个独立按键的右端都公共接地,左端引出四个编号,接单片机的I/O口上。当单片机上电时,所有I/O口默认都......