首页 > 其他分享 >Stable-Diffusion-webUI 代码阅读02 —— 按钮?按一下!

Stable-Diffusion-webUI 代码阅读02 —— 按钮?按一下!

时间:2023-07-18 21:23:08浏览次数:48  
标签:Diffusion 02 prompt gr hr button elem webUI id

Stable-Diffusion-webUI 代码阅读02 —— 按钮?按一下!

由于实习工作需要,决定用几天时间阅读一遍stable-diffusion-webui的代码。

本文参考知乎专栏,并且添加了一些自己的理解,感谢大佬!知乎专栏:自动做游戏:AI技术落地于游戏开发 - 知乎 (zhihu.com)

最近工作主要侧重于OneFlow框架应用于SD的加速和不同Sampler的支持适配的工作,所以阅读代码也将其作为切入点。

由于本人刚刚入门,故许多内容比较粗糙,有问题希望多多批评指正!

webui本身更新较快,本文所阅读的内容为 AUTOMATIC1111大佬的1.4.1,这里是项目地址


点击了txt2img之后...

按钮位置

接上一篇文章代码阅读01的末尾部分,可以在modules\ui.pycreate_ui函数中找到txt2img的模块(末尾有代码块)

由于我是小白,对于js不是很熟练,只能通过函数和变量的命名来推断对应的UI部分,锁定了如下的代码块:

image

txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)

虽然这一行很长,但是可以看得出来,webUI对于generate,skip等按钮是做了顶层模块化的,因为他们无论是在txt2img,还是img2img等位置都有。

于是顺着代码,我们找到create_toprow

def create_toprow(is_img2img):
    id_part = "img2img" if is_img2img else "txt2img"

    with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
        with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
            with gr.Row():
                with gr.Column(scale=80):
                    with gr.Row():
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])

            with gr.Row():
                with gr.Column(scale=80):
                    with gr.Row():
                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])

        button_interrogate = None
        button_deepbooru = None
        if is_img2img:
            with gr.Column(scale=1, elem_classes="interrogate-col"):
                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")

        with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
            with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
                skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')

                skip.click(
                    fn=lambda: shared.state.skip(),
                    inputs=[],
                    outputs=[],
                )

                interrupt.click(
                    fn=lambda: shared.state.interrupt(),
                    inputs=[],
                    outputs=[],
                )

            with gr.Row(elem_id=f"{id_part}_tools"):
                paste = ToolButton(value=paste_symbol, elem_id="paste")
                clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
                extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
                save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
                restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)

                token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
                token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")

                clear_prompt_button.click(
                    fn=lambda *x: x,
                    _js="confirm_clear_prompt",
                    inputs=[prompt, negative_prompt],
                    outputs=[prompt, negative_prompt],
                )

            with gr.Row(elem_id=f"{id_part}_styles_row"):
                prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
                create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")

    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button

同样地,我们依旧是排除没用的部分,直接去定位"Generate"按钮,

image

可以看得到,这个按钮所绑定的变量为submit

submit事件

通过全局搜索的方式,最终找到了这个submit的点击事件

image

函数的在webui.py 的create_ui()函数里

wrap_gradio_gpu_call

观察这个函数,发现执行了wrap_gradio_gpu_call这个方法,此方法:

# \modules\call_queue.py

def wrap_gradio_gpu_call(func, extra_outputs=None):
    def f(*args, **kwargs):

        # if the first argument is a string that says "task(...)", it is treated as a job id
        if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
            id_task = args[0]
            progress.add_task_to_queue(id_task)
        else:
            id_task = None

        with queue_lock:
            shared.state.begin()
            progress.start_task(id_task)

            try:
                res = func(*args, **kwargs)
                progress.record_results(id_task, res)
            finally:
                progress.finish_task(id_task)

            shared.state.end()

        return res

    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)

可以发现,此方法是给wrap_gradio_call的一个包裹,在下一个小节会介绍这部分的代码。

可以看到,代码里加入了queue_lock,此队列为多个用户一起执行生成任务的时候,满足先来后到的排队要求,根据锁的方式来满足任务的独立。

此部分代码包裹的目的为排队,即保证先到先服务。

但是这里其实是有问题的,我和师兄在进行尝试时就发现过,模型的切换会出现问题。

在参考了知乎大佬的博客后,很巧合地他谈及了这个问题,这个问题主要原因是,input的参数并不包含全部信息,比如不包含模型名,当不同的两个用户切换模型时,就会发生冲突。而不同的一些插件的加载也容易破坏队列任务中的参数。

wrap_gradio_call

