首页 > 其他分享 >strict=False 但还是size mismatch 的解决办法

strict=False 但还是size mismatch 的解决办法

时间:2023-08-25 12:44:06浏览次数:33  
标签:... decoder.0 False weight mismatch strict dict pretrained

问题描述:

# RuntimeError: 
Error(s) in loading state_dict for Fusion_Generator: size mismatch for fg_decoder.0.weight: copying a param with shape torch.Size([4096, 1024]),g_decoder.0.weight: copying a param with shape torch.Size([4096, 1024]...
出现两个参数的不匹配。

具体内容如下:

model = GAN(opt)
loaded = torch.load(model_path)
assert (opt.epoch == loaded['epoch'])
model.load_state_dict(loaded['model'], strict=False)   # 这里爆出上述Error,定位到下面的函数

def load_state_dict(self, pretrained_dict, strict=False):

    for k in pretrained_dict:
        if k ...
             ...
             ...
        elif k == "generator":
            self.generator.load_state_dict(pretrained_dict[k], strict=strict)  # 这里虽然strict传入的是False,忽略不匹配参数,仍有上述问题
        elif k ...
               ...

在参考 这里 后,如果只是pop()掉fg_decoder.0.weightbg_decoder.0.weight会有新的问题出现(一般问题通过pop掉能解决问题),即



KeyError: 'fg_decoder.0.weight,bg_decoder.0.weight'

即不能识别上述两个键值,这时可以通过打印模型参数具体内容查看:

def load_state_dict(self, pretrained_dict, strict=False):

    for k in pretrained_dict:
        if k ...
             ...
             ...
        elif k == "fusion_generator":
            for u in pretrained_dict[k].keys():
                print(u," ",pretrained_dict[k][u])
            self.fusion_generator.load_state_dict(pretrained_dict[k], strict=strict)  #
        elif k ...
               ...

打印结果

fg_decoder.0.weight xxxxxx tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')

fg_decoder.0.bias xxxxxx tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0') fg_decoder.1.weight xxxxxx tensor([1.0362, 0.9969, 0.9892, ..., 0.9939, 1.0122, 1.0190], device='cuda:0') fg_decoder.1.bias xxxxxx tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0') fg_decoder.1.running_mean xxxxxx tensor([ 0.1915, -0.5510, 0.5370, ..., -0.1265, 0.8344, 1.4391], device='cuda:0') fg_decoder.1.running_var xxxxxx tensor([0.9402, 0.7382, 0.0167, ..., 0.3988, 0.1081, 0.4470], device='cuda:0') fg_decoder.1.num_batches_tracked xxxxxx tensor(3880, device='cuda:0') fg_decoder.3.weight xxxxxx tensor([[ 0.0211, -0.0072, 0.0030, ..., 0.0090, 0.0120, 0.0043], [ 0.0221, -0.0320, -0.0050, ..., 0.0239, 0.0035, 0.0438], [ 0.0246, -0.0091, 0.0146, ..., -0.0003, 0.0257, -0.0025], ..., [ 0.0077, -0.0209, -0.0017, ..., 0.0135, 0.0418, 0.0052], [ 0.0109, 0.0066, -0.0093, ..., 0.0048, -0.0019, -0.0381], [ 0.0145, -0.0165, 0.0095, ..., 0.0252, -0.0184, 0.0178]], device='cuda:0')
....
bg_decoder.0.weight xxxxxx tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
....

可以发现fg_decoder.0.weight和bg_decoder.0.weight都在里面,并且对应为pretrained_dict[k][u]

所以!!!在有序字典中将对应报错内容删除后,就能解决size mismatch问题

def load_state_dict(self, pretrained_dict, strict=False):

    for k in pretrained_dict:
        if k ...
             ...
             ...
        elif k == "fusion_generator":
            for u in list(pretrained_dict[k].keys()):# (小坑)加list防止同时读写报错
                if u == "fg_decoder.0.weight" or u == "bg_decoder.0.weight":
                    pretrained_dict[k].pop(u)
            self.fusion_generator.load_state_dict(pretrained_dict[k], strict=strict)  #
        elif k ...
               ...

成功解决问题~

