此篇为《Learning to Compare Relation Network for Few-Shot Learning》
只实现了基于Omniglot数据集的小样本代码
datas为数据集
models为训练好的模型
venv为配置文件
下面的py文件是具体实现代码
1.结构
2.问题:KeyError: '..\datas\omniglot_resized'
报错信息:
File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in main
task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 72, in <listcomp>
self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
KeyError: '..\\datas\\omniglot_resized'
由于linux和window路径的转换,需要把把'/'改成'\'即可。
修改一:
def get_class(self, sample):
return os.path.join(*sample.split('\\')[:-1])
修改二:
3.问题:IndexError: invalid index of a 0-dim tensor.
报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
按要求改成
if (episode + 1) % 100 == 0:
print("episode:", episode + 1, "loss", loss.item())
4.问题:RuntimeError: output with shape [1, 28, 28]
报错信息:
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 107, in __getitem__
image = self.transform(image)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
img = t(img)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 163, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\functional.py", line 208, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这个是使用Omniglot数据集时的报错,主要原因在于使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28],只需要把
normalize =transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
改成
normalize = transforms.Normalize(mean=[0.92206], std=[0.08426])
dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation)
5.问题:AttributeError: module 'torch. nn’has no attribute
报错信息:
Traceback (most recent call last):
File"D:/Omnight/omniglot/omniglot_train few shot.py",line 263, in <modulemainO
File"D:/Omnight/omniglot/omniglot train few shot.py",line 140,in main
if mn.path. exists(str("./models omniglot_feature_encoder_"+ str(CLASS _AM)way_"+ str(SANPLE_ANM_PER_CLASS) +'shot. phk1')):AttributeError: module 'torch. nn’has no attribute 'path'
torch.nn模块是PyTorch中用于神经网络构建和操作的核心模块,它包含了各种层、损失函数和激活函数等。并不包含文件或目录的处理函数,所以没有path函数。因此,更改为使用os.path.exists是一个正确的解决方法,以检查文件或目录是否存在。所以将nn改成os就行了
6.问题:IndexErrorscatter_(: Expected dtype int64 for index.
报错信息
"" vpTraceback (most recent call last) :
File "D;:/Omnight/omniglot/omniglot_train_few shot.py",line 264, in <module>
main()
File"D:/Omnight/omniglot/omniglot_train_few_shot.py",line 188, in main
one_hot_labels = Variable(torch. zeros (BATCA_NMM_PER_CLASS*CLASS_NM,CLASS _NUM .scatter_(l, batch_labels.vier-1,1),..ua(lFb)IndexErrorscatter_(: Expected dtype int64 for index.
scatter_()函数内的索引错误,此函数内部的参数必须是64位整数,改为
one_hot_labels = Variable(
torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1).long(), 1)).cuda(GPU)
loss = mse(relations, one_hot_labels)
7.问题:NotADirectoryBrror:[WinError 267]目录名称无效。:Alphabet_of_the_Magi\ .DS_Store
报错信息:
NotADirectoryBrror:[WinError 267]目录名称无效。:'.\ |datas\ omniglot_resized Alphabet_of_the_Magi\ .DS_Store
只要删除这个文件就行。
解释:
.DS_Store是Mac OS系统在文件夹中生成的隐藏文件,是特定于 Mac OS 系统的文件,是用于存储文件夹的元数据和自定义显示属性。如果出现与 .DS_Store 相关的报错,可能是因为程序在处理文件夹时意外地尝试读取或操作了 .DS_Store 文件,这可能是由于编程代码中没有正确处理隐藏文件的情况,或者没有明确地指定忽略这些文件。在数据集文件夹中删除这个文件就行。