def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
    def f(*args, extra_outputs_array=extra_outputs, **kwargs):
        run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
        if run_memmon:
            shared.mem_mon.monitor()
        t = time.perf_counter()

        try:
            res = list(func(*args, **kwargs))
        except Exception as e:
            # When printing out our debug argument list,
            # do not print out more than a 100 KB of text
            max_debug_str_len = 131072
            message = "Error completing request"
            arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
            if len(arg_str) > max_debug_str_len:
                arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
            errors.report(f"{message}\n{arg_str}", exc_info=True)

            shared.state.job = ""
            shared.state.job_count = 0

            if extra_outputs_array is None:
                extra_outputs_array = [None, '']

            error_message = f'{type(e).__name__}: {e}'
            res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]

        shared.state.skipped = False
        shared.state.interrupted = False
        shared.state.job_count = 0

        if not add_stats:
            return tuple(res)

        elapsed = time.perf_counter() - t
        elapsed_m = int(elapsed // 60)
        elapsed_s = elapsed % 60
        elapsed_text = f"{elapsed_s:.2f}s"
        if elapsed_m > 0:
            elapsed_text = f"{elapsed_m}m "+elapsed_text

        if run_memmon:
            mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
            active_peak = mem_stats['active_peak']
            reserved_peak = mem_stats['reserved_peak']
            sys_peak = mem_stats['system_peak']
            sys_total = mem_stats['total']
            sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)

            vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
        else:
            vram_html = ''

        # last item is always HTML
        res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"

        return tuple(res)

    return f

可以从函数名看出来,这个部分也是为了包裹gradio的请求。

此部分代码主要包裹内容为参数,为了将一些性能信息显示到html页面上

image

可以看到,此部分的包裹信息,就是在webUI使用时,生成图像的下方所包含的参数信息。

核心功能

txt2img

两个包裹函数看完了,而功能的核心实现为包裹函数内传入的参数,即modules.txt2img

image

image

接下来看txt2img

# \modules\txt2img.py

def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
    override_settings = create_override_settings_dict(override_settings_texts)

    p = processing.StableDiffusionProcessingTxt2Img(
        sd_model=shared.sd_model,
        outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
        outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
        prompt=prompt,
        styles=prompt_styles,
        negative_prompt=negative_prompt,
        seed=seed,
        subseed=subseed,
        subseed_strength=subseed_strength,
        seed_resize_from_h=seed_resize_from_h,
        seed_resize_from_w=seed_resize_from_w,
        seed_enable_extras=seed_enable_extras,
        sampler_name=sd_samplers.samplers[sampler_index].name,
        batch_size=batch_size,
        n_iter=n_iter,
        steps=steps,
        cfg_scale=cfg_scale,
        width=width,
        height=height,
        restore_faces=restore_faces,
        tiling=tiling,
        enable_hr=enable_hr,
        denoising_strength=denoising_strength if enable_hr else None,
        hr_scale=hr_scale,
        hr_upscaler=hr_upscaler,
        hr_second_pass_steps=hr_second_pass_steps,
        hr_resize_x=hr_resize_x,
        hr_resize_y=hr_resize_y,
        hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
        hr_prompt=hr_prompt,
        hr_negative_prompt=hr_negative_prompt,
        override_settings=override_settings,
    )

    p.scripts = modules.scripts.scripts_txt2img
    p.script_args = args

    if cmd_opts.enable_console_prompts:
        print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)

    processed = modules.scripts.scripts_txt2img.run(p, *args)

    if processed is None:
        processed = processing.process_images(p)

    p.close()

    shared.total_tqdm.clear()

    generation_info_js = processed.js()
    if opts.samples_log_stdout:
        print(generation_info_js)

    if opts.do_not_show_images:
        processed.images = []

    return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)

可以看到,依次调用了如下的方法:

  • 重写的设置(略)

  • StableDiffusionProcessingTxt2Img方法构造的对象p

  • modules.scripts.scripts_txt2img.run方法,利用对象p和参数args生成了结果processed

    • 若processed在调用后为空,则调用processing.process_images(p)
  • 在有了结果之后,将内容写入html(略)

StableDiffusionProcessingTxt2Img

我们可以在modules/processing.py里找到

image

  • sd_samplers.create_sampler创建了一个采样器对象并且赋值给sampler属性
  • 接下来使用了shared.latent_upscale_modes字典根据self.hr_upscaler的值获取对应的缩放模式。
  • x=create_random_tensors(...)生成一个随机噪音
  • 之后判断是否进行高清修复enable_hr,如果没有则直接返回采样器

注意,这个sample方法,只是定义了一个采样器,但是还未执行

采样器

create_sampler

# modules/samplers.py
def create_sampler(name, model):
    config = find_sampler_config(name)

    assert config is not None, f'bad sampler name: {name}'

    sampler = config.constructor(model)
    sampler.config = config

    return sampler

