How to train with ArcFace loss to improve model classification accuracy

If your classification accuracy is low, try this method, it will blow your mind.

Yiwen Lai
9 min readOct 17, 2021
We are going deep to investigate on how things work.
We are going to dive deep to investigate how ArcFace works.

I bet you have seen tons of COVID-19 x-ray image projects, but “bear” with me as this is not like the rest of the posts you see online and we probably had enough of COVID-19 by now. Our focus here was to understand how ArcFace works, implement it first hand, and uncover what must we do need to make it tick.

Hope that you are convinced to stay for a while, so let’s get started. Here is the post summary.

  1. Data preprocessing
  2. Baseline model
  3. What is ArcFace loss
  4. ArcFace Model
  5. Model Evaluation
  6. Summary

All the codes can be found in this repository.

1. Data preprocessing

Samples of x-ray images

The chest x-ray images are downloaded from Kaggle, there are tons of links that provide them but for this project, we only need a small subset of them. We will need to ensure that our dataset is balance in each class, which consists of the following 4 labels normal, bacterial, covid-19, viral. Note that having viral and COVID-19 labels together will create an additional challenge, as we all know COVID-19 is a type of viral infection, what we are trying to do here is to ask the model to classify a superset and a subset of class in viral infection.

When viewing these x-ray images, you soon notice some of these images are annotated this might give the model chances to cheat by looking at the annotation. We can mitigate these by using cropping to clip away these annotations at the sides. Other things that I have noticed are the contrast is different for each of the images and it contains some side view and top-down view x-ray images. So for the different contrast, we could mitigate it by using histogram equalization, as for the different views we will ignore this issue and include them in the training.

We will then create a data generator that does the following image augmentation crop, rotate, flip, histogram equalization. If your machine had at least 32GB of RAM, I would suggest loading all the images in memory. This greatly reduces the training time, we are looking at 18s to 3s per epoch, which is a 6x improvement in training time.

Finally, we ended up with the following dataset

Bacterial: 563
Normal: 563
COVID-19: 563
Viral: 563

