首页 > 其他分享 >【OJ题解-1】稀疏矩阵乘法

【OJ题解-1】稀疏矩阵乘法

时间:2024-09-25 18:50:18浏览次数:19  
标签:MATRIX int 题解 矩阵 OJ 乘法 data col row

一、试题题面

计算两个稀疏矩阵相乘,输出相乘的结果

【输入输出约定】

输入:

第一行输入三个正整数p、q、r,表示p×q和q×r的两个矩阵相乘;(约定0<p,q,r≤1000)

然后是第一个矩阵的输入,首先是一个整数m,表示矩阵一有m个非零元素;然后是m行,每行三个整数i,j,d,表示第i行,第j列的元素为d(约定行号、列号从1开始);这m行是按照行号-列号递增的方式排列的,即第一行、第二行、…每一行中,列号递增排列。

第二个矩阵的输入方法同上。

输出:

首先是一个整数cnt,表示结果矩阵有cnt个非零元素;然后是cnt行,每行三个整数,分别是非零元素的行号、列号、和数据值。要求这些数据按照行号-列号递增的方式排列的。

【测试数据样例】

输入:

4 4 4
4
1 4 7
2 1 2
3 3 3
4 1 1
3
1 2 2
3 3 -5
4 1 3

输出:

4 4 4
4
1 4 7
2 1 2
3 3 3
4 1 1
3
1 2 2
3 3 -5
4 1 3

样例解释:

\begin{pmatrix} 0& 0& 0& 7\\ 2& 0& 0& 0\\ 0& 0& 3& 0\\ 1& 0& 0& 0 \end{pmatrix}\times\begin{pmatrix} 0& 2& 0& 0\\ 0& 0& 0& 0\\ 0& 0& -5& 0\\ 3& 0& 0& 0 \end{pmatrix}=\begin{pmatrix} 21& 0& 0& 0\\ 0& 4& 0& 0\\ 0& 0& -15& 0\\ 0& 2& 0& 0 \end{pmatrix}

二、试题理解

稀疏矩阵,援引百度百科介绍如下“在矩阵中,若数值为0的元素数目远远多于非0元素的数目,并且非0元素分布没有规律时,则称该矩阵为稀疏矩阵”。如下给出了一个较为稀疏的矩阵,可以较为直观的感受。

\begin{pmatrix} 0& 0& 0& 0& 0& 0& 0& 0\\ 0& 0& 0& 0& 0& 0& 0& 0\\ 0& 0& 0& 0& 1& 0& 2& 0\\ 0& 0& 0& 0& 0& 0& 0& 0\\ 0& 0& 6& 0& 0& 0& 0& 0\\ 0& 0& 0& 0& 3& 0& 0& 0\\ 5& 0& 0& 0& 0& 0& 0& 0\\ 0& 0& 0& 0& 0& 0& 0&0 \end{pmatrix}

在没有前提条件(或矩阵是“稠密矩阵”)时,矩阵的存储和运算都相对直观,我们可以使用一个二维数组(如m[p][q])来存储一个矩阵,并简单使用循环迭代的方式计算矩阵乘法的结果,直观上空间复杂度为\Theta(n^{2}),时间复杂度为\Theta(n^{3}),对于相对“稠密”的矩阵也可以获得相对良好的效果。

但是对于稀疏矩阵,即矩阵中多数元素为0的情形,存储空间与运算时间都有进一步压缩的可能。如果仍然按照上述处理一般矩阵的方法处理本题,我们可以注意到二维数组中多数元素存储为0,且循环迭代中大量运算都在计算0×0或0×非零元素。

于是,可以想到我们仅存储矩阵中非零元素的坐标以及数值,并通过遍历这一系列坐标与数值实现矩阵乘法。两个稀疏矩阵相乘得到的矩阵同样相对稀疏,也可以使用相同的方式存储。

