0%

PROTO-CLIP: Vision-Language Prototypical Network for Few-Shot Learning

PROTO-CLIP

PROTO-CLIP: Vision-Language Prototypical Network for Few-Shot Learning

Motivation

在机器人感知的setting下:

  • object model-based方法建立物体的3D模型,并用3D模型来进行物体识别,但真实世界中难以获取大规模的3D模型
  • object catagory-based方法识别有限类别的物体,很难对每个类别都收集大量的图像
  • ImageNet和Visual Genome这类Internet data和robot manipulation存在domain gap

因此从少量的图片示例中学习如何识别一个新的物体对于scale up机器人可识别的物体非常重要

image-20230724184706199

基于few-shot training images同时adapt图像和文本编码器,在adapt过程中对齐图像原型和文本原型,以提升few-shot classification性能

CLIP-based Few-Shot Learning

  • N-way, K-shot Setting

    • N个类别,每个类别K张图像,$N\times K$张图像用于训练

    • 模型训练好后就可以用于N个类别的测试

  • 将原始的CLIP称作Zero-shot CLIP,

  • Linear-probe CLIP则表示在CLIP的图像特征基础上训练一个logistic regression classifier

  • adapt the text encoder

    • CoOp: 可学习的text prompt

      [8] Learning to prompt for vision-language models. IJCV2022

  • adapt the image encoder

    • Clip-adapter: 在图像编码器端加两层linear transformertions,文本编码器加残差连接进行few-shot learning

      [9] Clip-adapter: Better vision-language models with feature adapters 2021

    • Tip-adapter: 建立key-value cache model,keys是CLIP的图像特征,values是类别标签的one-hot vectors(也可以替换为可学习参数,提高分类准确率)。

      [10] Tip-adapter: Training-free adaption of clip for few-shot classification 2022

image-20230724192423524

Meta-learning-based Few-Shot Learning

  • setting:

    • 每个类别都包含support set和query set,在训练期间,support set和query set的类别标签都是可用的,测试时只有support set可用,目标是预测query set的标签

    • 与CLIP-based Few-Shot Learning的N-way, K-shot Setting区别在于CLIP-based Few-Shot Learning N-way, K-shot Setting并不对训练集进行划分,train和test是在相同的类别上进行的,而Meta-learning-based approaches则是使用训练集类别来训练一个meta-learner,在novel classes上利用对应的support sets进行adapt

Meta-learning-based approaches train a meta-learner with the training classes $C_{train}$ that can be adapted to the novel classes $C_{test}$ using their support sets.

  • Non-episodic approaches use all the data in Ctrain for training such as k-NN and its ‘Finetuned’ variants
  • Episodic approaches construct episodes, i.e., a subset of the training classes, to train the meta-learner
    • Prototypical Networks
    • Matching Networks
    • Relation Networks
    • Model Agnostic Meta-Learning (MAML)
    • Proto-MAML
    • CrossTransformers
  • In this work, we consider training and testing in the same classes following previous CLIP-based few-shot learning methods [8, 9, 10].

Others

  • 机器人环境的fsl数据集:
    • [15] Fewsol: A dataset for few-shot object learning in robotic environments

Methodology

image-20230724195251104

