DeepSDF
- DeepSDF is an auto-decoder based approach that learns a continuous SDF
representation for a class of shapes. Once trained, it can be used for
- shape representation
- interpolation
- shape completion
Bibtex #
@inproceedings{park2019deepsdf,
title={Deepsdf: Learning continuous signed distance functions for shape
representation},
author={Park, Jeong Joon and Florence, Peter and Straub, Julian and Newcombe,
Richard and Lovegrove, Steven},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and
Pattern Recognition},
pages={165--174},
year={2019}
}
Overview #
- Not tied to a grid structure
- During training, the autodecoder optimizes both the network parameters and the latent codes representing each of the training shapes. Once trained, to reconstruct a shape given its SDF observations, a latent code is optimized keeping the network parameters fixed, such that the optimized latent code gives the lowest error with observed SDF values.
Architecture #

Figure 1: Architecture of the Auto-decoder
Implementation #
Model #
import torch.nn as nn
import torch
class DeepSDFDecoder(nn.Module):
def __init__(self, latent_size):
"""
:param latent_size: latent code vector length
"""
super().__init__()
dropout_prob = 0.2
# NOTE: Define model
self.lin0 = nn.utils.weight_norm(nn.Linear(259, 512))
self.lin1 = nn.utils.weight_norm(nn.Linear(512, 512))
self.lin2 = nn.utils.weight_norm(nn.Linear(512, 512))
self.lin3 = nn.utils.weight_norm(nn.Linear(512, (512 - 259)))
self.lin4 = nn.utils.weight_norm(nn.Linear(512, 512))
self.lin5 = nn.utils.weight_norm(nn.Linear(512, 512))
self.lin6 = nn.utils.weight_norm(nn.Linear(512, 512))
self.lin7 = nn.utils.weight_norm(nn.Linear(512, 512))
self.lin8 = nn.Linear(512, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x_in):
"""
:param x_in: B x (latent_size + 3) tensor
:return: B x 1 tensor
"""
# NOTE: implement forward pass
x = self.lin0(x_in)
x = self.relu(x)
x = self.dropout(x)
x = self.lin1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.lin2(x)
x = self.relu(x)
x = self.dropout(x)
x = self.lin3(x)
x = self.relu(x)
x = self.dropout(x)
x = torch.cat((x, x_in), dim=1)
x = self.lin4(x)
x = self.relu(x)
x = self.dropout(x)
x = self.lin5(x)
x = self.relu(x)
x = self.dropout(x)
x = self.lin6(x)
x = self.relu(x)
x = self.dropout(x)
x = self.lin7(x)
x = self.relu(x)
x = self.dropout(x)
x = self.lin8(x)
return x
Inference #
import random
from pathlib import Path
import torch
import numpy as np
from exercise_3.data.shape_implicit import ShapeImplicit
from exercise_3.model.deepsdf import DeepSDFDecoder
from exercise_3.util.misc import evaluate_model_on_grid
class InferenceHandlerDeepSDF:
def __init__(self, latent_code_length, experiment, device):
"""
:param latent_code_length: latent code length for the trained DeepSDF model
:param experiment: path to experiment folder for the trained model; should contain "model_best.ckpt" and "latent_best.ckpt"
:param device: torch device where inference is run
"""
self.latent_code_length = latent_code_length
self.experiment = Path(experiment)
self.device = device
self.truncation_distance = 0.01
self.num_samples = 4096
def get_model(self):
"""
:return: trained deep sdf model loaded from disk
"""
model = DeepSDFDecoder(self.latent_code_length)
model.load_state_dict(torch.load(self.experiment / "model_best.ckpt", map_location='cpu'))
model.eval()
model.to(self.device)
return model
def get_latent_codes(self):
"""
:return: latent codes which were optimized during training
"""
latent_codes = torch.nn.Embedding.from_pretrained(torch.load(self.experiment / "latent_best.ckpt", map_location='cpu')['weight'])
latent_codes.to(self.device)
return latent_codes
def reconstruct(self, points, sdf, num_optimization_iters):
"""
Reconstructs by optimizing a latent code that best represents the input sdf observations
:param points: all observed points for the shape which needs to be reconstructed
:param sdf: all observed sdf values corresponding to the points
:param num_optimization_iters: optimization is performed for this many number of iterations
:return: tuple with mesh representations of the reconstruction
"""
model = self.get_model()
# NOTE: define loss criterion for optimization
loss_l1 = torch.nn.L1Loss()
# initialize the latent vector that will be optimized
latent = torch.ones(1, self.latent_code_length).normal_(mean=0, std=0.01).to(self.device)
latent.requires_grad = True
# NOTE: create optimizer on latent, use a learning rate of 0.005
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
for iter_idx in range(num_optimization_iters):
# NOTE: zero out gradients
optimizer.zero_grad()
# NOTE: sample a random batch from the observations, batch size =
# self.num_samples
batch_indices = np.random.choice(points.shape[0], self.num_samples)
batch_points = points[batch_indices, :]
batch_sdf = sdf[batch_indices, :]
# move batch to device
batch_points = batch_points.to(self.device)
batch_sdf = batch_sdf.to(self.device)
# same latent code is used per point, therefore expand it to have same length as batch points
latent_codes = latent.expand(self.num_samples, -1)
# NOTE: forward pass with latent_codes and batch_points
nn_input = torch.cat((latent_codes, batch_points), dim=1)
predicted_sdf = model(nn_input)
# NOTE: truncate predicted sdf between -0.1, 0.1
predicted_sdf = torch.clip(predicted_sdf, -0.1, 0.1)
# compute loss wrt to observed sdf
loss = loss_l1(predicted_sdf, batch_sdf)
# regularize latent code
loss += 1e-4 * torch.mean(latent.pow(2))
# NOTE: backwards and step
loss.backward()
optimizer.step()
# loss logging
if iter_idx % 50 == 0:
print(f'[{iter_idx:05d}] optim_loss: {loss.cpu().item():.6f}')
print('Optimization complete.')
# visualize the reconstructed shape
vertices, faces = evaluate_model_on_grid(model, latent.squeeze(0), self.device, 64, None)
return vertices, faces
def interpolate(self, shape_0_id, shape_1_id, num_interpolation_steps):
"""
Interpolates latent codes between provided shapes and exports the intermediate reconstructions
:param shape_0_id: first shape identifier
:param shape_1_id: second shape identifier
:param num_interpolation_steps: number of intermediate interpolated points
:return: None, saves the interpolated shapes to disk
"""
# get saved model and latent codes
model = self.get_model()
train_latent_codes = self.get_latent_codes()
# get indices of shape_ids latent codes
train_items = ShapeImplicit(4096, "train").items
latent_code_indices = torch.LongTensor([train_items.index(shape_0_id), train_items.index(shape_1_id)]).to(self.device)
# get latent codes for provided shape ids
latent_codes = train_latent_codes(latent_code_indices)
for i in range(0, num_interpolation_steps + 1):
# NOTE: interpolate the latent codes: latent_codes[0, :] and latent_codes[1, :]
t = i / num_interpolation_steps
interpolated_code = (1-t) * latent_codes[0, :] + t * latent_codes[1, :]
# reconstruct the shape at the interpolated latent code
evaluate_model_on_grid(model, interpolated_code, self.device, 64, self.experiment / "interpolation" / f"{i:05d}_000.obj")
def infer_from_latent_code(self, latent_code_index):
"""
Reconstruct shape from a given latent code index
:param latent_code_index: shape index for a shape in the train set for which reconstruction is performed
:return: tuple with mesh representations of the reconstruction
"""
# get saved model and latent codes
model = self.get_model()
train_latent_codes = self.get_latent_codes()
# get latent code at given index
latent_code_indices = torch.LongTensor([latent_code_index]).to(self.device)
latent_codes = train_latent_codes(latent_code_indices)
# reconstruct the shape at latent code
vertices, faces = evaluate_model_on_grid(model, latent_codes[0], self.device, 64, None)
return vertices, faces
Training #
from pathlib import Path
import torch
from exercise_3.model.deepsdf import DeepSDFDecoder
from exercise_3.data.shape_implicit import ShapeImplicit
from exercise_3.util.misc import evaluate_model_on_grid
def train(model, latent_vectors, train_dataloader, device, config):
# Declare loss and move to device
# NOTE: declare loss as `loss_criterion`
loss_criterion = torch.nn.L1Loss()
loss_criterion.to(device)
# declare optimizer
optimizer = torch.optim.Adam([
{
# NOTE: optimizer params and learning rate for model (lr provided in config)
"params" : model.parameters(),
"lr" : config["learning_rate_model"],
},
{
# NOTE: optimizer params and learning rate for latent code (lr provided in config)
"params": latent_vectors.parameters(),
"lr" : config["learning_rate_code"],
}
])
# declare learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
# Set model to train
model.train()
# Keep track of running average of train loss for printing
train_loss_running = 0.
# Keep track of best training loss for saving the model
best_loss = float('inf')
for epoch in range(config['max_epochs']):
for batch_idx, batch in enumerate(train_dataloader):
# Move batch to device
ShapeImplicit.move_batch_to_device(batch, device)
# NOTE: Zero out previously accumulated gradients
optimizer.zero_grad()
# calculate number of samples per batch (= number of shapes in batch * number of points per shape)
num_points_per_batch = batch['points'].shape[0] * batch['points'].shape[1]
# get latent codes corresponding to batch shapes
# expand so that we have an appropriate latent vector per sdf sample
batch_latent_vectors = latent_vectors(batch['indices']).unsqueeze(1).expand(-1, batch['points'].shape[1], -1)
batch_latent_vectors = batch_latent_vectors.reshape((num_points_per_batch, config['latent_code_length']))
# reshape points and sdf for forward pass
points = batch['points'].reshape((num_points_per_batch, 3))
sdf = batch['sdf'].reshape((num_points_per_batch, 1))
# NOTE: perform forward pass
nn_input = torch.cat((batch_latent_vectors, points), dim=1)
predicted_sdf = model(nn_input)
# NOTE: truncate predicted sdf between -0.1 and 0.1
predicted_sdf = torch.clip(predicted_sdf, -0.1, 0.1)
# compute loss
loss = loss_criterion(predicted_sdf, sdf)
# regularize latent codes
code_regularization = torch.mean(torch.norm(batch_latent_vectors, dim=1)) * config['lambda_code_regularization']
if epoch > 100:
loss = loss + code_regularization
# NOTE: backward
loss.backward()
# NOTE: update network parameters
optimizer.step()
# loss logging
train_loss_running += loss.item()
iteration = epoch * len(train_dataloader) + batch_idx
if iteration % config['print_every_n'] == (config['print_every_n'] - 1):
train_loss = train_loss_running / config["print_every_n"]
print(f'[{epoch:03d}/{batch_idx:05d}] train_loss: {train_loss:.6f}')
# save best train model and latent codes
if train_loss < best_loss:
torch.save(model.state_dict(), f'exercise_3/runs/{config["experiment_name"]}/model_best.ckpt')
torch.save(latent_vectors.state_dict(), f'exercise_3/runs/{config["experiment_name"]}/latent_best.ckpt')
best_loss = train_loss
train_loss_running = 0.
# visualize first 5 training shape reconstructions from latent codes
if iteration % config['visualize_every_n'] == (config['visualize_every_n'] - 1):
# Set model to eval
model.eval()
latent_vectors_for_vis = latent_vectors(torch.LongTensor(range(min(5, latent_vectors.num_embeddings))).to(device))
for latent_idx in range(latent_vectors_for_vis.shape[0]):
# create mesh and save to disk
evaluate_model_on_grid(model, latent_vectors_for_vis[latent_idx, :], device, 64, f'exercise_3/runs/{config["experiment_name"]}/meshes/{iteration:05d}_{latent_idx:03d}.obj')
# set model back to train
model.train()
# lr scheduler update
scheduler.step()
def main(config):
"""
Function for training DeepSDF
:param config: configuration for training - has the following keys
'experiment_name': name of the experiment, checkpoint will be saved to folder "exercise_2/runs/<experiment_name>"
'device': device on which model is trained, e.g. 'cpu' or 'cuda:0'
'num_sample_points': number of sdf samples per shape while training
'latent_code_length': length of deepsdf latent vector
'batch_size': batch size for training and validation dataloaders
'resume_ckpt': None if training from scratch, otherwise path to checkpoint (saved weights)
'learning_rate_model': learning rate of model optimizer
'learning_rate_code': learning rate of latent code optimizer
'lambda_code_regularization': latent code regularization loss coefficient
'max_epochs': total number of epochs after which training should stop
'print_every_n': print train loss every n iterations
'visualize_every_n': visualize some training shapes every n iterations
'is_overfit': if the training is done on a small subset of data specified in exercise_2/split/overfit.txt,
train and validation done on the same set, so error close to 0 means a good overfit. Useful for debugging.
"""
# declare device
device = torch.device('cpu')
if torch.cuda.is_available() and config['device'].startswith('cuda'):
device = torch.device(config['device'])
print('Using device:', config['device'])
else:
print('Using CPU')
# create dataloaders
train_dataset = ShapeImplicit(config['num_sample_points'], 'train' if not config['is_overfit'] else 'overfit')
train_dataloader = torch.utils.data.DataLoader(
train_dataset, # Datasets return data one sample at a time; Dataloaders use them and aggregate samples into batches
batch_size=config['batch_size'], # The size of batches is defined here
shuffle=True, # Shuffling the order of samples is useful during training to prevent that the network learns to depend on the order of the input data
num_workers=0, # Data is usually loaded in parallel by num_workers
pin_memory=True # This is an implementation detail to speed up data uploading to the GPU
)
# Instantiate model
model = DeepSDFDecoder(config['latent_code_length'])
# Instantiate latent vectors for each training shape
latent_vectors = torch.nn.Embedding(len(train_dataset), config['latent_code_length'], max_norm=1.0)
# Load model if resuming from checkpoint
if config['resume_ckpt'] is not None:
model.load_state_dict(torch.load(config['resume_ckpt'] + "_model.ckpt", map_location='cpu'))
latent_vectors = torch.nn.Embedding.from_pretrained(torch.load(config['resume_ckpt'] + "_latent.ckpt", map_location='cpu'))
# Move model to specified device
model.to(device)
latent_vectors.to(device)
# Create folder for saving checkpoints
Path(f'exercise_3/runs/{config["experiment_name"]}').mkdir(exist_ok=True, parents=True)
# Start training
train(model, latent_vectors, train_dataloader, device, config)