Skip to content

Integrate with PyTorch

PyTorch is a popular open source machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing.

PyTorch enables fast, flexible experimentation and efficient production through a user-friendly front-end, distributed training, and ecosystem of tools and libraries.

Instrument PyTorch with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

Open In Colab

Note: If you are using Pytorch Tensorboard, see our Tensorboard Integration.

Note: This integration also supports PyTorch Distributed Data Parallel. See below.

Start logging

Connect Comet to your existing code by adding in a simple Comet Experiment.

Add the following lines of code to your script or notebook:

import comet_ml
import torch
import torchvision

experiment = comet_ml.Experiment(
    api_key="<Your API Key>",
    project_name="<Your Project Name>"
)

# Your code here

Note

There are other ways to configure Comet. See more here.

Log automatically

After an Experiment has been created, Comet automatically logs the following PyTorch items, by default, with no additional configuration:

  • Model and graph description
  • Training loss

You can easily turn the automatic logging on and off for any or all items. See Configure Comet for PyTorch for more details.

Note

Don't see what you need to log here? We have your back. You can manually log any kind of data to Comet using the Experiment object. For example, use experiment.log_image to log images, or experiment.log_audio to log audio.

End-to-end example

Following is a basic example of using Comet with PyTorch.

If you can't wait, check out the results of this example PyTorch experiment for a preview of what's to come.

Install dependencies

pip install comet_ml torch torchvision tqdm

Run the example

import comet_ml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm

comet_ml.init(project_name="comet-example-pytorch")
experiment = comet_ml.Experiment()

hyper_params = {"batch_size": 100, "num_epochs": 2, "learning_rate": 0.01}
experiment.log_parameters(hyper_params)

# MNIST Dataset
dataset = datasets.MNIST(
    root="./data/", train=True, transform=transforms.ToTensor(), download=True
)

# Data Loader (Input Pipeline)
dataloader = torch.utils.data.DataLoader(
    dataset=dataset, batch_size=hyper_params["batch_size"], shuffle=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(model, optimizer, criterion, dataloader, epoch):
    model.train()
    total_loss = 0
    correct = 0
    for batch_idx, (images, labels) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)

        loss = criterion(outputs, labels)
        pred = outputs.argmax(
            dim=1, keepdim=True
        )  # get the index of the max log-probability

        loss.backward()
        optimizer.step()

        # Compute train accuracy
        batch_correct = pred.eq(labels.view_as(pred)).sum().item()
        batch_total = labels.size(0)

        total_loss += loss.item()
        correct += batch_correct

        # Log batch_accuracy to Comet; step is each batch
        experiment.log_metric("batch_accuracy", batch_correct / batch_total)

    total_loss /= len(dataloader.dataset)
    correct /= len(dataloader.dataset)

    experiment.log_metrics({"accuracy": correct, "loss": total_loss}, epoch=epoch)


model = Net().to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params["learning_rate"])

# Train the Model
with experiment.train():
    print("Running Model Training")
    for epoch in range(hyper_params["num_epochs"]):
        train(model, optimizer, criterion, dataloader, epoch)

Try it out!

Don't just take our word for it, try it out for yourself.

Pytorch model saving and loading

Comet provides user-friendly helpers to allow you to easily save your model and load them back.

Saving a model

To save a Pytorch model, you can use the comet_ml.integration.pytorch.log_model helper like this:

from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model

experiment = Experiment()

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Train model
train(model)

# Save the model for inference
log_model(experiment, model, model_name="TheModel")

The model file will be saved as an Experiment Model which is visible in the Experiment assets tab. From there you will be able to register it in the Model Registry.

The previous code snippet is tailored for inference needs. If you want to log a general checkpoint for Resume Training, you can update the last line of the snippet to be:

# Save the model for Resume Training
model_checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    ...
}
log_model(experiment, model_checkpoint, model_name="TheModel")

comet_ml.integration.pytorch.log_model is using torch.save under the hood, consult the official Pytorch documentation for more details and for instructions for more advanced use-cases.

Check out the reference documentation for more details.

Loading a model

Once you have saved a model using comet_ml.integration.pytorch.log_model, you can load it back with its counterpart comet_ml.integration.pytorch.load_model.

Here is how you can load a model from the Model Registry for Inference:

from comet_ml.integration.pytorch import load_model

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Load the model state dict from Comet Registry
model.load_state_dict(load_model("registry://WORKSPACE/TheModel:1.2.4"))

model.eval()

prediction = model(...)

You can load Pytorch Model from various sources:

  • file://data/my-model, load the state_dict from the file path data/my-model (relative path)
  • file:///path/to/my-model, load the state_dict from the file path /path/to/-my-model (absolute path)
  • registry://<workspace>/<registry_name>, load the state_dict from the Model Registry identified by the workspace and registry name, take the last version of it.
  • registry://<workspace>/<registry_name>:version, load the state_dict from the Model Registry identified by the workspace, registry name and explicit version.
  • experiment://<experiment_key>/<model_name>, load the state_dict from an Experiment, identified by the Experiment key and the model_name.
  • experiment://<workspace>/<project_name>/<experiment_name>/<model_name>, load the state_dict from an Experiment, identified by the workspace name, project name, experiment name and the model_name.

The previous code snippet is tailored for inference needs. If you want to load a general checkpoint for Resume Training, you can update the last line of the snippet to be:

# Initialize model
model = TheModelClass()

# Load the model state dict from a Comet Experiment
checkpoint = load_model("experiment://e1098c4e1e764ff89881b868e4c70f5/TheModel")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()

comet_ml.integration.pytorch.load_modle is using torch.load under the hood, consult the official Pytorch documentation for more details and for instructions for more advanced use-cases.

Check out the reference documentation for more details.

PyTorch Distributed Data Parallel

Are you running distributed training with PyTorch? There is an example for logging PyTorch DDP with Comet in the comet-example repository.

Configure Comet for PyTorch

You can control which PyTorch items are logged automatically. Use any of the following methods:

experiment = comet_ml.Experiment(
    log_graph=True, # Can be True or False.
    auto_metric_logging=True # Can be True or False
)

Add or remove these fields from your .comet.config file under the [comet_auto_log] section to enable or disable logging.

[comet_auto_log]
graph=true # can be true or false
metrics=true # can be true or false
export COMET_AUTO_LOG_GRAPH=true # Can be true or false
export COMET_AUTO_LOG_METRICS=true # Can be true or false

For more information about configuring Comet, see Configure Comet.

Feb. 9, 2024