3D-EPN for shape completion

3D-EPN for shape completion

Bibtex #

@inproceedings{dai2017shape,
  title={Shape completion using 3d-encoder-predictor cnns and shape synthesis},
  author={Dai, Angela and Ruizhongtai Qi, Charles and Nie{\ss}ner, Matthias},
  booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
  pages={5868--5877},
  year={2017}
}

Architecture #

Stage 1: Encoder - Predictor Network #

-> Give coarse estimate of structure

  • Makes use of an Distance field representation for 3d reconstruction input (3d_epn_for_shape_completion_9df6c1227fbfed5584730c604e1a5ebf6a871447.svg)
  • Fairly standard set of 3D CNNs that are spatially compressing down
  • Using Skip Connections to keep generation closer to the input
  • Decoder predicting the distance field (3d_epn_for_shape_completion_9df6c1227fbfed5584730c604e1a5ebf6a871447.svg)
  • 3d classification network used to predict a whole shape class, so that the generator can have more information about the shape
  • L1 Loss Function for the distance field prediction

Insight:

  • When missing big parts of the shape, local methods suffer (obviously)

Parts (Omitting classification) #

Encoder
4 layers, each one containing a 3D convolution (with kernel size 4, as seen in the visualization), a 3D batch norm (except the very first layer), and a leaky ReLU with a negative slope of 0.2. Our goal is to reduce the spatial dimension from 32x32x32 to 1x1x1 and to get the feature dimension from 2 (absolute values and sign) to num_features * 8. We do this by using a stride of 2 and padding of 1 for all convolutions except for the last one where we use a stride of 1 and no padding. The feature channels are increased from 2 to num_features in the first layer and then doubled with every subsequent layer.
Decoder
Same architecture as encoder, just mirrored: Going from num_features * 8 * 2 (the 2 will be explained later) to 1 (the DF values). The spatial dimensions go from 1x1x1 to 32x32x32. Each layer use a 3D Transpose convolution now, together with 3D batch norm and ReLU (no leaky ReLUs anymore). Note that the last layer uses neither Batch Norms nor a ReLU since we do not want to constrain the range of possible values for the prediction.
Bottleneck
This is realized with 2 fully connected layers, each one going from a vector of size 640 (which is num_features * 8) to a vector of size 640. Each such layer is followed by a ReLU activation.
Skip connections
allow the decoder to use information from the encoder and also improve gradient flow. We use it here to connect the output of encoder layer 1 to decoder layer 4, the output of encoder layer 2 to decoder layer 3, and so on. This means that the input to a decoder layer is the concatenation of the previous decoder output with the corresponding encoder output, along the feature dimension. Hence, the number of input features for each decoder layer are twice those of the encoder layers, as mentioned above.
Log scaling
You also need to scale the final outputs of the network logarithmically: out = log(abs(out) + 1). This is the same transformation you applied to the target shapes in the dataloader before and ensures that prediction and target volumes are comparable.

Stage 2: “Upsample” low res result #

  1. Take coarse prediction
  2. Database lookup and find (a couple) nearest neighbors
  3. Then for parts of the reconstructed mesh:
    • Find most similar part from the objects of (2)
    • composite them in

End-To-End considerations #

  • This can’t easily be trained End to End since database lookups are not trivially differentiable
  • Training end to end would result in the first network learning to predict the most “useful” intermittant shape for retrieval and reconstruction

3D representation considerations #

Implementation #

Model (Omitting classification) #

import torch
import torch.nn as nn

