首页 > 其他分享 >PyTorchStepByStep - Chapter 2: Rethinking the Training Loop

PyTorchStepByStep - Chapter 2: Rethinking the Training Loop

时间:2024-10-12 21:34:10浏览次数:13  
标签:Chapter Rethinking Training tensor loss step train model fn



def make_train_step_fn(model, loss_fn, optimizer):
    def perform_train_step_fn(x, y):
        # Set model to TRAIN mode

        # Step 1 - Compute model's predictions - forward pass
        yhat = model(x)

        # Step 2 - Compute the loss
        loss = loss_fn(yhat, y)

        # Step 3 - Compute the gradients for both parameters "b" and "w"

        # Step 4 - Update parameters using gradients and the learning rate

        # Return the loss
        return loss.item()
    return perform_train_step_fn


%%writefile model_configuration/v1.py

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set learning rate
lr = 0.1


model = nn.Sequential(nn.Linear(1, 1)).to(device)

# Define an SGD optimizer to update the parameters (now retrieved directly from the model)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# Define an MSE loss function
loss_fn = nn.MSELoss(reduction='mean')

# Create the train_step function for model, loss function and optimizer
train_step_fn = make_train_step_fn(model, loss_fn, optimizer)



%%writefile model_training/v1.py

n_epochs = 1000

losses = []

for epoch in range(n_epochs):
    # Perform one train step and return the corresponding loss
    loss = train_step_fn(x_train_tensor.reshape(-1, 1), y_train_tensor.reshape(-1, 1))



class CustomDataset(Dataset):
    def __init__(self, x_tensor, y_tensor):
        self.x = x_tensor
        self.y = y_tensor

    def __getitem__(self, index):
        return (self.x[index], self.y[index])

    def __len__(self):
        return len(self.x)

# Wait, is this a CPU tensor now? Why? Where is .to(device)?
x_train_tensor = torch.from_numpy(x_train).float()
y_train_tensor = torch.from_numpy(y_train).float()

train_data = CustomDataset(x_train_tensor, y_train_tensor)
print(train_data[0])  # (tensor(0.8446), tensor(2.8032))



From: https://www.cnblogs.com/zhangzhihui/p/18461534


  • 【HITCON-Training】Lab 12 - SecretGarden