Reference only code reading 代码逻辑梳理
Analyse reference-only code in controlnet extension of sd webui.
Control Net Hook
The entry point inner controlnet_main_entry
looks like:
# def controlnet_main_entry():
self.latest_network = UnetHook(lowvram=is_low_vram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p,
batch_option_uint_separate=batch_option_uint_separate,
batch_option_style_align=batch_option_style_align)
What happened inner hook function?
The core hook part:
def hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p, ...):
# ...
model._original_forward = model.forward
outer.original_forward = model.forward
model.forward = forward_webui.__get__(model, UNetModel)
- model is original unet model
- outer is the new created UnetHook object itself
So the origin unet model’s forward is hijacked by forward_webui
and the original forward is saved as UNetHook.original_forward.
So for other controlnet, the control module is running under hooked forward function, saved as total_controlnet_embedding and added to original result during U-Net Middle and Decoder blocks. (Not in Decoder Blocks)
Hack stuff todo with reference-only
Hook
But for reference part, there’s no control data forward process but a new hooked basic transform block in hook
function:
# def hook(...):
all_modules = torch_dfs(model)
if need_attention_hijack:
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlockSGM)]
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
if getattr(module, '_original_inner_forward_cn_hijack', None) is None:
module._original_inner_forward_cn_hijack = module._forward
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
module.bank = []
module.style_cfgs = []
module.attn_weight = float(i) / float(len(attn_modules))
Hacked basic transformer forward
Hack the basic transformer module by a function hacked_basic_transformer_inner_forward
.
What dose this function do?
self_attention_context = x_norm1
if outer.attention_auto_machine == AutoMachine.Write:
if outer.attention_auto_machine_weight > self.attn_weight:
self.bank.append(self_attention_context.detach().clone())
self.style_cfgs.append(outer.current_style_fidelity)
if outer.attention_auto_machine == AutoMachine.Read:
if len(self.bank) > 0:
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))
self_attn1_c = self_attn1_uc.clone()
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
self_attn1_c[outer.current_uc_indices] = self.attn1(
x_norm1[outer.current_uc_indices],
context=self_attention_context[outer.current_uc_indices])
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
self.bank = []
self.style_cfgs = []
First it only hack self-attention part, for cross-attention module is as-is. And there’re two status: Read/Write.
- Write: Save the
x_norm1
to self.bank - Read: Cat current self_attention_context and value in
self.bank
(previous x_norm1) and use them as the context for running attn1. Also, handle the unconditioned part accordingly
Hacked forward
So what's inside the main hijacked forward function related to reference-only?
Get ref_xt
from used_hint_cond_latent
(the latent control image) and (the latent control image) and then :
- Set automachine Write and calling
outer.original_forward()
withx=ref_xt
that means using latent control image to forward and save the result toself.bank
- Set automachine Read and run the normal UNet process with real input x
Firstly save controlled context to self.bank and secondly add together as the context for real forward attn1.
As you can see: self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))