首页 > 其他分享 >交叉验证

交叉验证

时间:2023-05-31 21:45:03浏览次数:24  
标签:交叉 torch db batch loader train 验证 data

交叉验证

在实际情况中,数据集是分为训练集和测试集的。而测试集通常被用户保留,并不对外公开,以防止在测试模型时作弊,故意使用让模型效果更好的数据进行测试,以至于模型遇上新的数据效果很差。

image-20230531211645662

于是我们通常将训练集进行分割,一部分用于训练,一部分用以测试,这里的测试其实叫做验证。

image-20230531211926532

由于数据集的一部分用以测试而获取不到,这部分的数据损失可能对训练结果造成影响。为了减小这种影响,我们充分的利用可用的数据集,将训练集中的不同部分作为验证集,进行交叉验证,以减小验证集选取的偶然性对结果造成的影响。

image-20230531170055521

import torch
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

batch_size = 200
learning_rate = 0.01
epochs = 10

# 加载训练集 60k
train_db = datasets.MNIST('../data',train=True,download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,),(0.3081,))
                   ]
                   ))

train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size,
    shuffle=True
)

# 加载测试集 10k
test_db = datasets.MNIST('../data',train=False,download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,),(0.3081,))
                   ])
                         )

test_loader = torch.utils.data.DataLoader(
    test_db,
    batch_size=batch_size,
    shuffle=True
)

# 打印训练集(60k)和测试集(10k)的大小
print('train:',len(train_db),'test:',len(test_db))

# 将训练数据集(60k)划分为训练集(50k)和验证集(10k)
train_db,val_db = torch.utils.data.random_split(train_db,[50000,10000])

# 打印训练集和测试集的大小
print('train_db:',len(train_db),'val_db:',len(val_db))

train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size,
    shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_db,
    batch_size=batch_size,
    shuffle=True
)


class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP,self).__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(784,200),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Linear(200, 200),
            torch.nn.LeakyReLU(inplace=True),
            torch.nn.Linear(200, 10),
            torch.nn.LeakyReLU(inplace=True),
        )

    def forward(self,x):
        x = self.model(x)

        return x


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = MLP().to(device)

optimizer = optim.SGD(net.parameters(),lr=learning_rate)
criteon = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    for batch_idx,(data,target) in enumerate(train_loader):
        data = data.view(-1,28*28).to(device)
        target =target.to(device)

        logits = net(data)
        loss = criteon(logits,target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx % 100) == 0:
            print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(data),
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()
            ))

    test_loss = 0
    correct = 0
    for data,target in val_loader:
        data = data.view(-1,28 * 28).to(device)
        target = target.to(device)
        logits = net(data)
        test_loss += criteon(logits,target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target).sum().item()

    test_loss /= len(val_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss,
        correct,
        len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)
    ))

标签:交叉,torch,db,batch,loader,train,验证,data
From: https://www.cnblogs.com/dxmstudy/p/17447405.html

相关文章

  • 解决方案 | Windows 验证账号出现 0x80190001错误解决
    一、问题描述点击windows开始→账户→更改账户设置→验证,出现下面的错误。 二、解决方法网上流行的是这个方法,https://blog.csdn.net/qq_36393978/article/details/107413791 ,但是这个其实是恢复网络刷新dns的方法,大家可试一试。 如果不行,试试下面的方法,在任务栏搜索框......
  • 【解决一个小问题】macbook m2 上交叉编译 gozstd
    作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢!cnblogs博客zhihuGithub公众号:一本正经的瞎扯已知zstd是一个优秀的压缩库,gozstd封装了这个库。一开始在macbookm2芯片的笔记本上开发包含了gozstd的程序时,一切正常。发布的时候,需要分别编译linux+arm64......
  • 阿里云验证码短信功能---SpringBoot项目
    阿里云官网:https://www.aliyun.com/activity/2023caigouji/shangyuncaigouji?utm_content=se_1013408957准备工作注册阿里云账号申请AccessKeyID和AccessKeySecret搜索“短信服务SMS”,选择“免费开通”即可选择国内消息,申请签名管理和模板管理准备完成后我们可以获取Access......
  • dockerfile镜像私有仓库需要https登录验证改成http
    ERROR:failedtodorequest:Head"https://192.168.16.185:8088/v2/jenkins/python_common_api/manifests/base":http:servergaveHTTPresponsetoHTTPSclientDockerfile:1--------------------1|>>>FROM192.168.16.185:8088/jenkins/p......
  • 验证码识别
    验证码识别是基于线上的打码平台识别验证码-打码平台:1.超级鹰(http://www.chaojiying.com/)-注册(用户中心身份)-登录(用户中心身份)-1.查询余额,请充值-2.创建一个软件ID(899370)-3.下载示例代码2.云打码3.打......
  • 常见LOSS函数之Cross Entropy(交叉熵)
    常见LOSS函数之CrossEntropy(交叉熵)交叉熵是分类问题常用的损失函数。熵熵表示稳定性或者说惊喜度,熵越高表示越稳定,其惊喜度就越低。示例一般用H(p)表示熵交叉熵交叉熵=熵+散度散度代表分布的重叠程度,散度越高重合度越少。当两个分布p和q相等时,散度等于0,此时交叉熵......
  • python selenium web网站登录缺口图片验证码识别
    deflogin():driver=webdriver.Chrome("browser_driver/chromedriver.exe")driver.get("http://xxxxxx/#/login")driver.maximize_window()sleep(1)driver.find_element(By.CSS_SELECTOR,'[placeholder="请输入手机号&qu......
  • SpringSecurity 添加验证码的两种方式
    一验证码生产<dependency><groupId>com.github.penggle</groupId><artifactId>kaptcha</artifactId><version>2.3.2</version></dependency>@ConfigurationpublicclassKaptchaConfig{@BeanPro......
  • SQL高级篇~动态交叉表
    QL动态交叉表(DynamicCrosstab)是SQL查询语言中的一种高级技术,可以将行数据转换为列数据,实现更加直观的数据展示方式。它允许我们在不知道列名和列数的情况下动态地将行数据转换为列数据,并将其呈现在一个表格中。这在数据分析和报表生成方面非常有用。一般情况下,我们使用SELECT语......
  • python爬虫 requests访问http网站之443报错(ssl验证)
    报错信息:urllib3.exceptions.MaxRetryError:HTTPSConnectionPool(host='ssr4.scrape.center',port=443):Maxretriesexceededwithurl:/page/1(CausedbySSLError(SSLCertVerificationError(1,'[SSL:CERTIFICATE_VERIFY_FAILED]certificateverifyfa......