我们可以通过链式存储结构实现上述设想。由于链式结构,我们需要维护一个指针指向下一个结点。而直观上两个矩阵相乘时一个矩阵按行遍历,另一个矩阵按列遍历,因此在这里我们需要维护两个指针nextInRow与nextInColumn,分别指向同行中下一个结点与同列中下一个结点,鉴于此有些人习惯于将其称为十字链表。

三、实现细节

结点Node的定义如上,包含了3个整型变量data,row,col与两个指针nextInRow与nextInColumn。

struct Node{
    int data;
    int row,col;
    Node* nextInRow;
    Node* nextInColumn;
    Node(int data, int row, int col): data(data),row(row),col(col),nextInRow(nullptr),nextInColumn(nullptr){}
};

对于矩阵类定义如下:包含了两个整型变量rowN,colN表示矩阵的行数与列数,以及指向每行或每列第一个非零元素的指针数组rows[MAXN]和columns[MAXN]。由于输入数据需要被添加至新定义的矩阵中,插入操作相对频繁,为提高效率同时维护了rows_tail[MAXN]和columns_tail[MAXN],分别指向每行或每列最后一个非零元素。

class MATRIX{
    struct Node* columns[MAXN];
    struct Node* rows[MAXN];
    struct Node* columns_tail[MAXN];
    struct Node* rows_tail[MAXN];
    int rowN,colN;
public:
    MATRIX(int m, int n): rowN(m), colN(n){
        for(int i = 0; i < MAXN; i++){
            columns[i] = nullptr;
            rows[i] = nullptr;
            columns_tail[i] = nullptr;
            rows_tail[i] = nullptr;
        }
    }
    void insert(int data, int row, int col);
    mul_output* multiple(MATRIX * m1, MATRIX *m2);
    void print_MATRIX(MATRIX *m);
};

此外,由于我们需要同时输出结果矩阵的非零元素数量cnt与非零元素的行号、列号、和数据值,我们新定义一个结构体作为矩阵相乘函数的返回值的类型,其包含了结果矩阵与其中非零元素的个数。

struct mul_output{
    MATRIX *m;
    int cnt;
};

在这里我们要注意先给出classMATRIX的一个简略定义,否则可能会遇到编译错误。

插入操作函数定义如下。其本质就是将输入的i,j,data新定义一个结点并将其加入相应链表的末尾并维持上述几个指针。

void MATRIX::insert(int data, int row, int col){
    Node *p = new Node(data, row, col);
    if(rows[row] == nullptr){
        rows[row] = p;
        rows_tail[row] = p;
    }
    else{
        rows_tail[row]->nextInRow = p;
        rows_tail[row] = p;
    }
    if(columns[col] == nullptr){
        columns[col] = p;
        columns_tail[col] = p;
    }
    else{
        columns_tail[col]->nextInColumn = p;
        columns_tail[col] = p;
    }
    return ;
}

最后给出矩阵相乘的实现函数,直观上依然是一个矩阵按行遍历,另一个矩阵按列遍历;但是遍历时会“跳过”零元素。实现上,当遍历时对应的两个结点满足p->col == q->row时我们将两个结点对应的值相乘。当一行/一列遍历完后,如果发现结果矩阵相应位置元素不为0,则向结果矩阵中插入结点。我们的遍历也保证了新节点插入结果矩阵时有与输入时一致的性质,于是可以复用上述的插入函数。

mul_output* MATRIX::multiple(MATRIX *m1, MATRIX *m2){
    MATRIX *m3 = new MATRIX(m1->rowN, m2->colN);
    int cnt_m3 = 0;
    mul_output *mo = new mul_output;
    for(int i = 1; i <= m1->rowN; i++){
        for(int j = 1; j <= m2->colN; j++){
            int sum = 0;
            Node *p = m1->rows[i];
            Node *q = m2->columns[j];
            while(p && q){
                if(p->col == q->row){
                    sum += p->data * q->data;
                    p = p->nextInRow;
                    q = q->nextInColumn;
                }
                else if(p->col < q->row){
                    p = p->nextInRow;
                }
                else{
                    q = q->nextInColumn;
                }
            }
            if(sum != 0){
                m3->insert(sum, i, j);
                cnt_m3++;
            }
        }
    }
    mo->m = m3;
    mo->cnt = cnt_m3;
    return mo;
}