class ThreeDEPN(nn.Module):
    def __init__(self):
        super().__init__()

        self.num_features = 80

        self.leaky_relu = nn.LeakyReLU(0.2)
        self.relu       = nn.ReLU()

        # NOTE: 4 Encoder layers
        self.e_conv1 = nn.Conv3d(2,                   self.num_features,   kernel_size=4, stride=2, padding=1)
        self.e_conv2 = nn.Conv3d(self.num_features,   self.num_features*2, kernel_size=4, stride=2, padding=1)
        self.e_bn2   = nn.BatchNorm3d(self.num_features*2)
        self.e_conv3 = nn.Conv3d(self.num_features*2, self.num_features*4, kernel_size=4, stride=2, padding=1)
        self.e_bn3   = nn.BatchNorm3d(self.num_features*4)
        self.e_conv4 = nn.Conv3d(self.num_features*4, self.num_features*8, kernel_size=4, stride=1, padding=0)
        self.e_bn4   = nn.BatchNorm3d(self.num_features*8)

        # NOTE: 2 Bottleneck layers
        self.bottleneck = nn.Sequential(
            nn.Linear(640, 640),
            self.relu,
            nn.Linear(640, 640),
            self.relu,
        )

        # NOTE: 4 Decoder layers
        self.d_conv1 = nn.ConvTranspose3d(self.num_features*8*2, self.num_features*4, kernel_size=4, stride=1, padding=0)
        self.d_bn1   = nn.BatchNorm3d(self.num_features*4)
        self.d_conv2 = nn.ConvTranspose3d(self.num_features*4*2, self.num_features*2, kernel_size=4, stride=2, padding=1)
        self.d_bn2   = nn.BatchNorm3d(self.num_features*2)
        self.d_conv3 = nn.ConvTranspose3d(self.num_features*2*2,   self.num_features, kernel_size=4, stride=2, padding=1)
        self.d_bn3   = nn.BatchNorm3d(self.num_features)
        self.d_conv4 = nn.ConvTranspose3d(self.num_features*2,                     1, kernel_size=4, stride=2, padding=1)



    def forward(self, x):
        b = x.shape[0]
        # TODO: Pass x though encoder while keeping the intermediate outputs for the skip connections
        # Reshape and apply bottleneck layers

        # Encode
        x_e1 = self.e_conv1(x)
        x_e1 = self.leaky_relu(x_e1)

        x_e2 = self.e_conv2(x_e1)
        x_e2 = self.e_bn2(x_e2)
        x_e2 = self.leaky_relu(x_e2)

        x_e3 = self.e_conv3(x_e2)
        x_e3 = self.e_bn3(x_e3)
        x_e3 = self.leaky_relu(x_e3)

        x_e4 = self.e_conv4(x_e3)
        x_e4 = self.e_bn4(x_e4)
        x_e4 = self.leaky_relu(x_e4)

        # bottleneck
        x = x_e4.view(b, -1)
        x = self.bottleneck(x)
        x = x.view(x.shape[0], x.shape[1], 1, 1, 1)

        # NOTE: Pass x through the decoder, applying the skip connections in the process
        # Decode

        x_d1 = self.d_conv1(torch.cat((x,    x_e4), dim=1))
        x_d1 = self.d_bn1(x_d1)
        x_d1 = self.relu(x_d1)

        x_d2 = self.d_conv2(torch.cat((x_d1, x_e3), dim=1))
        x_d2 = self.d_bn2(x_d2)
        x_d2 = self.relu(x_d2)

        x_d3 = self.d_conv3(torch.cat((x_d2, x_e2), dim=1))
        x_d3 = self.d_bn3(x_d3)
        x_d3 = self.relu(x_d3)

        x_d4 = self.d_conv4(torch.cat((x_d3, x_e1), dim=1))

        x = x_d4
        x = torch.squeeze(x, dim=1)

        # NOTE: Log scaling
        x = torch.log(torch.abs(x) + 1)

        return x

Inference #

import numpy as np
import torch
from skimage.measure import marching_cubes

from exercise_3.model.threedepn import ThreeDEPN


class InferenceHandler3DEPN:
    def __init__(self, ckpt):
        """
        :param ckpt: checkpoint path to weights of the trained network
        """
        self.model = ThreeDEPN()
        self.model.load_state_dict(torch.load(ckpt, map_location='cpu'))
        self.model.eval()
        self.truncation_distance = 3

    def infer_single(self, input_sdf, target_df):
        """
        Reconstruct a full shape given a partial observation
        :param input_sdf: Input grid with partial SDF of shape 32x32x32
        :param target_df: Target grid with complete DF of shape 32x32x32
        :return: Tuple with mesh representations of input, reconstruction, and target
        """
        # NOTE Apply truncation distance: SDF values should lie within -3 and 3, DF values between 0 and 3
        input_sdf = np.clip(input_sdf, -3, 3).reshape((32,32,32))
        target_df = np.clip(target_df, 0, 3).reshape((32,32,32))

        with torch.no_grad():
            # NOTE: Pass input in the right format though the network and revert the log scaling by applying exp and subtracting 1
            input = np.expand_dims(input_sdf, axis=0)
            input = np.append(input, np.clip(input, -1, 1), axis=0)
            input[0] = np.absolute(input[0])

            input = torch.from_numpy(input).reshape((1,2,32,32,32))

            reconstructed_df = self.model(input)
            reconstructed_df = torch.exp(reconstructed_df) - 1

        input_sdf = np.abs(input_sdf)
        input_mesh = marching_cubes(input_sdf, level=1)
        reconstructed_mesh = marching_cubes(reconstructed_df.squeeze(0).numpy(), level=1)
        target_mesh = marching_cubes(target_df, level=1)
        return input_mesh, reconstructed_mesh, target_mesh
Calendar October 22, 2023