File size: 2,948 Bytes
a48b15f
 
ce6be70
f3cd231
e64ca4e
00b5438
f3cd231
0f7de99
a48b15f
 
 
 
e64ca4e
a48b15f
 
f3cd231
a48b15f
0f7de99
a48b15f
 
 
00b5438
a48b15f
00b5438
 
 
 
874e761
a48b15f
 
 
 
 
d2471f2
ce6be70
d2471f2
ce6be70
75132dc
0f7de99
a48b15f
64b132e
a48b15f
 
0f7de99
a48b15f
 
 
 
00b5438
 
75132dc
 
00b5438
 
d2471f2
00b5438
 
 
75132dc
 
 
00b5438
 
 
 
75132dc
00b5438
75b9622
 
 
1f20712
75b9622
 
a48b15f
0f7de99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np

from lmsim.metrics import Metrics, CAPA, EC

from src.dataloading import load_run_data_cached
from src.utils import softmax, one_hot

def load_data_and_compute_similarities(models: list[str], dataset: str, metric_name: str) -> np.array:
    # Load data
    probs = []
    gts = []
    for model in models:
        model_probs, model_gt = load_run_data_cached(model, dataset)
        probs.append(model_probs)
        gts.append(model_gt)
    
    # Compute pairwise similarities
    similarities = compute_pairwise_similarities(metric_name, probs, gts)
    return similarities


def compute_similarity(metric: Metrics, outputs_a: list[np.array], outputs_b: list[np.array], gt: list[int],) -> float:
    # Check that the models have the same number of responses
    assert len(outputs_a) == len(outputs_b) == len(gt), f"Models must have the same number of responses: {len(outputs_a)} != {len(outputs_b)} != {len(gt)}" 
    
    # Compute similarity values
    similarity = metric.compute_k(outputs_a, outputs_b, gt)
    
    return similarity


def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]], gts: list[list[int]]) -> np.array:
    # Select chosen metric
    if metric_name == "CAPA":
        metric = CAPA()
    elif metric_name == "CAPA (det.)":
        metric = CAPA(prob=False)
        # Convert logits to one-hot
        probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
    elif metric_name == "Error Consistency":
        probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
        metric = EC()
    else:
        raise ValueError(f"Invalid metric: {metric_name}") 

    similarities = np.zeros((len(probs), len(probs)))
    for i in range(len(probs)):
        for j in range(i, len(probs)):
            outputs_a = probs[i]
            outputs_b = probs[j]
            gt_a = gts[i].copy()
            gt_b = gts[j].copy()

            # Format softmax outputs
            if metric_name == "CAPA":
                outputs_a = [softmax(logits) for logits in outputs_a]
                outputs_b = [softmax(logits) for logits in outputs_b]

                # Remove indices where the ground truth differs
                # (This code assumes gt_a and gt_b are lists of integers.)
                indices_to_remove = [idx for idx, (a, b) in enumerate(zip(gt_a, gt_b)) if a != b]
                for idx in sorted(indices_to_remove, reverse=True):
                    del outputs_a[idx]
                    del outputs_b[idx]
                    del gt_a[idx]
                    del gt_b[idx]

            try:
                similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
            except Exception as e:
                print(f"Failed to compute similarity between models {i} and {j}: {e}")
                similarities[i, j] = np.nan
        
            similarities[j, i] = similarities[i, j]
    return similarities