最后,依据约定的输入与已经完成的函数给出了主函数。

int main(){
    int p , q , r , cnt , data_temp , row_temp , column_temp;
    scanf("%d %d %d", &p, &q, &r);
    MATRIX *m1 = new MATRIX(p, q);
    MATRIX *m2 = new MATRIX(q, r);
    scanf("%d", &cnt);
    for(int i = 0; i < cnt; i++){
        scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
        m1->insert(data_temp, row_temp, column_temp);
    }
    scanf("%d", &cnt);
    for(int i = 0; i < cnt; i++){
        scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
        m2->insert(data_temp, row_temp, column_temp);
    }
    MATRIX *m3 = new MATRIX(p, r);
    mul_output *mo = m3->multiple(m1, m2);
    m3 = mo->m;
    printf("%d\n", mo->cnt);
    m3->print_MATRIX(m3);
    return 0;
}

于是,我们就实现了两个稀疏矩阵的乘法,完整代码在第四部分直接给出,供各位读者参考批评。

四、完整实现代码

#include <iostream>
#define Past_Dream FW
using namespace std;
const int MAXN = 1005;
struct Node{
    int data;
    int row,col;
    Node* nextInRow;
    Node* nextInColumn;
    Node(int data, int row, int col): data(data),row(row),col(col),nextInRow(nullptr),nextInColumn(nullptr){}
};
class MATRIX;
struct mul_output{
    MATRIX *m;
    int cnt;
};
class MATRIX{
    struct Node* columns[MAXN];
    struct Node* rows[MAXN];
    struct Node* columns_tail[MAXN];
    struct Node* rows_tail[MAXN];
    int rowN,colN;
public:
    MATRIX(int m, int n): rowN(m), colN(n){
        for(int i = 0; i < MAXN; i++){
            columns[i] = nullptr;
            rows[i] = nullptr;
            columns_tail[i] = nullptr;
            rows_tail[i] = nullptr;
        }
    }
    void insert(int data, int row, int col);
    mul_output* multiple(MATRIX * m1, MATRIX *m2);
    void print_MATRIX(MATRIX *m);
};
void MATRIX::insert(int data, int row, int col){
    Node *p = new Node(data, row, col);
    if(rows[row] == nullptr){
        rows[row] = p;
        rows_tail[row] = p;
    }
    else{
        rows_tail[row]->nextInRow = p;
        rows_tail[row] = p;
    }
    if(columns[col] == nullptr){
        columns[col] = p;
        columns_tail[col] = p;
    }
    else{
        columns_tail[col]->nextInColumn = p;
        columns_tail[col] = p;
    }
    return ;
}
mul_output* MATRIX::multiple(MATRIX *m1, MATRIX *m2){
    MATRIX *m3 = new MATRIX(m1->rowN, m2->colN);
    int cnt_m3 = 0;
    mul_output *mo = new mul_output;
    for(int i = 1; i <= m1->rowN; i++){
        for(int j = 1; j <= m2->colN; j++){
            int sum = 0;
            Node *p = m1->rows[i];
            Node *q = m2->columns[j];
            while(p && q){
                if(p->col == q->row){
                    sum += p->data * q->data;
                    p = p->nextInRow;
                    q = q->nextInColumn;
                }
                else if(p->col < q->row){
                    p = p->nextInRow;
                }
                else{
                    q = q->nextInColumn;
                }
            }
            if(sum != 0){
                m3->insert(sum, i, j);
                cnt_m3++;
            }
        }
    }
    mo->m = m3;
    mo->cnt = cnt_m3;
    return mo;
}
void MATRIX::print_MATRIX(MATRIX *m){
    for(int i = 0; i <= m->rowN; i++){
        Node *p = m->rows[i];
        while(p){
            printf("%d %d %d\n", p->row, p->col, p->data);
            p = p->nextInRow;
        }
    }
}
int main(){
    int p , q , r , cnt , data_temp , row_temp , column_temp;
    scanf("%d %d %d", &p, &q, &r);
    MATRIX *m1 = new MATRIX(p, q);
    MATRIX *m2 = new MATRIX(q, r);
    scanf("%d", &cnt);
    for(int i = 0; i < cnt; i++){
        scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
        m1->insert(data_temp, row_temp, column_temp);
    }
    scanf("%d", &cnt);
    for(int i = 0; i < cnt; i++){
        scanf("%d %d %d", &row_temp, &column_temp, &data_temp);
        m2->insert(data_temp, row_temp, column_temp);
    }
    MATRIX *m3 = new MATRIX(p, r);
    mul_output *mo = m3->multiple(m1, m2);
    m3 = mo->m;
    printf("%d\n", mo->cnt);
    m3->print_MATRIX(m3);
    return 0;
}

