Introduction to PaliGemma
PaliGemma is a new class of visual language models (VLMs) that is part of Google's latest release as a push for multi-modal models and model efficiency. While large language models have been trending larger and larger to get closer to achieving universal input and output, PaliGemma alternatively focuses on developing relatively compact architectures that can be fine-tuned to perform competitively with larger models.
To learn more about PaliGemma’s architecture, high-level design details, as well as downstream use cases, check out our introductory article.
Getting Started with Fine-Tuning
To fine-tune PaliGemma on your own specific task or dataset, you will first need to convert your annotation files into a JSON Lines file (.jsonl
), where each line of text contains a JSON object containing the following key-value pairs:
Preparing Caption-Based Datasets
To fine-tune PaliGemma to output text for downstream image captioning or visual question-answering (VQA) use cases, the prefixes should represent the task that you want the model to learn to perform. To give an example on image captioning, prefixes can simply be standardized as “caption”
, but you can optionally append the language used if your dataset is multilingual, such as “caption en”
for English captions. Suffixes should contain the ground truths of what the image caption should look like, such as “two people are running side by side”
.
Challenges On Caption-Based Downstream Tasks
Since PaliGemma employs token loss, the model may not learn from the data if the captions are largely similar. For example, when training the model to predict where the location of an object is in a VQA setting, seemingly logical suffixes (answers) would be “the object is in the top-right of the image”
, or “the object is in the bottom-left of the image”
. However, when considering each token to be a single word, we see that both captions have around an 80% similarity in tokens, with the difference being the actual location itself. Though the location is the important piece of information that the model needs to learn from, the model does not get heavily penalized when predicting the wrong location, as long as it is able to get the majority of the tokens right. The model ends up being content with simply guessing the location, which is catastrophic. Generically, this issue demonstrates a significant difficulty in mitigating model hallucination, given that the underlying model losses are unable to precisely target these issues.
Preparing Object Detection Datasets
For object detection, prefixes can be standardized as “detect [CLASS]”
, with semicolons separating classes in a multiclass scenario - “detect [CLASS1] ; [CLASS2] ; …”
. Suffixes consist of a set of four location tokens describing bounding box coordinates and one token per object class name for each object to be detected in the image, with the following format: “<locXXXX><locYYYY><locXXXX><locYYYY> [CLASS]”
.
By default, Google’s mix pretrained models standardize the bounding box coordinates in unnormalized [xmin, ymin, xmax, ymax]
format, but you can modify your training data accordingly if you want the model to learn to output bounding boxes in [xmin, ymin, width, height]
or even [cx, cy, width, height]
.
Challenges on Object Detection Downstream Tasks
PaliGemma works well in cases where each image has a single or a couple of objects to be detected. However, in our ablation testing where we increased the number of annotated objects per image, we found that the model performance dropped quite significantly. This can be attributed to the difference in loss functions between PaliGemma and traditional CNNs. In typical object detection tasks where CNNs are used, the loss function takes into account factors such as the overlap (IoU) between ground truth and predicted bounding boxes. In the case of PaliGemma, token loss is used where the <loc>
tokens are compared between ground truth and predictions. Given that token position significantly impacts token loss, the model can get heavily penalized when outputting the bounding boxes in a different order as compared to the ground truths, despite correctly and accurately identifying all instances.
Furthermore, given that the number of objects per image can vary, the model output may be constrained by the maximum output token length set as a hyperparameter. This can cause tokens to be erroneously truncated e.g. bounding boxes with only two coordinates, or bounding boxes with missing classes. This also affects segmentation outputs. Even if the output tokens are correct, the number of objects detected is still limited by the maximum output tokens.
For instance, if the maximum token output is 256, it would only be able to output 42 objects. For segmentation, it can only output 11 object masks. As such, if the ground truth annotations contain more than the token limit, the model is fundamentally incapable of replicating those results with a prediction. Though a simple fix would be to increase the maximum output token length, this significantly increases the model’s VRAM consumption, which is not too scalable for real-time and memory-constrained use cases.
Preparing Segmentation Datasets
Fine-tuning PaliGemma for segmentation use cases is somewhat similar in overall structure for object detection use cases, as each object will still have its own designated class or identifier in the training examples as well as the four location tokens which describe the coordinates that bound the object. The main difference is the extra 16 segmentation tokens that the model must learn which represent a binary mask that fits within the bounding box described by the 4 location tokens. There are 128 possible segmentation tokens that can be used, and their associated codebook has been directly utilized out-of-the-box, using the work by Ning et al. (2023). The output of the encoder are token IDs that can be mapped to segmentation tokens that can make up the following format.
The format for each mask object will look something like this, and can be split up by semicolon separators:
<locXXXX><locXXXX><locXXXX><locXXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX><segXXX> [CLASS]
Once these segmentation tokens for each instance mask in an image have been added to the corresponding suffix, one can simply follow the same training steps to fine-tune PaliGemma.
To reduce repetitive information to the above steps, we will largely focus on the encoding and decoding of this binary mask representation for instance segmentation use cases.
Encoding Instance Masks into Tokens
The encoder is a small neural network that takes in a 64x64 binary mask and encodes it into a 1D array of 16 integer IDs that can then be mapped to their associated tokens. The decoder is composed of four downsampling layers that use convolutions to reduce the dimension of the array. Subsequently, a residual block is used, like what is used in ResNet, composed of convolutional layers and ReLU layers, as well as addition of the original input. The output of the residual block goes to a final convolutional layer. The output of this convolutional layer is then converted into tokens using a latent codebook containing 128 embeddings with length of 512. The output is reshaped into a 16x512 array and then each vector is mapped to the nearest embedding by L2 distance. Finally, the matching token IDs are listed.
Decoding Segmentation Tokens into Instance Masks
The decoding process goes through the opposite process from the encoder, with a mirrored architecture. The predicted tokens from the model are converted back into embeddings. Subsequently, the embeddings go through a convolutional block, two residual blocks, four transpose convolutional blocks, and a final convolutional layer to get back to a 64x64 mask. This mask can then be resized and mapped back to its associated bounding box to get a full sized mask for the image.
Challenges on Segmentation Downstream Tasks
One source of issues is the lossy and black box nature of the binary mask encoding and decoding process. The usage of decoders and encoders creates a non-intuitive process to mapping tokens with binary masks. There is no clear indication of how the latent token embeddings relate to each other or even how they individually contribute to the binary masks. As such, the process in which PaliGemma learns to output specific segmentation masks is as broad as learning any other tokens, and because the encoder and decoder are not fine-tuned alongside PaliGemma, the overall network does not utilize or learn any fundamental intuition of the computer vision space.
The other limitation is that the expressiveness of the latent codebook has only been empirically tested and may not necessarily possess the needed accuracy. In our own usage, decoding masks is at best, approximately close but not exact, and at worst, rather deficient in recovering the original binary mask from its native resolution, with pixel artifacts and noise permeating the surrounding original mask. Though the size of the codebook is efficient, with only 128 vectors, encoding and directly decoding binary masks is lossy, and even pixel-perfect ground truth mask data can be corrupted just from the mask’s conversion into segmentation tokens which act as the ground truth for PaliGemma’s training data. As shown in the diagram below, recovery of mask details is not guaranteed during the encoding and decoding process.
Additionally, PaliGemma’s decoded mask has an output resolution of 64x64 pixels and is subsequently scaled to a bounding box. In the case of large images, or the requirement of detailed binary masks along the border of the mask, PaliGemma’s mask encoder and decoder do not possess the capabilities to reconstruct such details.
Training Setup
Once your JSONL file has been created and you have your training images collected, it is also recommended to split your dataset into a train split, a validation split for evaluating the model accuracy, as well as a test split to evaluate the model’s generalizability on unseen data. The suggested ratios are 70%, 20%, and 10% for each of the splits respectively, but this can be adjusted depending on the overall size of your dataset.
With these, you can then refer to Google’s fine-tuning notebook to implement your own fine-tuning pipeline in JAX. Additionally, you can utilize our notebooks that we have developed which have preprocessed datasets with their corresponding annotation JSON Lines file to try them out yourself. You will simply just have to change the relative file paths to the image dataset and your labels.
Challenges with Fine-Tuning VLMs
While VLMs have been gaining traction for bridging the gap between text and images through the adoption of multi-modal inputs and outputs, the complexity of multi-modality has introduced some issues during the fine-tuning process that could potentially have a significant impact on model performance. We list some of the major issues below.
Misalignment between LLM and Image Encoder Tokens
Architecturally, VLMs face the intricate task of aligning linguistic and visual information seamlessly. While language model components (LLMs) excel in generating fluent text, integrating visual knowledge poses a significant challenge. Due to the utilization of disparate pre-trained models for extracting visual and textual information, despite attempts to align language and image encoder tokens during training with attention modules, discrepancies can still arise due to a lack of deep alignment between the separately generated tokens, leading to instances where generated text fails to accurately reflect visual content. The disjointed training processes for language and image encoders contribute to this misalignment, hindering the model's ability to achieve coherent multimodal understanding.
Loss of Generality
Like LLMs, VLMs are often sought out for their universal zero-shot ability for a broad range of tasks. While model fine-tuning certainly improves the model's efficacy on targeted tasks or datasets, it may lead to a loss of generality due to catastrophic forgetting. Fine-tuning the model for specific tasks can compromise its ability to handle diverse inputs or tasks it hasn't explicitly been trained on. This lack of flexibility poses challenges in real-world scenarios where a broad understanding across various domains is imperative, and makes it necessary for users to conduct thorough and extensive evaluations of their model’s expected capabilities.
Hallucinations
Hallucinations in VLMs refer to the generation of outputs that lack grounding in the input data, leading to potentially erroneous responses. These hallucinations can stem from various sources, including biases in training data, complexities in multimodal interactions, or limitations in model architectures. In critical applications such as medical diagnosis or autonomous systems, hallucinations undermine the reliability of the model's outputs, highlighting the need for robust mitigation strategies.
General Challenges with Fine-Tuning PaliGemma
PaliGemma inherits the generic flaws of typical VLMs, but new task-specific challenges arise when fine-tuning downstream tasks. We take a look at three particular domains - image captioning, object detection, and segmentation - and highlight some considerations for each. Below, we describe issues that apply to all tasks as well as issues unique to the individual tasks.
Lack of Task-Specific Loss Functions
The current, most commonly used pipeline for fine-tuning PaliGemma exclusively uses perplexity, a natural language processing metric, as the underlying calculation for the training loss that the model uses to fine-tune weights. While this might be largely appropriate for a generic language model, such losses don’t really emphasize or even acknowledge important task-specific metrics.
In the worst case of the tasks listed, the common losses used for instance segmentation are the bounding box loss, which tracks the metric distance between predicted and ground truth coordinates, the classification loss, which tracks whether the predicted class matches the ground truth class, and the mask loss, which evaluates how the actual binary pixel masks align with the ground truth. None of these task losses can be implemented without inducing instability to the training, and additionally may not necessarily produce improved training results because the underlying architecture does not lend itself to intuitively learning visual 2D array patterns which the losses listed above are designed to encourage.
Lack of Innate Confidence Scores
Confidence scores are an important part of traditional computer vision neural networks in interpreting model performance, and also post-processing for the best results at inference time. It is possible to use token probability distributions to interpret model performance, but they are not currently filtered to determine how only relevant tokens compare to each other, such as only comparing the usage of location tokens to others. Naively, this can be implemented, but this also ignores the possibility of all other tokens being predicted, and is computationally expensive to hold onto full scale probability distributions given the size of the corpus. Additionally, token probability distributions do not directly align with computer vision task outputs, such as bounding boxes or segmentation masks, so aggregating token probability distributions is an approximate and possibly misleading measure of task output confidence.
The lack of reliable confidence scores also severely inhibits certain computer vision post processing techniques, such as non-maximum suppression (NMS) for object detection and instance segmentation, in which highly overlapping predictions are removed to produce a clean, best estimate of the model’s predictions. With the lack of confidence scores, it is impossible to select certain predicted instances over others, making NMS impossible to use, and thus predictions cannot be cleaned.
Potential Solutions
Despite the shortcomings that PaliGemma introduces due to its nature of being a VLM, there are some techniques that can be employed to mitigate some of these problems, such as implementing guardrails and tuning the model hyperparameters well. We will detail some of these techniques below, but will provide further commentary in an upcoming blog post on tackling the challenges of fine-tuning PaliGemma and other VLMs.
Implementing Guardrails for Model Alignment
Implementing guardrails for VLMs like PaliGemma is essential to minimize hallucinations and ensure accurate and reliable outputs, especially when dealing with complex, multimodal inputs. Guardrails can include both input and output preprocessing techniques designed to enhance model performance and stability.
On the input side, preprocessing can involve filtering out redundant or irrelevant tokens and ensuring consistent ordering of bounding box tokens, which helps the model better understand spatial relationships.
For outputs, postprocessing steps can remove invalid tokens and reorder bounding box tokens to match ground truth annotations, aligning the model’s predictions more closely with expected results.
In alignment with the latest industry focus on responsible AI, another form of guardrails that can be implemented can be to introduce an stages where the inputs and outputs are fed into LLMs, using tools like LlamaGuard or NeMo Guardrails to ensure that the inputs and outputs do not violate certain societal principles. Of course, this could potentially be a costly implementation given the additional cost and complexity of hosting a separate LLM service just for guardrailing.
Additionally, custom loss functions that reflect the specific needs of the downstream task can provide a more precise training signal, further aligning model predictions with desired outcomes. These guardrails not only improve the coherence and relevance of the model's outputs but also enhance its robustness, making it more capable of handling the intricacies of vision-language tasks in diverse real-world scenarios.
Tuning Model Hyperparameters
To maximise the chances that PaliGemma works well for your specific downstream task, you may need to modify the default hyperparameters that Google used in their fine-tuning notebook. We list some of them that we experimented with in the following sections.
Token Length & Batch Size
For object detection and segmentation use cases, you may need to balance your token lengths and batch sizes to account for varying number of objects, as well as minimize any occurrences of out-of-memory errors during the training process depending on your compute environment. In our testing with these use cases, we found that following Google’s default parameters with a token length of 128, and a batch size of 8 work well on a T4 GPU.
Optimizers & Learning Rate
Google’s pretrained models simply utilize Stochastic Gradient Descent (SGD) as the optimizer, with a learning rate of 0.03. From our own experience with traditional CNNs, we wanted to explore if popular optimizers such as Adam would be ideal for the object detection and segmentation use cases as well. Through some testing we found that the following configurations worked well.
Of course, these are simply ballpark figures that you may need to tune further based on what you observe during the training process. For example, if you observe that the losses are decreasing too slowly, you may need to increase the learning rate; and conversely if you observe that the losses are exploding, you may need to decrease the learning rate. Check out our documentation for best practices on dealing with overfitting and underfitting.
If you wish to identify an optimal set of hyperparameters, you can explore hyperparameter tuning using search methods like grid search or even Bayesian optimization.
It should also be noted that PaliGemma fine-tuning training processes can be highly variable. This can lead to degenerative training processes. This is an unfortunate byproduct of large language models in general, and are not necessarily indicative of datasets or underlying training pipelines.
What’s Next?
If you have questions, feel free to join our Community Slack to post your questions or contact us to finetune your own PaliGemma model on Datature Nexus.
For more detailed information about the model functionality, customization options, or answers to any common questions you might have, read more on our Developer Portal.
Developer’s Roadmap
Datature recognizes the importance of VLMs in practical use cases. With this in mind, we will be incorporating the PaliGemma architecture for fine-tuning, such that Datature Nexus users will be able to import and annotate multimodal datasets that can be used to train a PaliGemma model for their specific use cases. To learn more about the underlying PaliGemma architecture, you can read this article here. We also note that while achieving academic benchmarks is significant, practical deployment still requires other considerations such as guardrailing to ensure consistency and quality.
Build models with the best tools.
develop ml models in minutes with datature