可以看到,采样器的生成做了一下几个事情

  • 加载配置

    • def find_sampler_config(name):
          if name is not None:
              config = all_samplers_map.get(name, None)
          else:
              config = all_samplers[0]
      
          return config
      
    • 这里就是一个在一个以name作为键值的map里查找

    • map的定义如下:

    • all_samplers = [
          *sd_samplers_kdiffusion.samplers_data_k_diffusion,
          *sd_samplers_compvis.samplers_data_compvis,
      ]
      

      可以发现,采样器来源于开源项目 K_Diffusion里的采样器数据和Compvis小组自己的采样器数据

  • 调用config.constructor来加载模型

config.constructor

我们采取调试的方法,选择Euler a作为sampler

image

在调试的时候发现遇到了阻碍:

image

想想,原因应该是在于,k_diffusion是作为外部库的,具体的位置在哪呢?突然想到了在我的第一篇文章内的webui-macos-env的阅读,好像见过类似的库的导入。但我们作为windows启动,那么答案显而易见了,就是之前刻意跳过的launch_utils.py里的内容。

全局搜索k-diffusion.git,最终找到

image

但是!这种外部库十分不方便,尤其是不便于索引到对应的位置并进行调试。

对此尝试了一种新的解决方法,篇幅原因,放到下节再说。

标签:Diffusion,02,prompt,gr,hr,button,elem,webUI,id
From: https://www.cnblogs.com/ZzTzZ/p/17564176.html

相关文章

  • 20230718巴蜀暑期集训测试总结
    T1做了\(3h\),时间复杂度不对,小样例都还有一个没过。考虑容斥,不连通的情况枚举\(1\)号点所在连通块。设\(f_{S,i}\)表示\(S\)连通且选了\(i\)条边的方案数。设\(inb_s\)表示\(S\)内部的边数。那么有转移:\[f_{S,i}=\binom{inb_S}i-\sum_{T\subsetneqqS,1\inT}......
  • 【专题】2022年中国制造业数字化转型研究报告PDF合集分享(附原数据表)
    报告链接:http://tecdat.cn/?p=32145原文出处:拓端数据公众号本文中所说的制造业数字化转型,指的是在制造企业的设计、生产、管理、销售及服务的每一个环节中,将新一代信息技术应用到制造企业的设计、生产、管理、销售及服务的每一个环节中,并可以以每一个环节中产生的数据为基础,展开......
  • 2023冲刺国赛模拟 38.1
    又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又又垫......
  • Databend 开源周报第 102 期
    Databend是一款现代云数仓。专为弹性和高效设计,为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务:https://app.databend.cn。What'sOnInDatabend探索Databend本周新进展,遇到更贴近你心意的Databend。为指定列创建BloomIndex创建bloomindex将会消耗大......
  • CS5212/CS5202 DP转VGA芯片设计方案
    CS5212内置MCU控制器,超低待机功率<100uW,用于设计DP端口到VGA转换器,也可以用于主板DP转VGA方案,CS5212AN芯片功能特性:2-lane通道VESADP1.1兼容接收机VGA输出接口,DAC速度高达210MHz,8位分辨率高达1920x1200x60(RB,缩小消隐),24位色深,1920x1440x60(RB,缩小消隐),或2048x152x60(RB,缩小消隐......
  • 8-102-(LeetCode- 207&210) 课程表
    1.题目 读题  考查点 2.解法思路这个问题可以用图论的方法来解决,具体思路如下:将课程和先修课程看作有向图的节点和边,如果要学习课程ai,则必须先学习课程bi,表示为bi->ai。判断图中是否存在环,如果存在环,则说明有些课程无法完成,返回false;如果不存在环,则说明所有课程都......
  • P6835 [Cnoi2020] 线形生物题解
    P6835[Cnoi2020]线形生物题解题目描述求从\(1\)到\(n+1\)的链的期望,其中有\(m\)条返祖边:\(u->v\)这条边\(u\gev\),等概率,求期望Solution这种爬楼梯的题一般求解\(E(x\rightarrowx+1)\),则最后答案为\(\sum_{i=1}^nE(i\rightarrowi+1)\)我们考虑从\(x\rightarr......
  • 原子计算器Atom4 2024版发布下载
    2024版更新记录:2024EditionupdateRecord:1、能计算2个粒子的弹性碰撞后的各自的速度。2、能计算球的对心和对切线轴的质量、半径和转动惯量之间的关系和数据。1,thevelocitycanningcomputetheempressoftheelasticcollisionof2gramulatesrespectively.2,can......
  • Day11(2023.07.18)
    行程8:45    到达上海市信息安全测评认证中心(黄浦区陆家浜路1308号)9:00  改文件11:30--13:00   吃饭休息13:00 创建项目,熟悉软件,生成报告等..17:00      下班......
  • 2023HWS_RE复现
    2023HWS_RE复现参考wp:https://oacia.cc/hws-2023/Android参考这篇文章:https://www.52pojie.cn/thread-1680984-1-1.html360加固去掉利用gg修改器+lua脚本去掉360加固后然后就ok了,但是dex文件直接丢尽jadx会报错把多余的数据删了就行然后可以正常看到so里面看,sm4加......