首页 > 其他分享 >Reference only code reading

Reference only code reading

时间:2024-01-31 18:12:49浏览次数:17  
标签:style code self attention module only attn1 forward reading

Reference only code reading 代码逻辑梳理

Analyse reference-only code in controlnet extension of sd webui.

Control Net Hook

The entry point inner controlnet_main_entry looks like:

# def controlnet_main_entry():
self.latest_network = UnetHook(lowvram=is_low_vram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p,
                                 batch_option_uint_separate=batch_option_uint_separate,
                                 batch_option_style_align=batch_option_style_align)

What happened inner hook function?

The core hook part:

def hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p, ...):
    # ...
	model._original_forward = model.forward
	outer.original_forward = model.forward
	model.forward = forward_webui.__get__(model, UNetModel)
  • model is original unet model
  • outer is the new created UnetHook object itself

So the origin unet model’s forward is hijacked by forward_webui and the original forward is saved as UNetHook.original_forward.

So for other controlnet, the control module is running under hooked forward function, saved as total_controlnet_embedding and added to original result during U-Net Middle and Decoder blocks. (Not in Decoder Blocks)

Hack stuff todo with reference-only

Hook

But for reference part, there’s no control data forward process but a new hooked basic transform block in hook function:

# def hook(...):
all_modules = torch_dfs(model)

if need_attention_hijack:
		attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlockSGM)]
		attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
		
		for i, module in enumerate(attn_modules):
		    if getattr(module, '_original_inner_forward_cn_hijack', None) is None:
		        module._original_inner_forward_cn_hijack = module._forward
		    module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
		    module.bank = []
		    module.style_cfgs = []
		    module.attn_weight = float(i) / float(len(attn_modules))

Hacked basic transformer forward

Hack the basic transformer module by a function hacked_basic_transformer_inner_forward .

What dose this function do?

self_attention_context = x_norm1
if outer.attention_auto_machine == AutoMachine.Write:
    if outer.attention_auto_machine_weight > self.attn_weight:
        self.bank.append(self_attention_context.detach().clone())
        self.style_cfgs.append(outer.current_style_fidelity)
if outer.attention_auto_machine == AutoMachine.Read:
    if len(self.bank) > 0:
        style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
        self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))
        self_attn1_c = self_attn1_uc.clone()
        if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
            self_attn1_c[outer.current_uc_indices] = self.attn1(
                x_norm1[outer.current_uc_indices],
                context=self_attention_context[outer.current_uc_indices])
        self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
    self.bank = []
    self.style_cfgs = []

First it only hack self-attention part, for cross-attention module is as-is. And there’re two status: Read/Write.

  • Write: Save the x_norm1 to self.bank
  • Read: Cat current self_attention_context and value in self.bank (previous x_norm1) and use them as the context for running attn1. Also, handle the unconditioned part accordingly

Hacked forward

So what's inside the main hijacked forward function related to reference-only?

Get ref_xt from used_hint_cond_latent (the latent control image) and (the latent control image) and then :

  1. Set automachine Write and calling outer.original_forward() with x=ref_xt that means using latent control image to forward and save the result to self.bank
  2. Set automachine Read and run the normal UNet process with real input x

Firstly save controlled context to self.bank and secondly add together as the context for real forward attn1.

As you can see: self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))

标签:style,code,self,attention,module,only,attn1,forward,reading
From: https://www.cnblogs.com/psklf/p/17999391

相关文章

  • Leetcode刷题第六天-回溯
    131:分割回文串链接:131.分割回文串-力扣(LeetCode)for遍历字符串递归切割,切割到字符串尾,单次结束1classSolution:2defpartition(self,s:str)->List[List[str]]:3if(nots):return[]4re=[]5self.backtracking(s,[],re,......
  • Codeforces Round 922 (Div. 2)
    CodeforcesRound922(Div.2)比赛链接A.BrickWall思路简单的模拟,要想实现最高的稳定性,就横着放就可以了,因为长度必须大于等于2,所以最后即使不能被2整除,也可以算在里面Code#include<bits/stdc++.h>usingnamespacestd;#defineintlonglongvoidsolve(){ intn,......
  • createDelegatingPasswordEncoder只有在spring 5.1.13才能用吗?
    springboot和spring版本对应关系:SpringBoot1.x支持SpringFramework4.x系列,包括4.0、4.1、4.2和4.3等版本。SpringBoot2.x支持SpringFramework5.x系列,包括5.0、5.1和5.2等版本。 下载springboot会一并下载 ......
  • Codeforces Round 922 (Div.2)
    题目链接点这里CF1918ABrickWallvoidsolve(){lln,m;cin>>n>>m;cout<<n*(m/2)<<endl;}CF1918BMinimizeInversions注意到,当其中一个排列有序时,总的逆序对数量最少()今天找个时间补上证明对于任意一对\(i,j\)位置,其可能的逆序对总......
  • Codeforces Round 922 (Div. 2) A-C
    这次还好,虽然还是不够满意,因为D题没写出来。。A一个明显的贪心,都竖着放就好了#include<bits/stdc++.h>#definelllonglongusingnamespacestd;inlineintread(){ charc=getchar();inta=0,b=1; for(;c<'0'||c>'9';c=getchar())if(c=='-')b=-1; for(;c......
  • FFmpeg和Monibuka拉取rtsp(大华摄像头)视频流时未进行URLCode编码导致提示404等报错
    场景Monibucav4(开源流媒体服务器)在Windows上搭建rtmp服务器并实现拉取rtsp视频流以及转换flv播放:https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/details/132474126Nginx搭建RTMP服务器+FFmpeg实现海康威视摄像头预览:https://blog.csdn.net/BADAO_LIUMANG_QIZHI/articl......
  • Codeforces Round 922 (A-C)
    第一次打Div2,对我来说还是很难,写篇博客记录一下~A题题意:T组输入,每组输入一个n,m,代表nm大小的地板,以1k大小的地砖完全覆盖地板(k>=2,且同一地板中k可以不同)。将水平放置的地砖与垂直放置的地砖相减的值定义为稳定性,求最大的稳定性是多少。思路:尽可能的使得水平放置的地砖多,垂......
  • jmeter 接口测试响应unicode编码为中文显示
    添加一个后置处理器beanshellpostprocessor, public StringresUnicoding(StringresMsg){   String[]asciis=resMsg.split("\\\\u");   StringnativeValue=asciis[0];   try   {      for(inti=1;i<asciis.length;i++......
  • http encode 编码
    uses System.NetEncodingvars:string:=TNetEncoding.URL.Encode('123');//123vars2:string:=TNetEncoding.URL.Encode('abcd');//abcdvars3:string:=TNetEncoding.URL.Encode('ab1cd2');//ab1cd2vars4:string:=TNetEn......
  • AtCoder Beginner Contest 338
    ABC338总结A-Capitalized?翻译给你一个由大写和小写英文字母组成的非空字符串\(S\)。请判断是否满足以下条件:\(S\)的第一个字符是大写字母,其他所有字符都是小写字母。如果满足,输出Yes,否则输出No。分析按题目说的判断即可。code#include<bits/stdc++.h>usingn......