首页 > 其他分享 >tvm实现卷积操作

tvm实现卷积操作

时间:2024-05-26 22:45:07浏览次数:23  
标签:weight conv 卷积 tvm 操作 ic data out

 

https://blog.csdn.net/sinat_31425585/article/details/103797339

import tvm
import numpy as np
import mxnet as mx


def padding(X, ph, pw):
    assert len(X.shape) >= 2
    nh, nw = X.shape[-2], X.shape[-1]
    return tvm.te.compute(
        (*X.shape[0:-2], nh + ph * 2, nw + pw * 2),
        lambda *i: tvm.te.if_then_else(
            tvm.te.any(i[-2] < ph, i[-2] >= nh + ph, i[-1] < pw, i[-1] >= nw + pw),
            0, X[i[:-2] + (i[-2] - ph, i[-1] - pw)]
        ), name='PaddedX'
    )


# 输入size:n
# 卷积核size:k
# 填充size:p
# 步长size:s
def conv_out_size(n, k, p, s):
    return (n - k + 2 * p) // s + 1


def conv(oc, ic, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1):
    # reduction axes
    ric = tvm.te.reduce_axis((0, ic), name='ric')
    rkh = tvm.te.reduce_axis((0, kh), name='rkh')
    rkw = tvm.te.reduce_axis((0, kw), name='rkw')

    # output height and width
    oh = conv_out_size(nh, kh, ph, sh)
    ow = conv_out_size(nw, kw, pw, sw)

    # pad x and then conpute y
    X = tvm.te.placeholder((ic, nh, nw), name='x')
    K = tvm.te.placeholder((oc, ic, kh, kw), name='k')
    # 对输入填充
    PaddedX = padding(X, ph, pw) if ph * pw != 0 else X
    Y = tvm.te.compute(
        (oc, oh, ow),
        lambda c, i, j: tvm.te.sum(
            PaddedX[ric, i * sh + rkh, j * sw + rkw] * K[c, ric, rkh, rkw],
            axis=[ric, rkh, rkw]
        ), name='Y'
    )

    return X, K, Y, PaddedX


def get_conv_data(oc, ic, n, k, p=0, s=1, constructor=None):
    np.random.seed(0)
    data = np.random.normal(size=(ic, n, n)).astype('float32')
    weight = np.random.normal(size=(oc, ic, k, k)).astype('float32')
    on = conv_out_size(n, k, p, s)
    out = np.empty((oc, on, on), dtype='float32')
    if constructor:
        data, weight, out = (constructor(x) for x in [data, weight, out])

    return data, weight, out


oc, ic, n, k, p, s = 4, 6, 12, 3, 1, 1
X, K, Y, _ = conv(oc, ic, n, n, k, k, p, p, s, s)
sch = tvm.te.create_schedule(Y.op)
mod = tvm.build(sch, [X, K, Y])
print(tvm.lower(sch, [X, K, Y], simple_mode=True))

data, weight, out = get_conv_data(oc, ic, n, k, p, s, tvm.nd.array)
mod(data, weight, out)


def get_conv_data_mxnet(oc, ic, n, k, p, s, ctx='cpu'):
    ctx = getattr(mx, ctx)()
    data, weight, out = get_conv_data(oc, ic, n, k, p, s,
                                      lambda x: mx.nd.array(x, ctx=ctx))
    data, out = data.expand_dims(axis=0), out.expand_dims(axis=0)
    bias = mx.nd.zeros(out.shape[1], ctx=ctx)
    return data, weight, bias, out


def conv_mxnet(data, weight, bias, out, k, p, s):
    mx.nd.Convolution(data, weight, bias, kernel=(k, k), stride=(s, s),
                      pad=(p, p), num_filter=out.shape[1], out=out)


data, weight, bias, out_mx = get_conv_data_mxnet(oc, ic, n, k, p, s)
conv_mxnet(data, weight, bias, out_mx, k, p, s)
np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5)

 

标签:weight,conv,卷积,tvm,操作,ic,data,out
From: https://www.cnblogs.com/xiaochouk/p/18214440