2. Baseline Model

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 222, 222, 32) 896
_________________________________________________________________
average_pooling2d (AveragePo (None, 74, 74, 32) 0
_________________________________________________________________
batch_normalization (BatchNo (None, 74, 74, 32) 128
_________________________________________________________________
dropout (Dropout) (None, 74, 74, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 175232) 0
_________________________________________________________________
dense (Dense) (None, 64) 11214912
_________________________________________________________________
batch_normalization_1 (Batch (None, 64) 256
_________________________________________________________________
dense_1 (Dense) (None, 4) 260
=================================================================
Total params: 11,216,452
Trainable params: 11,216,260
Non-trainable params: 192
_________________________________________________________________

Our baseline model is very simple, the purpose of choosing such a simple architecture is because we do not want the complexity of well-known architecture like the popular ResNet50 as the accuracy of the model was not part of this experiment. What we want is to study the properties of ArcFace and how far it can push from baseline accuracy.

Training summary

              precision    recall  f1-score   support

Normal 0.74 0.97 0.84 90
Bacterial 0.68 0.64 0.66 97
Viral 0.71 0.57 0.63 98
COVID-19 0.96 0.94 0.95 115

accuracy 0.78 400
macro avg 0.77 0.78 0.77 400
weighted avg 0.78 0.78 0.78 400

After training, we reach the f1-score of 0.78, not bad for such a simple model. Notice that the model have difficulty learning Bacterial and Viral class. Next, let take a look at the training graph.

simple_model loss and accuracy graph

A quick glance, the model does not show signs of serious over-fitting or under-fitting. But notice there are spikes here and there, these are probably caused by the data generator giving the model a higher number of Bacterial or Viral classes during the training epoch. I would also like to highly the validation loss plateau at around 0.5, take note of this number as it will be an important figure to note if our ArcFace model is training properly. I will explain it in the latter part of this post.

Softmax + Cross-Entropy

With all the training put aside, let's take a re-cap on what Softmax and Cross-Entropy do. Softmax is an activation function that outputs the probability of each class and they will sum up to one. Cross-Entropy is a loss function that sums up the negative logarithm of the probabilities. These two are often used together for classification tasks.

Yes, we all know this long ago what’s your point?

The Voynich Manuscript

My point is, image embedding produce by this method is computed through Euclidean Distance on a hyperplane. Important features are extracted by calculating the similarity between each metrics. The hyperplane dimension can be as high as 512. Let’s try to visualize this, imagine each plane is a piece of paper, so the image embedding hyperplane will look like a book with 512 pages. Keep this concept in mind it will be helpful in understanding why ArcFace perform so well later.

3. What is ArcFace Loss

Arcface or Additive Angular Margin Loss is a loss function used in the state-of-the-art face recognition task. It is capable of handling tens of billions of classes (individual faces). This means as of the current date in 2021 world population is approaching 8 billion, if you have enough resources, you could train an earth-level facial recognition system that recognizes everyone.

Creepy.

Trying to unlock your phone with a mask.

Unlike Euclidean Distance, ArcFace loss calculates uses Geodesic Distance on a hypersphere which is a higher dimension space compare to Euclidean Distance. Let’s try to visualize this again, imagine each sphere is now a book, a dimension of 512 will be a shelf full of books. In theory, it should be able to obtain highly discriminative features than using Softmax + Cross-Entropy.

Interstellar — Inside a black hole

Another difference between Euclidean and Geodesic is that Euclidean distance completely ignores the shape when finding a path from start point to endpoint. Whereas Geodesic will be constrained to the given shape.

4. Arcface Model

Next, we will look at the actual implementation of ArcFace into our baseline model. Adding the layer is very easy, just replace the classifier head with the ArcFace layer. Do take note of the input connections to the ArcFace layer.

Simple! our job is done let’s have some coffee. ☕ …wait a minute…

simple_arcface_model summary

Immediately we find the training is stuck, every epoch learns nothing, and sometimes the val_loss is getting worst. After a few rounds of tweaking, Googling, Kaggling on how to train this model. It was found that converging the ArcFace model from scratch is hard. It will require a good embedding at the start and then fine-tune it with ArcFace. From experience, it is best when transfer learns using a baseline model with frozen top layers (in blue) and then tune the lower layers (in red). The intuition behind this is that because we want the ArcFace layer to tune the embeddings but keeping the “internal workings” that are used to derive them.

Next, we change the parameters of ArcFace. We need to set a higher learning rate at 1e-2. We also need to set s (scaler) from 30 to 6 and m (angular margin) from 0.5 to 2. This is required because ArcFace is originally designed to classify billions of classes, our task now only has 4 classes, using the default value will skew too much during the training.

In short, set s lower and set m higher when your class is small

5. Model evaluation

              precision    recall  f1-score   support

Normal 1.00 1.00 1.00 107
Bacterial 1.00 1.00 1.00 122
Viral 1.00 1.00 1.00 106
COVID-19 1.00 1.00 1.00 117

accuracy 1.00 452
macro avg 1.00 1.00 1.00 452
weighted avg 1.00 1.00 1.00 452

Let’s look at the result from the training. We got a 100% f1-score on every single class! The performance gain from just changing into the ArcFace layer seems unreal. For a data scientist, this is usually a sign of bugs and requires some verification.

The 100% accuracy training error

It turns out, there is a common error when training ArcFace. Which leads to 100% accuracy with just a few epochs. This is due to an error in the calculation of the logits in ArcFace loss, when setting the s (scaler) value too small, the logits will return a bunch of -1 and will take your input label as the output value. To solve this issue we need to adjust the s to a higher value and the problem should be resolved.

Simple_Arcface_Model training loss and accuracy

Remeber our baseline model val_loss (0.5)? We can now used it to compare our current val_loss value and it is significantly lower.

Verify the validation loss

To confirm our model does not falls into “100% accuracy training error”, we can compare val_loss during our baseline training. Since the result is below 0.5, this reassures us that our model is correct. In addition to this, we can check our model through clustering of our model embedding. We use a dimension reduction algorithm such as UMAP and plot it on a graph. It should form clear clusters, 4, in this case, to show that the embeddings learned are meaningful.

In the graph, the red embedding cluster (Bacterial) contains some Virus embeddings. If we look back at our baseline model, we can see that this is a property that is preserved even after ArcFace tuning. But overall after ArcFace tuning, we can get meaningful clusters even using only 2 parameters.

Baseline model performance            precision    recall  f1-score   support
Bacterial 0.68 0.64 0.66 97
Viral 0.71 0.57 0.63 98
UMAP — Property of baseline model preserved after ArcFace tuning.

Summary

ArcFace can be used to improve classification model accuracy with minimum change to an existing architecture. The cost of getting the performance is low, it only requires additional steps in verification to ensure that the result produced is valid.

In summary here are a few points to take note of when training an ArcFace model:

  • Fine-tuning an existing model with frozen layers will converge faster
  • Setting s (scaler) smaller and m (angular margin) higher if the number of classes is small.
  • Take note of val_loss and val_acc during training to detect error
  • Using embedding of the cluster as part of the verification process

--

--

Yiwen Lai
Yiwen Lai

Written by Yiwen Lai

🤖 AI² | NTU Computer Science Graduate | NUS M.Tech Knowledge Engineering | https://twitter.com/Niel_Lai

No responses yet