标签:...,decoder.0,False,weight,mismatch,strict,dict,pretrained
From: https://www.cnblogs.com/ygsworld/p/17656637.html

相关文章

  • Commit failed (details follow): Working copy text base is corrupt Checksum misma
    问题:提交一个svn文件报错,提交其他文件没有报错解决办法:(网上看了很多方法都解决不了):1、把文件拷贝到svn目录外放着2、把svn目录下文件移除,然后commitsvn3、把目录外的文件拷贝进来,先Add,然后commit就成功了......
  • Winform项目中出现 "已经可见的窗体不能显示为模式对话框。在调用 showDialog 之前应
    1问题描述最近做一个winform项目,启动程序弹出的加载进度窗体时,发生如标题所示的异常。2尝试debug根据异常提示,在进度窗体弹出前添加代码Visable=false;--未解决逐步debug调试发现Form弹框运行了2次,由此查出bug所在。由于我是用的单例模式,在Program.cs中运行的还是new......
  • Vue packages version mismatch:
    报错原因:vue与vue-template-compiler版本不匹配。解决办法:上图中说了看看使用vue-loader的版本,我的是13版本大于10.0版本,这个时候需要更新vue-template-compiler//卸载npmuninstallvue-template-compiler//添加和vue一样的版本[email protected]......
  • python+playwright 学习-72 设置window.navigator.webdriver属性为false 跳过网站反爬
    前言有些网站有反爬机制,比如用代码启动的浏览器会被检测到,需要人机验证,用脚本去点击或者滑动滑动虽然能滑动,但是会认证失败。用playwright和selenium启动的浏览器都会用个webdriver属性。浏览器会根据这个属性判断是否是人工正常操作。window.navigator.webdriver属性人......
  • csv reader utf-8报错:strict 改为ignore
    classBufferedIncrementalDecoder(IncrementalDecoder):"""ThissubclassofIncrementalDecodercanbeusedasthebaseclassforanincrementaldecoderifthedecodermustbeabletohandleincompletebytesequences."......
  • 模型超参数基本都没改,测试时加载模型报模型结构不匹配,设置模糊加载模型即:model.load_s
    原因多卡训练;单卡模糊加载进行测试。训练时,通过torch.nn.DataParallel(self.model)进行多卡并行训练;测试时,用单卡模糊加载保存的模型权重,很多模型参数都没有加载成功,自然会导致测试效果很差。解决方法测试时,使用多卡加载模型时,删掉'module.'前缀;或者用单卡加载模型进行测试。......
  • 关于 HTTP 响应头字段 Strict-Transport-Security
    在Chrome开发者工具的Network面板里,当观察到一个请求的ResponseHeader字段名称为"Strict-Transport-Security",并且其值为"max-age=31536000;includeSubDomains;preload"时,这代表网站启用了严格传输安全(StrictTransportSecurity,HSTS)策略。HSTS是一种安全机制,旨在提高网站的安......
  • AutoX——当Android中clickable属性显示为false,实际可点击的布局如何处理
    前言最近在写一个关于某音的脚本,包含刷视频/点赞/收藏/分享/评论等一些列功能,借助于AutoX来实现,虽然我老早就买了AutoJsPro但是最新版本阉割的有点厉害。。。内容思索很简单就是,找到布局后,获取坐标信息,使用click去触发;varbtn=className("android.widget.TextView").t......
  • 浅谈-HttpSession session = request.getSession(false)
    当使用request.getSession(false)方法时,如果当前请求没有关联的会话,则不会创建新的会话,而是返回null。这意味着,如果当前客户端没有携带有效的会话标识符(如JSESSIONID),或者会话已过期或被销毁,则request.getSession(false)方法将返回null。下面是一个示例来解释这个方法的用......
  • 使用redis-py的两个类Redis和StrictRedis时遇到的坑
    redis-py提供两个类Redis和StrictRedis用于实现Redis的命令,StrictRedis用于实现大部分官方的命令,并使用官方的语法和命令(比如,SET命令对应与StrictRedis.set方法)。Redis是StrictRedis的子类,用于向后兼容旧版本的redis-py。简单说,官方推荐使用StrictRedis方法。  不推荐Redis类,原......