ControlNet-trt优化总结2:使用TRT-API从零构建ControlNet网络
在上节讲到,可以通过手动搭建trt网络的方式来重新构造controlnet网络,这样可以避免onnx中间转换过程中的精度损失,也可避免onnx中间转化时的算子被拆解的细碎的情况,对于不支持的算子,也可通过添加插件的方式添加不支持的算子。
基础概念
tensorrt.INetworkDefinition: 网络结构定义对象,可以由解析器解析得到,或者由TensorRT API构建而成
tensorrt.Builder: 根据NetworkDefinition和相应的BuilderConfig生成CudaEngine,CudaEngine是build好的二进制计算图
tensorrt.IExecutionContext: 根据CudaEngine生成IExecutionContext,每个CudaEngine可以生成多个ExecutionContext
注意:
- 下面的network一般是指tensorrt.INetworkDefinition对象。
- x有两种情况,一种是tensorrt.ITensor对象,多见于第一次输入,另外一种是tensorrt.ILayer对象,多见于中间层输入,tensorrt.ITensor可以视为计算图的边,tensorrt.ILayer可以视为计算图的节点。
- 所有算子都需要传入weight_map和其参数名称,其返回值都是tensorrt.ILayer对象。
常用TRT接口函数
add_input(self: tensorrt.tensorrt.INetworkDefinition,
name: str,
dtype: tensorrt.tensorrt.DataType,
shape: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ITensor
功能:为网络添加一个输入层
参数: name - 层的名字
dtype - tensor的数据类型,如trt.float32
shape - tensor的形状,必须小于2^30个元素
返回值: 一个新的tensor
add_scale(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
mode: tensorrt.tensorrt.ScaleMode,
shift: tensorrt.tensorrt.Weights ,
scale: tensorrt.tensorrt.Weights ,
power: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IScaleLayer
功能:控制每个元素缩放大小,计算公式为$output=(input*scale+shift)^{power}$
参数 : input - 输入tensor,最少有三个维度
mode - 缩放的模式,如trt.ScaleMode.UNIFORM,表示作用于每一个元素
shift - Weights变量,公式中的shift值
scale - Weights变量,公式中的scale值
power - Weights变量,公式中的power值
如果Weights变量可以得到,那么Weights变量的shape与mode模式相关:
UNIFORM:形状等于1
CHANNEL:形状为通道的维度
ELEMENTWISE:形状与input的形状相同
返回值: 一个新的layer或None
add_slice(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
start: tensorrt.tensorrt.Dims,
shape: tensorrt.tensorrt.Dims,
stride: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ISliceLayer
功能:tensor切片
参数 : input - 输入tensor
start - 起始index
shape - 输出shape
stride - 切片步长
返回值: 一个新的layer或None
add_constant(self: tensorrt.tensorrt.INetworkDefinition,
shape: tensorrt.tensorrt.Dims,
weights: tensorrt.tensorrt.Weights) → tensorrt.tensorrt.IConstantLayer
功能:添加一个常数层,可以把weight对象转变为layer进而变为tensor
参数 : shape - 形状
weights - weight对象
返回值: 一个新的layer或None
add_elementwise(self: tensorrt.tensorrt.INetworkDefinition,
input1: tensorrt.tensorrt.ITensor,
input2: tensorrt.tensorrt.ITensor,
op: tensorrt.tensorrt.ElementWiseOperation) → tensorrt.tensorrt.IElementWiseLayer
功能:二元操作
参数: input1(input2) - 输入tensor,形状必须相等
op - 二元操作符,在ElementWiseOperation中,如:
trt.ElementWiseOperation.PROD(乘积)
trt.ElementWiseOperation.SUM(加法)
返回值: 一个新的layer或None
add_unary(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
op: tensorrt.tensorrt.UnaryOperation) → tensorrt.tensorrt.IUnaryLayer
功能:一元操作
参数: input1 - 输入tensor,
op - 一元操作符,在UnaryOperation中,如:
trt.UnaryOperation.EXP(自然指数)
trt.UnaryOperation.LOG(自然对数)
返回值: 一个新的layer或None
add_convolution(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
num_output_maps: int,
kernel_shape: tensorrt.tensorrt.DimsHW,
kernel: tensorrt.tensorrt.Weights,
bias: tensorrt.tensorrt.Weights = None)→ tensorrt.tensorrt.IConvolutionLayer
功能:添加一个2D的卷积
参数: input - 输入Tensor,4维张量
num_output_maps - 输出特征图数量,也即后一层的channel
kernel_shape - 卷积核大小
kernel - 卷积核的数据
bias - 卷积bias的数据
返回值: 一个新的layer或None
add_activation(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
type: tensorrt.tensorrt.ActivationType) → tensorrt.tensorrt.IActivationLayer
功能:添加激活层,进行逐元素的激活操作,输出形状大小和输入形状大小一致
参数: input – 输入tensor
type – 对应的激活类型,RELU、SIGMOID、TANH、LEAKY_RELU等,参考tensorrt.ActivationType。
返回值:一个新的layer或None
add_normalization(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
scale: tensorrt.tensorrt.ITensor,
bias: tensorrt.tensorrt.ITensor,
axesMask: int)→ tensorrt.tensorrt.INormalizationLayer
功能:添加一个归一化层,执行$Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B$,trt内部实际上是使用instancenorm来实现的,有些时候需要自己手写替换
参数: input – 输入Tensor
scale – 归一化的sacle放缩参数
bias – 归一化的bias参数
axesMask – 进行mean操作的axes,以(1<<i)位压缩的方式进行传递
返回值: 一个新的layer或None
add_matrix_multiply(self: tensorrt.tensorrt.INetworkDefinition,
input0: tensorrt.tensorrt.ITensor,
op0: tensorrt.tensorrt.MatrixOperation,
input1: tensorrt.tensorrt.ITensor,
op1: tensorrt.tensorrt.MatrixOperation) → tensorrt.tensorrt.IMatrixMultiplyLayer
功能: 添加一个一个矩阵乘积运算,分为4种情况,矩阵矩阵、矩阵向量、向量矩阵和向量向量
参数: input0 – 第一个矩阵张量
op0 – 处理类型,矩阵处理类型,转置或向量
input1 – 第二个矩阵向量
op1 – 处理类型,矩阵处理类型,转置或向量
返回值: 一个新的layer或None
add_shuffle(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.IShuffleLayer
功能:添加一个shuffle层,对应的是transpose核reshape算子
参数: input – 每一层的输入tensor
返回值: 一个新的layer或None
add_softmax(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.ISoftMaxLayer
功能:添加一个softmax层,按照axes方向进行逐通道softmax操作,axes是位压缩的mask
参数: input – 输入的Tensor
返回值: 一个新的layer或None
add_gather(self: tensorrt.tensorrt.INetworkDefinition,
input: tensorrt.tensorrt.ITensor,
indices: tensorrt.tensorrt.ITensor,
axis: int)→ tensorrt.tensorrt.IGatherLayer
功能:添加一个gather层,按照axis方向,在indices上取相应数据,
参数: input – 输入张量
indices – index序列来产生output张量
axis – gather的方向,不能是batch方向
返回值:一个新的layer或None
add_einsum
功能:添加一个爱因斯坦算子层,与einsum相对应,主要用于矩阵乘法
参数: inputs – 输入张量
equation – 爱因斯坦等式
返回值: 一个新的layer或None
关键TRT算子
卷积算子
由于trt原生支持conv操作,所以这里调用的add_convolution函数直接计算,不过需要注意的是conv也可接受第一层的原始输入。
def conv(network, weight_map, x, ch, pre, kernel, padding, stride):
x = network.add_convolution(
input=x if isinstance(x, trt.ITensor) else x.get_output(0),
num_output_maps=ch,
kernel_shape=(kernel, kernel),
kernel=weight_map['{}.weight'.format(pre)],
bias=weight_map['{}.bias'.format(pre)])
assert x
x.padding = (padding, padding)
x.stride = (stride, stride)
return x
激活算子
SILU算子被拆分为了SIGMOID和PROD两个操作,实际上和onnx导出结果基本一致。
def silu(network, x):
y = network.add_activation(x.get_output(0), trt.ActivationType.SIGMOID)
assert y
x = network.add_elementwise(x.get_output(0), y.get_output(0), trt.ElementWiseOperation.PROD)
return x
归一化算子
这里groupnorm调用了plugin插件,通过PluginField定义了epsilon和bSwish两个属性参数,分别为误差及是否使用Swish激活函数。
其输入有上一层的输入、weights以及bias,输出的是groupnorm归一化后的值。
import ctypes
ctypes.CDLL('./trt/libmyplugins.so.1', mode=ctypes.RTLD_GLOBAL)
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
gn_plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
def group_norm(network, weight_map, h, pre, epsilon=EPS, silu=False):
ch = h.get_output(0).shape[1]
# plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
plugin_creator = gn_plugin_creator
s = network.add_constant([1, ch, 1, 1], weight_map['{}.weight'.format(pre)])
b = network.add_constant([1, ch, 1, 1], weight_map['{}.bias'.format(pre)])
eps_attr = trt.PluginField("epsilon", np.array([epsilon], dtype=np.float32), type=trt.PluginFieldType.FLOAT32)
silu_attr = trt.PluginField("bSwish", np.array([1 if silu else 0], dtype=np.int32), type=trt.PluginFieldType.INT32)
field_collection = trt.PluginFieldCollection([eps_attr, silu_attr])
plugin = plugin_creator.create_plugin(name='{}.group_norm'.format(pre), field_collection=field_collection)
n = network.add_plugin_v2(inputs=[h.get_output(0), s.get_output(0), b.get_output(0)], plugin=plugin)
return n
这里layer_norm执行的计算如下:
Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B
在不同axes执行的结果实际上是不一样的,这里axesMask的设置实际上是倒数第3维方向上进行归一化,对于seq人物,第一维是batch,第二维是seq长度。
def layer_norm(network, weight_map, h, pre, epsilon=EPS):
scale_np = weight_map['{}.weight'.format(pre)]
ch = scale_np.shape[0]
scale = network.add_constant([1, 1, ch], scale_np)
bias_np = weight_map['{}.bias'.format(pre)]
bias = network.add_constant([1, 1, ch], bias_np)
n = network.add_normalization(
h.get_output(0),
scale=scale.get_output(0),
bias=bias.get_output(0),
axesMask=1 << 2)
assert n
n.epsilon = epsilon
return n
Attention算子
因为Trt不直接支持4维矩阵的乘加运算,所以HW进行了合并。这里MHA是8个head,在计算时时合并batch进行计算的,所以就有以下的转化。
[2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]
在具体运算上,qkv的计算是由矩阵乘加得到的,这点有可优化的点,可以将3个乘积一起计算,而不是分开来进行计算,更利于并行。
而qk乘积部分则是由add_einsum计算得到的,随后softmax之后的结果与v进行乘积,需要注意的是需要将最终结果还原到[2, h * w, c]。
接下来的部分便是一个残差连接,得到并输出最终结果。
def self_attention(network, weight_map, i, ch, x):
heads = 8
dim_head = ch / heads
scale = dim_head ** -0.5
wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_q.weight'.format(i)])
wk = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_k.weight'.format(i)])
wv = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_v.weight'.format(i)])
q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
k = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wk.get_output(0), trt.MatrixOperation.TRANSPOSE)
v = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wv.get_output(0), trt.MatrixOperation.TRANSPOSE)
# q [2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (2, -1, 8, ch // 8)
q.second_transpose = trt.Permutation([0, 2, 1, 3])
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (16, -1, ch // 8)
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (2, -1, 8, ch // 8)
k.second_transpose = trt.Permutation([0, 2, 1, 3])
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (16, -1, ch // 8)
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (2, -1, 8, ch // 8)
v.second_transpose = trt.Permutation([0, 2, 1, 3])
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (16, -1, ch // 8)
s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
print(s.get_output(0).shape)
s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
scale=trt.Weights(np.array([scale], np.float32)))
s = network.add_softmax(s.get_output(0))
s.axes = 1<<2
out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
# [16, h * w, d] -> [2, 8, h * w, d] -> [2, h * w, 8, d] -> [2, h * w, c]
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, 8, -1, ch // 8)
out.second_transpose = trt.Permutation([0, 2, 1, 3])
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, -1, ch)
# to_out
outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.weight'.format(i)])
outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.bias'.format(i)])
out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
outw.get_output(0), trt.MatrixOperation.TRANSPOSE)
out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)
return out
cross attention与self attention算子类似,区别在于其kv是从context中获取,这里的context是上一层或上一次context计算的结果,而只有q是weight和上一层计算得到的结果。
def cross_attention(network, weight_map, i, ch, x, context):
heads = 8
dim_head = ch / heads
scale = dim_head ** -0.5
wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_q.weight'.format(i)])
q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
# [2, h*w, c]
dim = ch // 8
k = network.add_slice(context['context'],
trt.Dims([0, 0, 8 * context['start']]),
trt.Dims([2, 77, ch]),
trt.Dims([1, 1, 1]))
v = network.add_slice(context['context'],
trt.Dims([0, 0, 8 * (context['start'] + dim)]),
trt.Dims([2, 77, ch]),
trt.Dims([1, 1, 1]))
context['start'] += 2 * dim
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (2, -1, 8, ch // 8)
q.second_transpose = trt.Permutation([0, 2, 1, 3])
q = network.add_shuffle(q.get_output(0))
q.reshape_dims = (16, -1, ch // 8)
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (2, -1, 8, ch // 8)
k.second_transpose = trt.Permutation([0, 2, 1, 3])
k = network.add_shuffle(k.get_output(0))
k.reshape_dims = (16, -1, ch // 8)
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (2, -1, 8, ch // 8)
v.second_transpose = trt.Permutation([0, 2, 1, 3])
v = network.add_shuffle(v.get_output(0))
v.reshape_dims = (16, -1, ch // 8)
s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
print(s.get_output(0).shape)
# scale = network.add_constant((1, 1, 1), np.array([scale], np.float32))
# s = network.add_elementwise(s.get_output(0), scale.get_output(0), trt.ElementWiseOperation.PROD)
s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
scale=trt.Weights(np.array([scale], np.float32)))
s = network.add_softmax(s.get_output(0))
s.axes = 1<<2
out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, 8, -1, ch // 8)
out.second_transpose = trt.Permutation([0, 2, 1, 3])
out = network.add_shuffle(out.get_output(0))
out.reshape_dims = (2, -1, ch)
# to_out
outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.weight'.format(i)])
outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.bias'.format(i)])
out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
outw.get_output(0), trt.MatrixOperation.TRANSPOSE)
out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)
return out
这里把ffn同样归总到attention算子中,有一次全连接和一个gelu激活函数,需要注意的是乘加结果是分开来算的。
这里add_unary是一元算子,主要进行指数运算。
def feed_forward(network, weight_map, i, ch, x):
w1 = network.add_constant((1, ch * 8, ch), weight_map['{}.transformer_blocks.0.ff.net.0.proj.weight'.format(i)])
b1 = network.add_constant((1, 1, ch * 8), weight_map['{}.transformer_blocks.0.ff.net.0.proj.bias'.format(i)])
n = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
w1.get_output(0), trt.MatrixOperation.TRANSPOSE)
n = network.add_elementwise(n.get_output(0), b1.get_output(0), trt.ElementWiseOperation.SUM)
hw = n.get_output(0).shape[1]
# w = n.get_output(0).shape[3]
n1 = network.add_slice(n.get_output(0), trt.Dims([0, 0, 0]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
n2 = network.add_slice(n.get_output(0), trt.Dims([0, 0, ch * 4]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
# gelu
e = network.add_scale(n2.get_output(0), mode=trt.ScaleMode.UNIFORM, scale=trt.Weights(np.array([2 ** -0.5], np.float32)))
e = network.add_unary(e.get_output(0), trt.UnaryOperation.ERF)
e = network.add_scale(e.get_output(0), mode=trt.ScaleMode.UNIFORM,
scale=trt.Weights(np.array([0.5], np.float32)),
shift=trt.Weights(np.array([0.5], np.float32)))
n = network.add_elementwise(n2.get_output(0), e.get_output(0), trt.ElementWiseOperation.PROD)
n = network.add_elementwise(n.get_output(0), n1.get_output(0), trt.ElementWiseOperation.PROD)
w2 = network.add_constant((1, ch, ch * 4), weight_map['{}.transformer_blocks.0.ff.net.2.weight'.format(i)])
b2 = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.ff.net.2.bias'.format(i)])
n = network.add_matrix_multiply(n.get_output(0), trt.MatrixOperation.NONE,
w2.get_output(0), trt.MatrixOperation.TRANSPOSE)
n = network.add_elementwise(n.get_output(0), b2.get_output(0), trt.ElementWiseOperation.SUM)
return n
关键模块
transformer模块
这里基础的transformer就不再详细探讨,标准的attn1-attn2-ffn的过程,需要注意的是trt不支持4维操作,前后要多一次reshape操作。
def basic_transformer(network, weight_map, i, ch, x, context):
H = x.get_output(0).shape[2]
W = x.get_output(0).shape[3]
# n c h w -> b (h w) c
x = network.add_shuffle(x.get_output(0))
x.first_transpose = trt.Permutation([0, 2, 3, 1])
x.reshape_dims = (2, -1, ch)
# attn1
n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm1'.format(i))
attn1 = self_attention(network, weight_map, i, ch, n)
x = network.add_elementwise(attn1.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
# attn2
n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm2'.format(i))
attn2 = cross_attention(network, weight_map, i, ch, n, context)
x = network.add_elementwise(attn2.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
# ff
n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm3'.format(i))
ff = feed_forward(network, weight_map, i, ch, n)
x = network.add_elementwise(ff.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)
# n (h w) c -> n c h w
x = network.add_shuffle(x.get_output(0))
x.first_transpose = trt.Permutation([0, 2, 1])
x.reshape_dims = (2, ch, H, W)
return x
spatial_transformer是在basic_transformer基础上加了两次conv投影。
def spatial_transformer(network, weight_map, i, ch, h, context):
# return h
# norm
n = group_norm(network, weight_map, h, '{}.norm'.format(i), 1e-6)
# proj_in
n = conv(network, weight_map, n, ch, '{}.proj_in'.format(i), 1, 0, 1)
# BasicTransformerBlock
n = basic_transformer(network, weight_map, i, ch, n, context)
# proj_out
n = conv(network, weight_map, n, ch, '{}.proj_out'.format(i), 1, 0, 1)
h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
return h
采样模块
下采样则是卷积操作,上采样则是线性插值操作,zero_convs则是不改变原有特征图大小。
def input_first(network, weight_map, pre, h):
h = conv(network, weight_map, h, 320, '{}.input_blocks.0.0'.format(pre), 3, 1, 1)
return h
def downsample(network, weight_map, i, ch, x):
x = conv(network, weight_map, x, ch, '{}.op'.format(i), 3, 1, 2)
return x
def upsample(network, weight_map, i, ch, x):
x = network.add_resize(x.get_output(0))
x.scales = [1, 1, 2, 2]
x.resize_mode = trt.ResizeMode.NEAREST
x = conv(network, weight_map, x, ch, '{}.conv'.format(i), 3, 1, 1)
return x
def zero_convs(network, weight_map, x, i):
ch = x.get_output(0).shape[1]
x = conv(network, weight_map, x, ch, 'control_model.zero_convs.{}.0'.format(i), 1, 0, 1)
return x
block模块
resblock 是由倒瓶颈结构的卷积块组成的残差连接模块。
def resblock(network, weight_map, embed_weight, i, ch, h, emb):
print('resblock: ', h.get_output(0).shape, '{}.in_layers.0'.format(i))
## in_layers
# group_norm
n = group_norm(network, weight_map, h, '{}.in_layers.0'.format(i), silu=True)
# silu
# n = silu(network, n)
# conv_nd
n = conv(network, weight_map, n, ch, '{}.in_layers.2'.format(i), 3, 1, 1)
print('in_layers: ', n.get_output(0).shape)
## emb_layers
m = network.add_constant([20, ch, 1, 1], embed_weight.pop(0))
m = network.add_gather(m.get_output(0), emb, axis=0)
print('emb_layers: ', m.get_output(0).shape)
n = network.add_elementwise(n.get_output(0), m.get_output(0), trt.ElementWiseOperation.SUM)
## out_layers
n = group_norm(network, weight_map, n, '{}.out_layers.0'.format(i), silu=True)
# n = silu(network, n)
n = conv(network, weight_map, n, ch, '{}.out_layers.3'.format(i), 3, 1, 1)
print('out_layers: ', n.get_output(0).shape)
in_ch = h.get_output(0).shape[1]
if in_ch != ch:
# skip_connection
h = conv(network, weight_map, h, ch, '{}.skip_connection'.format(i), 1, 0, 1)
h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
return h
input_block则是由不同level、不同大小channel的resblock以及spatial_transformer组成的。
middle_block则是resblock和spatial_transformer的组合。
output_blocks与input_block类似,只不过由input_block中的下采样变成了output_blocks中的上采样。
这三个block是unet中的重要组成部分,对应了Unet先下采样到特征状态再上采样到对应图像的过程。
def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
hs = []
h = input_first(network, weight_map, 'control_model', h)
h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)
h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
h.mode = trt.SliceMode.WRAP
hs.append(zero_convs(network, weight_map, h, 0))
# h [2, 320, 32, 48]
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 1
for level, mult in enumerate(channel_mult):
ch = model_channels * mult
for nr in range(num_res_blocks[level]):
pre = 'control_model.input_blocks.{}'.format(index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
hs.append(zero_convs(network, weight_map, h, index))
# ch = mult * model_channels
index = index + 1
if level != len(channel_mult) - 1:
pre = 'control_model.input_blocks.{}'.format(index)
out_ch = ch
h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
hs.append(zero_convs(network, weight_map, h, index))
index = index + 1
# if index == 10:
return hs, h
def input_block(network, weight_map, embed_weight, h, emb, context, model_name):
hs = []
h = input_first(network, weight_map, model_name, h)
h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
h.mode = trt.SliceMode.WRAP
#return h
hs.append(h)
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 1
for level, mult in enumerate(channel_mult):
ch = model_channels * mult
for nr in range(num_res_blocks[level]):
pre = '{}.input_blocks.{}'.format(model_name, index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
hs.append(h)
# ch = mult * model_channels
index = index + 1
if level != len(channel_mult) - 1:
pre = '{}.input_blocks.{}'.format(model_name, index)
out_ch = ch
h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
hs.append(h)
index = index + 1
# if index == 10:
return hs, h
def middle_block(network, weight_map, embed_weight, h, emb, context, model_name):
pre = '{}.middle_block'.format(model_name)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), 1280, h, emb)
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), 1280, h, context)
h = resblock(network, weight_map, embed_weight, '{}.2'.format(pre), 1280, h, emb)
return h
def output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs):
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 0
for level, mult in list(enumerate(channel_mult))[::-1]:
ch = model_channels * mult
for i in range(num_res_blocks[level] + 1):
print(control[-1].shape, hs[-1].shape, len(hs), h.get_output(0).shape)
c = network.add_elementwise(control.pop(), hs.pop(), trt.ElementWiseOperation.SUM)
h = network.add_concatenation([h.get_output(0), c.get_output(0)])
print('output: ', index, h.get_output(0).shape)
pre = 'model.diffusion_model.output_blocks.{}'.format(index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
if level and i == num_res_blocks[level]:
h = upsample(network, weight_map,
'{}.{}'.format(pre, 1 if level == len(channel_mult) - 1 else 2), ch, h)
index = index + 1
print(h.get_output(0).shape, len(hs), len(control), index)
return h
input_block_control是control_net的上半部分,在结构参数上与Unet一样,但是在每一层都添加了zero_convs层学习参数。
def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
hs = []
h = input_first(network, weight_map, 'control_model', h)
h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)
h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
h.mode = trt.SliceMode.WRAP
hs.append(zero_convs(network, weight_map, h, 0))
# h [2, 320, 32, 48]
channel_mult = [1, 2, 4, 4]
num_res_blocks = [2] * 4
model_channels = 320
index = 1
for level, mult in enumerate(channel_mult):
ch = model_channels * mult
for nr in range(num_res_blocks[level]):
pre = 'control_model.input_blocks.{}'.format(index)
h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
print('resblock: ', h.get_output(0).shape)
if level != len(channel_mult) -1:
h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
hs.append(zero_convs(network, weight_map, h, index))
# ch = mult * model_channels
index = index + 1
if level != len(channel_mult) - 1:
pre = 'control_model.input_blocks.{}'.format(index)
out_ch = ch
h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
hs.append(zero_convs(network, weight_map, h, index))
index = index + 1
# if index == 10:
return hs, h
网络构建模块
controlnet
这里h, hint, emb经过input_block_control得到control和h的特征,h经过middle_block得到不同尺度特征的control特征。
def control_net(network, weight_map, embed_weight, h, hint, emb, context):
# #####################
# # time_embed
# #####################
#####################
# input_blocks
#####################
control, h = input_block_control(network, weight_map, embed_weight, h, emb, context, hint)
print(h.get_output(0).shape)
#####################
# middle_blocks
#####################
h = middle_block(network, weight_map, embed_weight, h, emb, context, 'control_model')
h = conv(network, weight_map, h, 1280, 'control_model.middle_block_out.0', 1, 0, 1)
control.append(h)
return control
Unet
Unet的组成相对简单,经过input_block、middle_block和output_blocks得到最终结果,并返回最终状态。
def unet(network, weight_map, embed_weight, h, emb, context, control):
# #####################
# # time_embed
# #####################
#####################
# input_blocks
#####################
hs, h = input_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
print(h.get_output(0).shape)
#####################
# middle_blocks
#####################
h = middle_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
print(h.get_output(0).shape)
h = network.add_elementwise(h.get_output(0), control.pop().get_output(0), trt.ElementWiseOperation.SUM)
#####################
# output_blocks
#####################
h = output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs)
# out
# group_norm
# h = group_norm_sile(network, weight_map, h)
h = group_norm(network, weight_map, h, 'model.diffusion_model.out.0', silu=True)
# silu
# h = silu(network, h)
# conv_nd
h = conv(network, weight_map, h, 4, 'model.diffusion_model.out.2', 3, 1, 1)
return h
参考
- nvidia python api: https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/
- xiatwhu: https://github.com/deeplearning/xiatwhu/trt2023