一、试题题面
计算两个稀疏矩阵相乘,输出相乘的结果
【输入输出约定】
输入:
第一行输入三个正整数p、q、r,表示p×q和q×r的两个矩阵相乘;(约定0<p,q,r≤1000)
然后是第一个矩阵的输入,首先是一个整数m,表示矩阵一有m个非零元素;然后是m行,每行三个整数i,j,d,表示第i行,第j列的元素为d(约定行号、列号从1开始);这m行是按照行号-列号递增的方式排列的,即第一行、第二行、…每一行中,列号递增排列。
第二个矩阵的输入方法同上。
输出:
首先是一个整数cnt,表示结果矩阵有cnt个非零元素;然后是cnt行,每行三个整数,分别是非零元素的行号、列号、和数据值。要求这些数据按照行号-列号递增的方式排列的。
【测试数据样例】
输入:
4 4 4
4
1 4 7
2 1 2
3 3 3
4 1 1
3
1 2 2
3 3 -5
4 1 3
输出:
4 4 4
4
1 4 7
2 1 2
3 3 3
4 1 1
3
1 2 2
3 3 -5
4 1 3
样例解释:
二、试题理解
稀疏矩阵,援引百度百科介绍如下“在矩阵中,若数值为0的元素数目远远多于非0元素的数目,并且非0元素分布没有规律时,则称该矩阵为稀疏矩阵”。如下给出了一个较为稀疏的矩阵,可以较为直观的感受。
在没有前提条件(或矩阵是“稠密矩阵”)时,矩阵的存储和运算都相对直观,我们可以使用一个二维数组(如m[p][q])来存储一个矩阵,并简单使用循环迭代的方式计算矩阵乘法的结果,直观上空间复杂度为,时间复杂度为,对于相对“稠密”的矩阵也可以获得相对良好的效果。
但是对于稀疏矩阵,即矩阵中多数元素为0的情形,存储空间与运算时间都有进一步压缩的可能。如果仍然按照上述处理一般矩阵的方法处理本题,我们可以注意到二维数组中多数元素存储为0,且循环迭代中大量运算都在计算0×0或0×非零元素。
于是,可以想到我们仅存储矩阵中非零元素的坐标以及数值,并通过遍历这一系列坐标与数值实现矩阵乘法。两个稀疏矩阵相乘得到的矩阵同样相对稀疏,也可以使用相同的方式存储。
我们可以通过链式存储结构实现上述设想。由于链式结构,我们需要维护一个指针指向下一个结点。而直观上两个矩阵相乘时一个矩阵按行遍历,另一个矩阵按列遍历,因此在这里我们需要维护两个指针nextInRow与nextInColumn,分别指向同行中下一个结点与同列中下一个结点,鉴于此有些人习惯于将其称为十字链表。
三、实现细节
结点Node的定义如上,包含了3个整型变量data,row,col与两个指针nextInRow与nextInColumn。
struct Node{
int data;
int row,col;
Node* nextInRow;
Node* nextInColumn;
Node(int data, int row, int col): data(data),row(row),col(col),nextInRow(nullptr),nextInColumn(nullptr){}
};
对于矩阵类定义如下:包含了两个整型变量rowN,colN表示矩阵的行数与列数,以及指向每行或每列第一个非零元素的指针数组rows[MAXN]和columns[MAXN]。由于输入数据需要被添加至新定义的矩阵中,插入操作相对频繁,为提高效率同时维护了rows_tail[MAXN]和columns_tail[MAXN],分别指向每行或每列最后一个非零元素。
class MATRIX{
struct Node* columns[MAXN];
struct Node* rows[MAXN];
struct Node* columns_tail[MAXN];
struct Node* rows_tail[MAXN];
int rowN,colN;
public:
MATRIX(int m, int n): rowN(m), colN(n){
for(int i = 0; i < MAXN; i++){
columns[i] = nullptr;
rows[i] = nullptr;
columns_tail[i] = nullptr;
rows_tail[i] = nullptr;
}
}
void insert(int data, int row, int col);
mul_output* multiple(MATRIX * m1, MATRIX *m2);
void print_MATRIX(MATRIX *m);
};
此外,由于我们需要同时输出结果矩阵的非零元素数量cnt与非零元素的行号、列号、和数据值,我们新定义一个结构体作为矩阵相乘函数的返回值的类型,其包含了结果矩阵与其中非零元素的个数。
struct mul_output{
MATRIX *m;
int cnt;
};
在这里我们要注意先给出classMATRIX的一个简略定义,否则可能会遇到编译错误。
插入操作函数定义如下。其本质就是将输入的i,j,data新定义一个结点并将其加入相应链表的末尾并维持上述几个指针。
void MATRIX::insert(int data, int row, int col){
Node *p = new Node(data, row, col);
if(rows[row] == nullptr){
rows[row] = p;
rows_tail[row] = p;
}
else{
rows_tail[row]->nextInRow = p;
rows_tail[row] = p;
}
if(columns[col] == nullptr){
columns[col] = p;
columns_tail[col] = p;
}
else{
columns_tail[col]->nextInColumn = p;
columns_tail[col] = p;
}
return ;
}
最后给出矩阵相乘的实现函数,直观上依然是一个矩阵按行遍历,另一个矩阵按列遍历;但是遍历时会“跳过”零元素。实现上,当遍历时对应的两个结点满足p->col == q->row时我们将两个结点对应的值相乘。当一行/一列遍历完后,如果发现结果矩阵相应位置元素不为0,则向结果矩阵中插入结点。我们的遍历也保证了新节点插入结果矩阵时有与输入时一致的性质,于是可以复用上述的插入函数。
mul_output* MATRIX::multiple(MATRIX *m1, MATRIX *m2){
MATRIX *m3 = new MATRIX(m1->rowN, m2->colN);
int cnt_m3 = 0;
mul_output *mo = new mul_output;
for(int i = 1; i <= m1->rowN; i++){
for(int j = 1; j <= m2->colN; j++){
int sum = 0;
Node *p = m1->rows[i];
Node *q = m2->columns[j];
while(p && q){
if(p->col == q->row){
sum += p->data * q->data;
p = p->nextInRow;
q = q->nextInColumn;
}
else if(p->col < q->row){
p = p->nextInRow;
}
else{
q = q->nextInColumn;
}
}
if(sum != 0){
m3->insert(sum, i, j);
cnt_m3++;
}
}
}
mo->m = m3;
mo->cnt = cnt_m3;
return mo;
}
最后,依据约定的输入与已经完成的函数给出了主函数。
int main(){
int p , q , r , cnt , data_temp , row_temp , column_temp;
scanf("%d %d %d", &p, &q, &r);
MATRIX *m1 = new MATRIX(p, q);
MATRIX *m2 = new MATRIX(q, r);
scanf("%d", &cnt);
for(int i = 0; i < cnt; i++){
scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
m1->insert(data_temp, row_temp, column_temp);
}
scanf("%d", &cnt);
for(int i = 0; i < cnt; i++){
scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
m2->insert(data_temp, row_temp, column_temp);
}
MATRIX *m3 = new MATRIX(p, r);
mul_output *mo = m3->multiple(m1, m2);
m3 = mo->m;
printf("%d\n", mo->cnt);
m3->print_MATRIX(m3);
return 0;
}
于是,我们就实现了两个稀疏矩阵的乘法,完整代码在第四部分直接给出,供各位读者参考批评。
四、完整实现代码
#include <iostream>
#define Past_Dream FW
using namespace std;
const int MAXN = 1005;
struct Node{
int data;
int row,col;
Node* nextInRow;
Node* nextInColumn;
Node(int data, int row, int col): data(data),row(row),col(col),nextInRow(nullptr),nextInColumn(nullptr){}
};
class MATRIX;
struct mul_output{
MATRIX *m;
int cnt;
};
class MATRIX{
struct Node* columns[MAXN];
struct Node* rows[MAXN];
struct Node* columns_tail[MAXN];
struct Node* rows_tail[MAXN];
int rowN,colN;
public:
MATRIX(int m, int n): rowN(m), colN(n){
for(int i = 0; i < MAXN; i++){
columns[i] = nullptr;
rows[i] = nullptr;
columns_tail[i] = nullptr;
rows_tail[i] = nullptr;
}
}
void insert(int data, int row, int col);
mul_output* multiple(MATRIX * m1, MATRIX *m2);
void print_MATRIX(MATRIX *m);
};
void MATRIX::insert(int data, int row, int col){
Node *p = new Node(data, row, col);
if(rows[row] == nullptr){
rows[row] = p;
rows_tail[row] = p;
}
else{
rows_tail[row]->nextInRow = p;
rows_tail[row] = p;
}
if(columns[col] == nullptr){
columns[col] = p;
columns_tail[col] = p;
}
else{
columns_tail[col]->nextInColumn = p;
columns_tail[col] = p;
}
return ;
}
mul_output* MATRIX::multiple(MATRIX *m1, MATRIX *m2){
MATRIX *m3 = new MATRIX(m1->rowN, m2->colN);
int cnt_m3 = 0;
mul_output *mo = new mul_output;
for(int i = 1; i <= m1->rowN; i++){
for(int j = 1; j <= m2->colN; j++){
int sum = 0;
Node *p = m1->rows[i];
Node *q = m2->columns[j];
while(p && q){
if(p->col == q->row){
sum += p->data * q->data;
p = p->nextInRow;
q = q->nextInColumn;
}
else if(p->col < q->row){
p = p->nextInRow;
}
else{
q = q->nextInColumn;
}
}
if(sum != 0){
m3->insert(sum, i, j);
cnt_m3++;
}
}
}
mo->m = m3;
mo->cnt = cnt_m3;
return mo;
}
void MATRIX::print_MATRIX(MATRIX *m){
for(int i = 0; i <= m->rowN; i++){
Node *p = m->rows[i];
while(p){
printf("%d %d %d\n", p->row, p->col, p->data);
p = p->nextInRow;
}
}
}
int main(){
int p , q , r , cnt , data_temp , row_temp , column_temp;
scanf("%d %d %d", &p, &q, &r);
MATRIX *m1 = new MATRIX(p, q);
MATRIX *m2 = new MATRIX(q, r);
scanf("%d", &cnt);
for(int i = 0; i < cnt; i++){
scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
m1->insert(data_temp, row_temp, column_temp);
}
scanf("%d", &cnt);
for(int i = 0; i < cnt; i++){
scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
m2->insert(data_temp, row_temp, column_temp);
}
MATRIX *m3 = new MATRIX(p, r);
mul_output *mo = m3->multiple(m1, m2);
m3 = mo->m;
printf("%d\n", mo->cnt);
m3->print_MATRIX(m3);
return 0;
}
以上代码通过了OJ评测,但是并不保证正确。由于笔者本身是小飞舞,文章中必然会存在些许的错误与不足,衷心恳请您能够批评和指正。
哦?看起来我们写完了稀疏矩阵的乘法。要不给大家唱首歌助助兴吧。“琴瑟愿与~共沐春秋~~滢溪潺潺~炊烟悠悠~~敢请东风~玉成双偶~~遥递佳信~知否知否~~”
去思考,去尝试,去突破,去成长
衷心感谢您的阅读,我们下次再见!
标签:MATRIX,int,题解,矩阵,OJ,乘法,data,col,row From: https://blog.csdn.net/2301_80336585/article/details/142530527