本文直接考虑N-way K-shot classification

  • CLIP模型全程冻结,利用CLIP模型的图像编码器和文本编码器获得分类概率,最终的条件分类概率表示为image probability和text probability的加权和,用超参$\alpha$来控制。

    $P\left(y=k \mid \mathbf{x}^{q}, \mathcal{S}\right)=\alpha \underbrace{P\left(y=k \mid \mathbf{x}^{q}, \mathcal{S}_{x}\right)}_{\text {image probability }}+(1-\alpha) \underbrace{P\left(y=k \mid \mathbf{x}^{q}, \mathcal{S}_{y}\right)}_{\text {text probability }},$

  • 使用prototypical networks来建模给定query图像$\mathbf{x}^q$和support set $S=\{\mathbf{x}_i^s,y_i^s\}^M_{i=1}$时class label $y$的条件概率分布

    $\begin{aligned}
    P\left(y=k \mid \mathbf{x}^{q}, \mathcal{S}_{x}\right) & =\frac{\exp \left(-\beta\left|g_{\mathbf{w}_{1}}\left(\mathbf{x}^{q}\right)-\mathbf{c}_{k}^{x}\right|_{2}^{2}\right)}{\sum_{k^{\prime}=1}^{N} \exp \left(-\beta\left|g_{\mathbf{w}_{1}}\left(\mathbf{x}^{q}\right)-\mathbf{c}_{k^{\prime}}^{x}\right|_{2}^{2}\right)}, \\
    P\left(y=k \mid \mathbf{x}^{q}, \mathcal{S}_{y}\right) & =\frac{\exp \left(-\beta\left|g_{\mathbf{w}_{1}}\left(\mathbf{x}^{q}\right)-\mathbf{c}_{k}^{y}\right|_{2}^{2}\right)}{\sum_{k^{\prime}=1}^{N} \exp \left(-\beta\left|g_{\mathbf{w}_{1}}\left(\mathbf{x}^{q}\right)-\mathbf{c}_{k^{\prime}}^{y}\right|_{2}^{2}\right)},
    \end{aligned}$

    • 其中$g_{\mathbf{w}_{1}}(\cdot)$代表CLIP图像编码器+adapter network(可学习参数$\mathbf{w}_1$)的组合,用于计算query图像的特征
    • $\mathbf{c}_{k}^{x},\mathbf{c}_{k}^{y}$are the “prototypes” for class k computed using images and text, respectively.
    • $\beta \in \mathbb{R}^+$ is a hyperparameter to sharpen the probability distributions

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def P(zq_imgs_flat, z_img_proto, z_text_proto, alpha, beta):
"""
Returns probability dist, p = alpha * p_i + (1-alpha) * p_t
"""
# compute pairwise euclidean distances(query, prototypes)
xq_img_proto_dists = torch.cdist(
zq_imgs_flat.float(), z_img_proto.float(), p=2).pow(2)
xq_text_proto_dists = torch.cdist(
zq_imgs_flat.float(), z_text_proto.float(), p=2).pow(2)

# P(y=k|query_image,support_images)
p_i = F.softmax(beta*(-xq_img_proto_dists), dim=1)

# P(y=k|query_image,support_text)
p_t = F.softmax(beta*(-xq_text_proto_dists), dim=1)

# total probability = alpha * p_image + (1-alpha) - p_text
p = alpha * p_i + (1-alpha) * p_t

return p
  • class $k$的prototypes表示为:

    $\mathbf{c}_{k}^{x}=\frac{1}{M_{k}} \sum_{y_{i}^{s}=k} \phi_{\text {Image }}\left(\mathbf{x}_{i}^{s}\right), \mathbf{c}_{k}^{y}=\frac{1}{\tilde{M}_{k}} \sum_{j=1}^{\tilde{M}_{k}} \phi_{\text {Text }}\left(\operatorname{Prompt}_{j}\left(y_{i}^{s}=k\right)\right),$

    • $M_k$代表label为$k$的图像样本数量,$\tilde M_k$代表label $k$对应的prompts数量
    • prototype被表示为多个example的特征做平均

Learning the memories and the adapter

adapter采用基于卷积的方式,参数量更少,两种adapter各有优势(见实验部分)

image-20230724202059285

image-20230724205340095