相关文章

  • charles常用操作
    参考:https://www.cnblogs.com/xiaocainiao920/p/8073073.html      charles修改请求体内容          重发网络请求&模拟慢速网络&过滤网络请求 ......
  • 嵌入式实时操作系统笔记3:FreeRTOS移植(STM32F407)_编写简单的FreeRTOS任务例程
    上文讲到UC/OSIII系统的移植,那篇文章是失败了的,网络上的资料真是层次不清,多有遗漏步骤,导致单片机连操作系统的初始化都卡在那,这次换个赛道,学FreeRTOS吧......今日任务如标题所示:FreeRTOS移植(STM32F407)_编写简单的FreeRTOS任务例程文章提供测试代码讲解、完整工程下载、测......
  • 【考研数据结构知识点详解及整理——C语言描述】第二章线性表的定义和基本操作
    25计算机考研,数据结构知识点整理(内容借鉴了王道408+数据结构教材),还会不断完善所整理的内容,后续的内容也会不断更新(可以关注),若有错误和不足欢迎各位朋友指出!目录 一.线性表的定义二.线性表的基本操作一.线性表的定义(1)线性表是具有相同数据类型的n(n>0)个数据元素的有......
  • UNiX强大的操作系统和编程环境
    Android设计模式一:EIT造型什么是EIT 造型?EIT造型,一种比类的范围更大,比模式(Pattern)稍微小的一种新的代码造型。造型的模型EIT造形是一种基本的结构(Structure),一种概念(Concept);我们称它为”EIT造形(Form)”。参考:https://www.cnblogs.com/myEIT/articles/3294583.html......
  • 操作系统学习
     Ubuntu(乌班图)、RedHat(红帽)、CentOS、Debain[蝶变]、Fedora、SuSE、OpenSUSE3、Linux和Windows区别 4、Linux和Unix关系UNIX是Linux的父亲"这个说法更怡当。之所以要介绍它们的关系,是因为要告诉读者,在学习的时候,其实Linux与UNIX有很多的共通之处,简单地说,如果你......
  • wetool企业版使用教程及下载方式 微兔该如何使用 wetool还能用吗 wetool扳手工具wetoo
    今天给大家推荐一款我们目前在使用的电脑群发工具掘金小蜜,不仅可以无限多开,方便你同时管理多个账号,群发功能更是十分强大,轻松释放你的双手。掘金小蜜(只支持Win7及以上操作系统,没有推Mac版和手机客户端。可直接可直接复制链接网页下载  lhttps://jjxx.lanzouo.com/s/jjxm......
  • 微服务中的鉴权操作详解(附代码)
    微服务架构中的鉴权是确保系统安全的重要部分,主要用于验证请求者的身份并授权其访问特定资源。鉴权的基本概念认证(Authentication):验证用户或服务的身份。授权(Authorization):决定认证通过的用户或服务可以访问哪些资源。常用鉴权策略API密钥:简单但安全性较低,适用于内......
  • Python面试宝典:Python中与数据库连接和操作相关的面试笔试题(1000加面试笔试题助你轻松
    Python面试宝典:1000加python面试题助你轻松捕获大厂Offer【第二部分:Python高级特性:第十五章:数据库编程:第一节:数据库连接和操作】第十五章:数据库编程第一节:数据库连接和操作数据库API规范:DB-API使用SQLite数据库使用MySQL数据库使用ORM工具注意事项python中和......
  • CATIA入门操作——萌新宝宝遇到的奇奇怪怪的问题解决,持续更新中。。。
    目录引出发生肾么事了??鼠标中键旋转不了解决:特征树不显示参数关系我的窗口去哪了?插曲:草图工具的调出插曲:颜色工具栏显示弹窗警告警告:创建约束是临时的操作技巧技巧:快速隐藏不相关元素工具栏怎么变成水平?总结异形弹簧新建几何体草图编辑,画一条样条线进行扫掠,圆心和半......
  • JDBC & 数据库连接池:详述Java 数据库操作的基础,数据库连接池的使用以及原理,比较常用数
    JDBC基础 JDBC的定义和目的 JDBC(JavaDatabaseConnectivity)是一个用于执行SQL语句的JavaAPI,可以与多种关系数据库进行交互,这的API由一组用Java语言编写的类和接口组成。 JDBC鼓励供应商使用JDBC驱动程序,该驱动程序可以通过数据库管理系统的客户机接口与各个数......