以上代码通过了OJ评测,但是并不保证正确。由于笔者本身是小飞舞,文章中必然会存在些许的错误与不足,衷心恳请您能够批评和指正。


哦?看起来我们写完了稀疏矩阵的乘法。要不给大家唱首歌助助兴吧。“琴瑟愿与~共沐春秋~~滢溪潺潺~炊烟悠悠~~敢请东风~玉成双偶~~遥递佳信~知否知否~~”

去思考,去尝试,去突破,去成长

衷心感谢您的阅读,我们下次再见!

标签:MATRIX,int,题解,矩阵,OJ,乘法,data,col,row
From: https://blog.csdn.net/2301_80336585/article/details/142530527

相关文章

  • 题解:CF573D Bear and Cavalry
    因为这是远古题目,所以根据现在的评测机速度,用\(O(nq)\)的做法也是可以过的。也就是说,我们可以每次操作直接修改对应位置上的数字,然后设计一种\(O(n)\)的算法求解答案。这道题类似资源分配型动态规划,所以我们可以设\(dp_i\)表示分配前\(i\)个人的答案。直接写是不行的,我......
  • 题解:AT_abc204_e [ABC204E] Rush Hour 2
    变形的dijkstra。先思考什么情况下需要等待以及等待多长时间最优。我们把题目上的计算方法按照当前的时间\(t\)和通过所需的时间\(f(t)\)列个函数关系:\[f(t)=t+c+\lfloor\frac{d}{t+1}\rfloor\]然后用Desmos画个图可以得到图像(其实就是对勾函数):因为\(c,d\geq0\),所......
  • [湖北省选模拟 2023] 棋圣 / alphago 题解
    很牛的题目啊。-Alex_Wei发现这个操作比较复杂但限制较弱,考虑通过考察“不变的量”来刻画操作。容易发现若为二分图,则初始颜色不同的一定不能移动到一起。又因为在存在环的图上这个限制很弱/目前较难考虑,所以先考虑树的情况,发现答案存在可能取到的上界,令\(c_{i,j}\)为初......
  • Codeforces Round 974 (Div. 3)题解记录
    A.RobinHelps签到模拟,遍历一遍即可,注意没钱时不给钱。\(O(n)\)#include<iostream>#include<set>#include<map>#include<vector>#include<algorithm>#include<bitset>#include<math.h>#include<string>#include<string.h>#......
  • 算法题之图论 [NOIP2001 提高组] Car的旅行路线详细题解
    P1027[NOIP2001提高组]Car的旅行路线这道题的思路呢,就是建个图,然后跑一遍Floyd,比较最小值就可以解决了。but!它每个城市只给三个点(共四个),所以还得计算出第四个点坐标。这里根据矩形的中点公式来表示未知点的坐标:(这个思路源于大佬 _jimmywang_       ......