Loss Functions

  • query图像的分类损失

    $\mathcal{L}_{1}\left(\mathbf{W}_{\text {image }}, \mathbf{W}_{\text {text }}, \mathbf{w}_{1}\right)=-\log P\left(y^{q}=k \mid \mathbf{x}^{q}, \mathcal{S}\right)$

    • $\mathbf{W}_{\text {image }}$代表image memory,$\mathbf{W}_{\text {text }}$代表text memory,$\mathbf{w}_{1}$代表adapter network的可学习参数
    • $P\left(y^{q}=k \mid \mathbf{x}^{q}, \mathcal{S}\right)$是前面提到的条件概率分布
  • 在训练阶段将image prototypes和text prototypes进行对齐

    • $\{\mathbf{c}_{1}^{x},\mathbf{c}_{2}^{x},…,\mathbf{c}_{N}^{x}\}$表示image prototypes,$\{\mathbf{c}_{1}^{y},\mathbf{c}_{2}^{y},…,\mathbf{c}_{N}^{y}\}$表示text prototypes

    • 基于对比学习在特征空间拉近$\mathbf{c}_{k}^{x},\mathbf{c}_{k}^{y}$之间的距离,并拉远与其他prototypes间的距离,使用InfoNCE loss:

      $\mathcal{L}_{2}^{k}\left(\mathbf{c}_{k}^{x},\left\{\mathbf{c}_{k^{\prime}}^{y}\right\}_{k^{\prime}=1}^{N}\right)=-\log \frac{\exp \left(\mathbf{c}_{k}^{x} \cdot \mathbf{c}_{k}^{y}\right)}{\sum_{k^{\prime}=1}^{N} \exp \left(\mathbf{c}_{k}^{x} \cdot \mathbf{c}_{k^{\prime}}^{y}\right)}, \mathcal{L}_{3}^{k}\left(\mathbf{c}_{k}^{y},\left\{\mathbf{c}_{k^{\prime}}^{x}\right\}_{k^{\prime}=1}^{N}\right)=-\log \frac{\exp \left(\mathbf{c}_{k}^{y} \cdot \mathbf{c}_{k}^{x}\right)}{\sum_{k^{\prime}=1}^{N} \exp \left(\mathbf{c}_{k}^{y} \cdot \mathbf{c}_{k^{\prime}}^{x}\right)}$

  • 训练时的总loss function:

    $\begin{array}{c}
    \mathcal{L}=-\frac{1}{L} \sum_{j=1}^{L} \log P\left(y_{j}^{q}=k \mid \mathbf{x}_{j}^{q}, \mathcal{S}\right)+\frac{1}{N} \sum_{k=1}^{N}\left(\mathcal{L}_{2}^{k}\left(\mathbf{c}_{k}^{x},\left\{\mathbf{c}_{k^{\prime}}^{y}\right\}_{k^{\prime}=1}^{N}\right)+\mathcal{L}_{3}^{k}\left(\mathbf{c}_{k}^{y},\left\{\mathbf{c}_{k^{\prime}}^{x}\right\}_{k^{\prime}=1}^{N}\right)\right)
    \end{array}$

1
2
3
4
5
6
7
8
9
10
11
12
13
if len(cfg['losses']) == 0 or 'L1' in cfg['losses']:
nloss = nn.NLLLoss()
loss += nloss(torch.log(p), target_inds)

if 'L2' in cfg['losses']:
# L2: img with all text alignment loss
img2txt_align_loss = InfoNCELoss(z_img_proto, z_text_proto)
loss += img2txt_align_loss

if 'L3' in cfg['losses']:
# L3: text with all img alignment loss
txt2img_align_loss = InfoNCELoss(z_text_proto, z_img_proto)
loss += txt2img_align_loss

文本端

以oxford_pets为例,共37类

1
template = ['a photo of a {}, a type of pet.']

Experiment

Datasets and Evaluation Metric

  • ImageNet [5], StandfordCars [24], UCF101 [25], Caltech101 [26], Flowers102 [27], SUN397 [28], DTD [29],
    EuroSAT [30], FGVCAircraft [31], OxfordPets [32], and Food101 [33], FewSOL dataset [15]

image-20230724204340911

  • Tip-F是Tip-Adapter的fine-tuned版本
  • PROTO-CLIP代表没有对the image memory and the text memory进行训练,没有使用adapter
  • PROTO-CLIP-F代表train the image memory and/or the text memory with the adapter,图片特征是预先抽取的
  • PROTO-CLIP-F-$Q^T$代表训练阶段使用了数据增强操作,图片让特征是在训练时抽取的

image-20230724204733578

Limitations

  • low-shot情况下表现不佳,不如Tip
  • 在每个新的数据集上都需要做hyperparameter grid search,每个数据集都需要不同的设置