File size: 4,170 Bytes
70860f8 2680a1e bf31b06 2d4ceff 16eb4c5 ef3ae40 2d4ceff 9fddd0e 3c1aa27 9fddd0e d1858c9 de95105 d1858c9 71f69a9 d1858c9 0b01beb d1858c9 19053f7 d1858c9 de95105 71f69a9 de95105 d1858c9 fd864f7 ff7122c d1858c9 2d4ceff a121d01 2d4ceff 1f17719 2d4ceff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
---
license: apache-2.0
---
# **m**utual **i**nformation **C**ontrastive **S**entence **E**mbedding (**miCSE**):
[](https://arxiv.org/abs/2211.04928)
Language model of the pre-print arXiv paper titled: "_**miCSE**: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings_"
# Brief Model Description
The **miCSE** language model is trained for sentence similarity computation. Training the model imposes alignment between the attention pattern of different views (embeddings of augmentations) during contrastive learning. Learning sentence embeddings with **miCSE** entails enforcing the syntactic consistency across augmented views for every single sentence, making contrastive self-supervised learning more sample efficient. This is achieved by regularizing the attention distribution. Regularizing the attention space enables learning representation in self-supervised fashion even when the _training corpus is comparatively small_. This is particularly interesting for _real-world applications_, where training data is significantly smaller thank Wikipedia.
# Intended Use
The model intended to be used for encoding sentences or short paragraphs. Given an input text, the model produces a vector embedding, which captures the semantics. The embedding can be used for numerous tasks, e.g., **retrieval**, **clustering** or **sentence similarity** comparison (see example below). Sentence representations correspond to the embedding of the _**[CLS]**_ token.
# Model Usage
```shell
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
tokenizer = AutoTokenizer.from_pretrained("sap-ai-research/miCSE")
model = AutoModel.from_pretrained("sap-ai-research/miCSE")
# Encoding of sentences in a list with a predefined maximum lengths of tokens (max_length)
max_length = 32
sentences = [
"This is a sentence for testing miCSE.",
"This is yet another test sentence for the mutual information Contrastive Sentence Embeddings model."
]
batch = tokenizer.batch_encode_plus(
sentences,
return_tensors='pt',
padding=True,
max_length=max_length,
truncation=True
)
# Compute the embeddings and keep only the _**[CLS]**_ embedding (the first token)
# Get raw embeddings (no gradients)
with torch.no_grad():
outputs = model(**batch, output_hidden_states=True, return_dict=True)
embeddings = outputs.last_hidden_state[:,0]
# Define similarity metric, e.g., cosine similarity
sim = nn.CosineSimilarity(dim=-1)
# Compute similarity between the **first** and the **second** sentence
cos_sim = sim(embeddings.unsqueeze(1),
embeddings.unsqueeze(0))
print(f"Distance: {cos_sim[0,1].detach().item()}")
```
# Training data
The model was trained on a random collection of **English** sentences from Wikipedia: [Training data file](https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt)
# Benchmark
Model results on SentEval Benchmark:
```shell
+-------+-------+-------+-------+-------+--------------+-----------------+--------+
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | S.Avg. |
+-------+-------+-------+-------+-------+--------------+-----------------+--------+
| 71.71 | 83.09 | 75.46 | 83.13 | 80.22 | 79.70 | 73.62 | 78.13 |
+-------+-------+-------+-------+-------+--------------+-----------------+--------+
```
## Citations
If you use this code in your research or want to refer to our work, please cite:
```
@article{Klein2022miCSEMI,
title={miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings},
author={Tassilo Klein and Moin Nabi},
journal={ArXiv},
year={2022},
volume={abs/2211.04928}
}
```
#### Authors:
- [Tassilo Klein](https://tjklein.github.io/)
- [Moin Nabi](https://moinnabi.github.io/) |