安装cuda和cudnn:
conda install cudatoolkit==11.8.0
python3 -m pip install nvidia-cudnn-cu11==8.7.0.84
下载安装包的下载地址页面:
wget https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu118
并从中找到对应python版本的安装包地址。
下载框架安装包:(python3.10环境下)
wget https://oneflow-staging.oss-cn-beijing.aliyuncs.com/branch/master/cu118/82503845d553c722041d884d109fc1314568a4c1/oneflow-0.9.1.dev20240201%2Bcu118-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
手动安装依赖包:
pip install packaging
pip install typing_extensions
安装框架包:
pip install ./oneflow-0.9.1.dev20240201+cu118-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
测试代码:
import oneflow as flow
import oneflow.nn as nn
import flowvision
import flowvision.transforms as transforms
BATCH_SIZE = 64
EPOCH_NUM = 1
DEVICE = "cuda" if flow.cuda.is_available() else "cpu"
print("Using {} device".format(DEVICE))
training_data = flowvision.datasets.CIFAR10(
root="data",
train=True,
transform=transforms.ToTensor(),
download=True,
source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz",
)
train_dataloader = flow.utils.data.DataLoader(
training_data, BATCH_SIZE, shuffle=True
)
model = flowvision.models.mobilenet_v2().to(DEVICE)
model.classifer = nn.Sequential(nn.Dropout(0.2), nn.Linear(model.last_channel, 10))
model.train()
loss_fn = nn.CrossEntropyLoss().to(DEVICE)
param_groups = [
{'params':model.features.parameters(), 'lr':1e-3},
{'params':model.adaptive_avg_pool2d.parameters(), 'lr':1e-4},
{'params':model.classifier.parameters(), 'lr':1e-5},
]
optimizer = flow.optim.SGD(param_groups)
for t in range(EPOCH_NUM):
print(f"Epoch {t+1}\n-------------------------------")
size = len(train_dataloader.dataset)
for batch, (x, y) in enumerate(train_dataloader):
x = x.to(DEVICE)
y = y.to(DEVICE)
# Compute prediction error
pred = model(x)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
current = batch * BATCH_SIZE
if batch % 5 == 0:
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
运行效果:
标签:0.9,loss,OneFlow,nn,1dev,oneflow,64,DEVICE,model From: https://www.cnblogs.com/devilmaycry812839668/p/18005051