What is K-Shot Learning?
Deep neural networks have achieved state-of-the-art for many computer vision tasks. However, much of this performance improvement can be accredited to their utilisation and reliance on large amounts of supervised information for learning. There are many practical cases in which such training data is not available. Few-shot learning as an approach is tasked with dealing with such issues. Few-shot learning is a type of supervised learning that is intended to rapidly generalise to new tasks containing only a few samples of supervised information based on prior knowledge. At the extremes of this is one-shot learning, where a model is only given one reference per class at the inference stage before it has to find other instances in new images. The most extreme approach is Zero-Shot Learning, in which there is no data provided that can be compared directly at test time. As such, Zero-Shot Learning is generally considered a fundamentally different problem. Naturally, as information restrictions become more severe, approaches must also adapt.
More formally, given a set of training images with objects covered by a set of training classes, N-Way K-Shot learning seeks to generalise to solve computer vision tasks accurately for object instances in new images belonging to novel classes that didn’t appear in the training data, given the restricted prior information of K object instances for N classes.
In this article, we are primarily concerned with the object detection task. Object detection is a significantly difficult task in the few-shot learning context that has prompted a wide variety of methods to tackle it well.
Approaches to K-Shot Learning have generally been fundamentally facilitated by transfer learning. K-Shot Learning depends on the feature extraction from standard computer vision networks such as ResNet and then takes various approaches to dealing with the feature maps that are outputted. Approaches to the feature information can then be separated into several categories: meta-learning, attention-based adaptations, and fine-tuning.
Fine-tuning as an approach is more common when more data is available, so when the value of K is greater. Fine-tuning is appealing due to its simplicity and because it is less susceptible to over-training to novel datasets. Fine-tuning as an approach simply trains on a small dataset of strictly limited novel instances to tune the object detection model to adapt to the novel classes, and has shown surprising practical success. Common baseline models for this approach are the Faster R-CNN, which is a combination of a region proposal network combined with the architecture of the region of interest (RoI) pooling from feature spaces outputted from a CNN.
Meta-learning has been another popular approach to K-Shot Learning in training a framework to generalise to new classes with few examples. The general concept behind meta-learning is to facilitate the development of a generalized space that allows for sufficient inter-class discrimination through the lens of semantic relationships. Siamese networks are one of the most common architectures for object detection in. Siamese networks use identical subnetworks with the same weights with different inputs.
The outputs of those subnetworks can then be compared for similarity through similar features through some definition of a metric. To train such networks, we use triplet loss. Using a metric, we can define distances between an anchor image and a known similar image and distances between an anchor image and a known dissimilar image, and minimise the difference of those distances to a threshold. However, previously most methods have two main problems, poor detection APs, and strong bias because of imbalance and insufficient datasets. Previous works mainly alleviate these issues by additional datasets, multi-relation attention mechanisms, and sub-modules. However, they require more cost to train and operate well.
Attention-based frameworks have been popular due to the concept of mapping support features to query features from feature maps. It uses pre-existing comparison and matching knowledge to extend and generalise rapidly to new instances. This has been implemented in numerous different papers, but suffers from a large model size due to it being an entirely additional framework on top of base object detection models.
📐 Approach 1: OS2D One-Shot Object Detection
Problem Setup
OS2D seeks to achieve high object detection on novel images based on a single novel sample object instance based on prior knowledge built up from a larger training dataset.
General Approach
Using typical feature extractors from well-established, pre-trained computer vision neural networks, they extract feature maps which are then transformed to a compatible size from which correlation tensors are calculated. Based on the correlation tensors calculated, they are traced backward through the feature mapping process such that anchor features that are mapped together can then be reverse engineered to determine accurate bounding boxes for the matching features.
Model Goal
The model’s goal is to minimise the following losses. Firstly, hinge-embedding loss with margins for recognition, essentially accounting for how similar the prediction and ground truth output is. Secondly, localization smooth L1 loss which is calculated by taking the sum over the losses calculated from four corners of the two bounding boxes. This is the traditional loss used to calculate the level of inaccuracy for bounding boxes. Additionally, ranked list loss was experimented with to better account for class imbalances and weight the more common negative class to improve training results.
Architecture
OS2D was designed based on the general framework of Rocco et. al. First, operate on dense feature maps of spatial size extracted from some ResNet/VGG. Match features by computing a correlation tensor containing correlations between each of the feature dimensions. These correlation tensors are then reshaped and fed into a regular 2D convolution neural network. Then outputs are computed by resampling correlations and selecting boxes.
With OS2D, the feature extractor computes dense feature maps between input and class features, and then uses a classic computer vision resampling technique, bilinear resampling on the class feature map to get a fixed feature size to allow for neural network training. We then apply the same neural net on the correlation tensor similarly to get the transformations defined globally with regards to the input feature map coordinates. The transformations are then fed into a grid sampler to produce a grid of points aligning the class image at each location of the input image.
Training Methodology
Training images are preprocessed by downsampling through random crop/scale data augmentation in order to achieve a more manageable image size that doesn’t compromise aspect ratio. Each batch collects a set of class images based on annotated classes within the batch as positive samples and collects images from random classes as negative examples. Then predictions are made and losses are computed by making comparisons between the positive and negative examples.
Demo
In the following example, we use two class images as opposed to the intended one image to give the model more class instances to compare to, in order to make the model more robust to varying query images. One need not use both images.
To try it yourself, click here.
📐 Approach 2: Decoupled Faster R-CNN (DeFRCN) for Few-Shot Object Detection
Problem Setup
DeFRCN seeks to achieve high object detection on novel images based on the model fine-tuning of a set number of novel sample object instances after general training on a larger training dataset.
General Approach
Faster R-CNN has been used widely for image detection with strong practical performance. However, to accommodate for the few-shot scenario, DeFRCN seeks to split the RPN and RCNN tasks by either stopping gradients, scaling gradients, or decoupling the tasks all together, as the authors believe that optimising for both tasks can lead to conflicting goals that result in sub-optimal training. Additionally, this decoupling should allow for better fine-tuning to novel image data.
Model Goal
Similar to OS2D, the model uses traditional Siamese network pairwise scores from positive and negative examples to build comparison scores between query features and support set feature prototypes.
Architecture
Typical Faster R-CNN architecture comes with a backbone network that connects to a Region Proposal Network and RCNN Head, which takes proposals from the RPN as well, which finally goes into a box classifier and box regressor. All these elements are typically connected. However, the backbone and RCNN Head now scales gradients, and the RPN is no longer connected gradient wise to the backbone, and decouple the box regressor and box classifier into separate tasks. DeFRCN additionally contains two Gradient Decoupled Layers and an offline Prototypical Calibration Block. The gradient decoupled layers are inserted between the shared backbone and RPN and between the RCNN and the backbone to adjust the degree of decoupling. They are specifically designed to learn affine transformations on the forward feature maps and multiply the backward gradients by constants. The PCB takes region proposals from the few-shot detectors as input and boosts the original softmax scores with additional prototype-based pairwise scores.
Training Methodology
The traditional datasets like Pascal-VOC and COCO are used in this paper for training, so their training and validation splits have been already, and the few shot examples are randomly selected based on different random seeds and differ for each number of shots that is being experimented with. General training on the overall dataset using pairwise scores to generate general object detection capability. At test time, novel fine-tuning on the novel classes occurs using the few-shot novel data to improve on the novel dataset.
Demo
To try it yourself, click here.
📐 Approach 3: FewX for Few-Shot Object Detection
Problem Setup
FewX seeks to detect objects of unseen categories with a few annotated prior instance examples, with or without novel finetuning.
General Approach
FewX uses a combination of techniques to improve generalisation to novel categories. It utilises an Attention-RPN, a Multi-Relation Detector, and Contrastive Training strategy which exploits the similarities between support sets and query sets to detect novel objects while suppressing false detection using the contrastive training and multi-relation detector. Further, their base model is trained by a self-developed dataset designed to improve the model in the few-shot object detection task directly to avoid the necessity of fine-tuning, which many other few-shot learning models are dependent on.
Model Goal
During training, for each sampled proposal, the loss is constructed to be the sum of matching loss and box loss where the box loss is the traditional bounding-box loss defined above, and the matching loss is binary cross entropy loss.
Architecture
The base of the model is still fundamentally a CNN (ResNet) used to pick up features from the image. The first novel part of the model is the attention based RPN. The component of the attention that it is constructed is on the feature map based on support information to enable filtering out background and non-matching categories. The second component is the multi-relation detector is designed to be robust against several different scales of relation: global, local, and patch correlation, using attention once again to refine its ability to catch such information.
Training Methodology
The training methodology is relatively unique as compared to other few shot approaches. Firstly, this paper develops its own few shot dataset using a variety of datasets to better emulate the few-shot scenario and increase adaptivity and generalizability. Their 2-way contrastive training creates training triplets with positive and negative support images and query images. The training is unique in that the loss is computed between the positive image as well as determining differences between specific negative support images that are treated as background to encourage learning of inter-class differentiation.
📐 Approach 4: Zero-Shot Instance Segmentation
Problem Setup
Instance segmentation is a downstream problem for object detection. The particular difficulty with instance segmentation in this case is that zero-shot instance segmentation requires that there is absolutely no experience with the instances appearing in the novel images, so the model must learn the difference between foreground and background. This could possibly be used in a pipeline for object detection by first using segmentation and then approach the classification separately.
General Approach
The main idea is to learn the relationship between visual and semantic concepts from seen classes data and transferring it to detect unseen objects. They use a new classification head that is more semantics inclined which should push the model to be more discriminative in dealing with visual and semantic relationships. They also have a background aware RPN which introduces the semantic process to help learn a more reasonable word-vector for background classes from images.
Model Goal
The loss function is composed of the loss of the background-aware RPN, the loss of the zero-shot detector, and the loss for the semantic classification head. Background aware RPN uses foreground-background classification cross-entropy loss and smooth L1 regression. The zero shot detector uses classification loss and L1 regression loss as well as reconstruction loss. The semantic head takes reconstruction loss and per-pixel binary classification loss.
Architecture
Again, the backbone is a CNN that provides the feature maps from which the word vectors are constructed semantically. The semantic head uses an encoder-decoder architecture aiming to encode into word vectors and then reconstruct the image visually from the vector. They implement a synchronised background strategy and use synchronous training for all components together. The classification module comes after the encoder and that’s all that is used for inference. word2vec is used as the semantic word-vector space for class embeddings
Training Methodology
Using standard datasets with segmentation annotation, prediction on known data in the training is used and trains on the loss described above. At inference, query images contain a mixture of novel and old training data.
Conclusion
All four methods have their own benefits and excel in their own individual areas, and equally have their own drawbacks.
Zero shot instance segmentation is most impressive given a complete lack of knowledge on image data in the test set. However, it is reliant on the ambient feature embedding space which isn’t always available for your potentially unique datasets and instance classes. Additionally, given the lack of information, the in-practice precision is much more prone to issues with variance.
For few-shot approaches in general, they will definitely struggle more in one-shot or zero-shot settings compared to zero-shot and one-shot models.
However, they are more likely to generalise better with more class instance examples. FewX for few shot object detection is the most robust to variable class instances than the other three approaches, given its more extensive network model. However, it is by far the most expensive computationally, with sizable GPU RAM requirements.
DeFRCN is another few-shot model which is more compact model-wise. However, the power of its generalisation to new class instances is reliant on its novel fine-tuning process, which can be time consuming for certain use-cases.
Lastly, OS2D is the one-shot approach that is easily plug-and-play. However, given its one-shot setup, OS2D is susceptible to variance in the class instance example and variance in prediction, when the query image is quite visually different. Additionally, the current model setup isn’t exactly designed to have consistent information gain with more class instance examples.
In short, each has its own pros and cons and one should carefully consider what they value before implementing one of these models for their own use case.
References
Algorithms and code used from:
OS2D: https://github.com/aosokin/os2d
DeFRCN: https://github.com/er-muyue/DeFRCN
FewX: https://github.com/fanq15/FewX
Zero-Shot Instance Segmentation: https://github.com/zhengye1995/Zero-shot-Instance-Segmentation
Build models with the best tools.
develop ml models in minutes with datature