An Introduction to Deep Learning on Meshes

Course Objective

This course introduces the fundamentals of machine learning techniques on meshes. We start with a short introduction of machine learning on regular structures (e.g., images), and discuss their generalization to the irregular mesh structure. Despite several alternative representations exist (e.g., implicits, voxels, point clouds), we focus our discussion on the techniques directly defined on surface triangle meshes, which is one of the most commonly used representations in graphics over the decades. Specifically, we will cover the building blocks of a mesh convolutional neural network, including convolution, pooling, regularization, and loss functions. We will also cover how to apply convolutional networks to classic geometry processing tasks, such as geometric texture synthesis, shape classification, and subdivision. We propose a mesh MNIST dataset for research prototyping and hand-on demos to these learning tasks.

The course materials are still under active development. If you have any comments or are intersted in contributing, please contact us.

Course Materials

Datasets

The MNIST dataset contains thousands of handwritten digits. This is commonly used as the Hello World for deep learning on images. We create an analogous 3D mesh MNIST dataset for 3D deep learning in Link.

Demos

We introduce a series of self-contained examples based on open source libraries such as JAX and PyTorch. The purpose of these examples is to demonstrate how to implement a simple machine learning model on meshes.

1. Simple mesh CNN without pooling

We present a basic example on using mesh CNN to classify meshes of "1" and meshes of "2" from our meshMNIST dataset. We will cover the full pipeline of loading a shape, computing input features, define a mesh CNN, and the network training. We recommend readers to use well-optimized functions for e.g. loading training data or network components. But for clarity purposes, we implement an as-simple-as-possible mesh CNN from scratch using JAX without those utility functions.

Let's get started by import the basic libraries that we will need in this tutorial
We represent each shape as a surface triangle mesh and store as a triangle soup. Specifically, the mesh is represented by two matrices:
  • a #vertices by 3 matrix of vertex positions V.
  • a #faces by 3 matrix of triangle indices F into the vertex list V.
We then pre-compute the information that are needed for the training process. This includes
  • a #edges by 2 matrix of (undirected) edge indices E.
  • a #edges by 4 matrix of flap edge indices flap (4 boundary edges of each two adjacent triangles) into the edge list E.
  • a #edges by n matrix of input features fE where n is the featrue dimension.
This compute_input computes the input feature for each edge. Which input feature to compute is a choice to make, here we simply compute the dihedral angle and the 4 edge length ratios for simplicity Since this pre-processing usually takes a while, we often process the entire dataset once and save the information After processing the meshes, we can now define the mesh CNN. Similar to a lot of mesh CNNs (e.g., MeshCNN), our architecture consists of a bunch of convolution layers, followed by a global pooling and a fully connected layers. For simplicity, we omit the mesh pooling processing in this toy example.

In each mesh convolution, the input is a n dimensional function defined on each edge fE in our case. For other applications, this could be a function defined on each face/vertex. No matter "where" we store the input function, the first step in mesh convolution is to re-index this function into a format for fast convolutions. In this example, we will re-index the #edges by n edge functions fE into flap functions fP following the formula in MeshCNN Equation 2 as Each flap function is a n by 5 matrix, thus fP is a #edges by n by 5 3D matrix. After that, mesh convolution can be performed easily with 2D convolution, where each convolution filter has size n by 5 and it outputs a scalar value. We implement it from scratch for clarity, but one should switch to a more well-optimized implementation for performance reasons With the structure of mesh convolution in mind, we can now initialize the network parameters and define the forward pass (convolutions, global pooling, fully connected multilayer perceptron) as The rest of the pipeline is the same as e.g. image based classification networks. We need to define the loss function, optimizer, parameter update functions Then start the network training The complete version of this tutorial can be found here. This training code takes one shape at a time and use stochastic gradient descent to optimize the network. One could also accumulate the gradient from a batch of shapes before taking a gradient step. Ideally, one should also use the dataloader so that loading the data won't be the bottleneck of training. One could also use a validation set to early stop the training to avoid overfitting. After training one can then use the learned parameter to make predictions

Contacts

This course of deep learning on meshes is developed by Rana Hanocka and Hsueh-Ti Derek Liu. If any questions or suggestions, please contact us.