Add files from other repo
Browse files- .gitignore +3 -0
- config.yaml +31 -0
- marcai/__init__.py +0 -0
- marcai/find_matches.py +73 -0
- marcai/pl/__init__.py +2 -0
- marcai/pl/attribute_selector.py +12 -0
- marcai/pl/marc_data_module.py +51 -0
- marcai/pl/similarity_vector_dataset.py +26 -0
- marcai/pl/similarity_vector_model.py +90 -0
- marcai/predict.py +75 -0
- marcai/process.py +269 -0
- marcai/processing/__init__.py +1 -0
- marcai/processing/comparisons.py +249 -0
- marcai/processing/normalizations.py +36 -0
- marcai/train.py +100 -0
- marcai/utils/__init__.py +1 -0
- marcai/utils/load_config.py +6 -0
- marcai/utils/parsing.py +93 -0
- requirements.txt +11 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.ipynb_checkpoints
|
3 |
+
.DS_Store
|
config.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
# Inputs features
|
3 |
+
features:
|
4 |
+
- title_tokenset
|
5 |
+
- title_agg
|
6 |
+
- author
|
7 |
+
- publisher
|
8 |
+
- pub_date
|
9 |
+
- pub_place
|
10 |
+
- pagination
|
11 |
+
# Size of hidden layers
|
12 |
+
hidden_sizes:
|
13 |
+
- 32
|
14 |
+
- 64
|
15 |
+
|
16 |
+
# Training
|
17 |
+
batch_size: 512
|
18 |
+
weight_decay: 0.0
|
19 |
+
max_epochs: -1
|
20 |
+
|
21 |
+
# Disable early stopping with -1
|
22 |
+
patience: 20
|
23 |
+
|
24 |
+
lr: 0.006
|
25 |
+
optimizer: Adam
|
26 |
+
saved_models_dir: saved_models
|
27 |
+
|
28 |
+
# Paths to dataset splits
|
29 |
+
test_processed_path: data/202303_goldfinch_set_1.1/processed/test_processed.csv
|
30 |
+
train_processed_path: data/202303_goldfinch_set_1.1/processed/train_processed.csv
|
31 |
+
val_processed_path: data/202303_goldfinch_set_1.1/processed/val_processed.csv
|
marcai/__init__.py
ADDED
File without changes
|
marcai/find_matches.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from process import multiprocess_pairs
|
3 |
+
from predict import predict_onnx
|
4 |
+
from tqdm import tqdm
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from marcai.utils.parsing import load_records, record_dict
|
8 |
+
from marcai.utils import load_config
|
9 |
+
|
10 |
+
import csv
|
11 |
+
|
12 |
+
def main():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
|
15 |
+
parser.add_argument(
|
16 |
+
"-p",
|
17 |
+
"--pair-indices",
|
18 |
+
help="File containing indices of comparisons",
|
19 |
+
required=True,
|
20 |
+
)
|
21 |
+
parser.add_argument("-C", "--chunksize", help="Chunk size", type=int, default=50000)
|
22 |
+
parser.add_argument(
|
23 |
+
"-P", "--processes", help="Number of processes", type=int, default=1
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"-m",
|
27 |
+
"--model-dir",
|
28 |
+
help="Directory containing model ONNX and YAML files",
|
29 |
+
required=True,
|
30 |
+
)
|
31 |
+
parser.add_argument("-o", "--output", help="Output file", required=True)
|
32 |
+
parser.add_argument("-t", "--threshold", help="Threshold for matching", type=float)
|
33 |
+
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
config_path = f"{args.model_dir}/config.yaml"
|
37 |
+
model_onnx = f"{args.model_dir}/model.onnx"
|
38 |
+
|
39 |
+
config = load_config(config_path)
|
40 |
+
|
41 |
+
# Load records
|
42 |
+
print("Loading records...")
|
43 |
+
records = []
|
44 |
+
for path in args.inputs:
|
45 |
+
records.extend([record_dict(r) for r in load_records(path)])
|
46 |
+
|
47 |
+
records_df = pd.DataFrame(records)
|
48 |
+
|
49 |
+
print(f"Loaded {len(records)} records.")
|
50 |
+
|
51 |
+
print("Processing and comparing records...")
|
52 |
+
written = False
|
53 |
+
with open(args.pair_indices, "r") as indices_file:
|
54 |
+
reader = csv.reader(indices_file)
|
55 |
+
# Process records
|
56 |
+
for df in tqdm(multiprocess_pairs(
|
57 |
+
records_df, reader, args.chunksize, args.processes
|
58 |
+
)):
|
59 |
+
input_df = df[config["model"]["features"]]
|
60 |
+
prediction = predict_onnx(model_onnx, input_df)
|
61 |
+
df.loc[:, "prediction"] = prediction.squeeze()
|
62 |
+
|
63 |
+
df = df[df["prediction"] >= args.threshold]
|
64 |
+
|
65 |
+
if not df.empty:
|
66 |
+
if not written:
|
67 |
+
df.to_csv(args.output, index=False)
|
68 |
+
written = True
|
69 |
+
else:
|
70 |
+
df.to_csv(args.output, index=False, mode="a", header=False)
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
main()
|
marcai/pl/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .similarity_vector_model import SimilarityVectorModel
|
2 |
+
from .marc_data_module import MARCDataModule
|
marcai/pl/attribute_selector.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class AttributeSelector(nn.Module):
|
5 |
+
def __init__(self, attrs):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
self.attrs = attrs
|
9 |
+
|
10 |
+
def forward(self, sim: dict) -> dict:
|
11 |
+
sim = {key: sim[key] for key in self.attrs if key in sim.keys()}
|
12 |
+
return sim
|
marcai/pl/marc_data_module.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
import torch
|
4 |
+
from .attribute_selector import AttributeSelector
|
5 |
+
from .similarity_vector_dataset import SimilarityVectorDataset
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
|
9 |
+
class MARCDataModule(pl.LightningDataModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
train_processed_path: str,
|
13 |
+
val_processed_path: str,
|
14 |
+
test_processed_path: str,
|
15 |
+
attrs: List[str],
|
16 |
+
batch_size: int,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.train_processed_path = train_processed_path
|
21 |
+
self.val_processed_path = val_processed_path
|
22 |
+
self.test_processed_path = test_processed_path
|
23 |
+
|
24 |
+
self.batch_size = batch_size
|
25 |
+
self.transform = torch.nn.Sequential(AttributeSelector(attrs))
|
26 |
+
|
27 |
+
self.train_set = None
|
28 |
+
self.val_set = None
|
29 |
+
self.test_set = None
|
30 |
+
|
31 |
+
def setup(self, stage=None):
|
32 |
+
self.train_set = SimilarityVectorDataset(
|
33 |
+
self.train_processed_path, transform=self.transform
|
34 |
+
)
|
35 |
+
self.val_set = SimilarityVectorDataset(
|
36 |
+
self.val_processed_path, transform=self.transform
|
37 |
+
)
|
38 |
+
self.test_set = SimilarityVectorDataset(
|
39 |
+
self.test_processed_path, transform=self.transform
|
40 |
+
)
|
41 |
+
|
42 |
+
def train_dataloader(self):
|
43 |
+
return DataLoader(
|
44 |
+
self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True
|
45 |
+
)
|
46 |
+
|
47 |
+
def val_dataloader(self):
|
48 |
+
return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=0)
|
49 |
+
|
50 |
+
def test_dataloader(self):
|
51 |
+
return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=0)
|
marcai/pl/similarity_vector_dataset.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
class SimilarityVectorDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, processed_path: str, transform=None):
|
9 |
+
|
10 |
+
self.transform = transform
|
11 |
+
self.data = pd.read_csv(processed_path)
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return self.data.shape[0]
|
15 |
+
|
16 |
+
def __getitem__(self, idx):
|
17 |
+
row = self.data.iloc[idx].to_dict()
|
18 |
+
|
19 |
+
label = float(float(row['cid']) == 1.0)
|
20 |
+
|
21 |
+
if self.transform:
|
22 |
+
row = self.transform(row)
|
23 |
+
|
24 |
+
row = np.array(list(row.values())).astype(float)
|
25 |
+
|
26 |
+
return row, label
|
marcai/pl/similarity_vector_model.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchmetrics import Accuracy
|
5 |
+
|
6 |
+
|
7 |
+
class SimilarityVectorModel(pl.LightningModule):
|
8 |
+
def __init__(self, lr, weight_decay, optimizer, batch_size, attrs, hidden_sizes):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# Hyperparameters
|
12 |
+
self.attrs = attrs
|
13 |
+
self.lr = lr
|
14 |
+
self.weight_decay = weight_decay
|
15 |
+
self.optimizer = optimizer
|
16 |
+
self.batch_size = batch_size
|
17 |
+
self.save_hyperparameters()
|
18 |
+
|
19 |
+
# Create model layers
|
20 |
+
layer_sizes = [len(attrs)] + hidden_sizes + [1]
|
21 |
+
layers = []
|
22 |
+
for i in range(len(layer_sizes) - 1):
|
23 |
+
in_size, out_size = layer_sizes[i], layer_sizes[i + 1]
|
24 |
+
layers.append(nn.Linear(in_size, out_size))
|
25 |
+
|
26 |
+
if i < len(layer_sizes) - 2:
|
27 |
+
layers.append(nn.ReLU())
|
28 |
+
|
29 |
+
self.layers = nn.Sequential(*layers)
|
30 |
+
|
31 |
+
self.sigmoid = nn.Sigmoid()
|
32 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
33 |
+
self.accuracy = Accuracy(task="binary")
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.layers(x)
|
37 |
+
|
38 |
+
def predict(self, x):
|
39 |
+
return self.sigmoid(self(x))
|
40 |
+
|
41 |
+
def training_step(self, batch, batch_idx):
|
42 |
+
sim, label = batch
|
43 |
+
pred = self(sim.float())
|
44 |
+
label = label.unsqueeze(1)
|
45 |
+
|
46 |
+
loss = self.criterion(pred, label)
|
47 |
+
acc = self.accuracy(pred, label.long())
|
48 |
+
|
49 |
+
self.log("train_loss", loss, on_step=False, on_epoch=True)
|
50 |
+
self.log("train_acc", acc, on_step=False, on_epoch=True)
|
51 |
+
|
52 |
+
return loss
|
53 |
+
|
54 |
+
def validation_step(self, batch, batch_idx):
|
55 |
+
sim, label = batch
|
56 |
+
pred = self(sim.float())
|
57 |
+
label = label.unsqueeze(1)
|
58 |
+
|
59 |
+
loss = self.criterion(pred, label)
|
60 |
+
acc = self.accuracy(pred, label.long())
|
61 |
+
|
62 |
+
self.log("val_loss", loss, on_step=False, on_epoch=True)
|
63 |
+
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
|
64 |
+
|
65 |
+
return loss
|
66 |
+
|
67 |
+
def test_step(self, batch, batch_idx):
|
68 |
+
sim, label = batch
|
69 |
+
pred = self(sim.float())
|
70 |
+
label = label.unsqueeze(1)
|
71 |
+
|
72 |
+
loss = self.criterion(pred, label)
|
73 |
+
acc = self.accuracy(pred, label.long())
|
74 |
+
|
75 |
+
self.log("test_loss", loss, on_step=False, on_epoch=True)
|
76 |
+
self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
|
77 |
+
|
78 |
+
return loss
|
79 |
+
|
80 |
+
def configure_optimizers(self):
|
81 |
+
optimizers = {
|
82 |
+
"Adadelta": torch.optim.Adadelta,
|
83 |
+
"Adagrad": torch.optim.Adagrad,
|
84 |
+
"Adam": torch.optim.Adam,
|
85 |
+
"RMSprop": torch.optim.RMSprop,
|
86 |
+
"SGD": torch.optim.SGD,
|
87 |
+
}
|
88 |
+
return optimizers[self.optimizer](
|
89 |
+
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
90 |
+
)
|
marcai/predict.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from marcai.utils import load_config
|
8 |
+
|
9 |
+
|
10 |
+
def sigmoid(x):
|
11 |
+
return 1 / (1 + np.exp(-1 * x))
|
12 |
+
|
13 |
+
|
14 |
+
def predict_onnx(model_onnx_path, data):
|
15 |
+
ort_session = onnxruntime.InferenceSession(model_onnx_path)
|
16 |
+
|
17 |
+
x = data.to_numpy(dtype=np.float32)
|
18 |
+
|
19 |
+
input_name = ort_session.get_inputs()[0].name
|
20 |
+
ort_inputs = {input_name: x}
|
21 |
+
ort_outs = np.array(ort_session.run(None, ort_inputs))
|
22 |
+
ort_outs = sigmoid(ort_outs)
|
23 |
+
|
24 |
+
return ort_outs
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument(
|
30 |
+
"-i", "--input", help="Path to preprocessed data file", required=True
|
31 |
+
)
|
32 |
+
parser.add_argument("-o", "--output", help="Output path", required=True)
|
33 |
+
parser.add_argument(
|
34 |
+
"-m",
|
35 |
+
"--model-dir",
|
36 |
+
help="Directory containing model ONNX and YAML files",
|
37 |
+
required=True,
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--chunksize",
|
41 |
+
help="Chunk size for reading and predicting",
|
42 |
+
default=1024,
|
43 |
+
type=int,
|
44 |
+
)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
config_path = f"{args.model_dir}/config.yaml"
|
49 |
+
model_onnx = f"{args.model_dir}/model.onnx"
|
50 |
+
|
51 |
+
config = load_config(config_path)
|
52 |
+
|
53 |
+
# Load data
|
54 |
+
data = pd.read_csv(args.input, chunksize=args.chunksize)
|
55 |
+
|
56 |
+
written = False
|
57 |
+
for chunk in data:
|
58 |
+
# Limit columns to model input features
|
59 |
+
input_df = chunk[config["model"]["features"]]
|
60 |
+
|
61 |
+
prediction = predict_onnx(model_onnx, input_df)
|
62 |
+
|
63 |
+
# Add prediction to chunk
|
64 |
+
chunk["prediction"] = prediction.squeeze()
|
65 |
+
|
66 |
+
# Append to CSV
|
67 |
+
if not written:
|
68 |
+
chunk.to_csv(args.output, index=False)
|
69 |
+
written = True
|
70 |
+
else:
|
71 |
+
chunk.to_csv(args.output, mode="a", header=False, index=False)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
main()
|
marcai/process.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import concurrent.futures
|
3 |
+
import csv
|
4 |
+
import itertools
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from more_itertools import chunked
|
10 |
+
|
11 |
+
import marcai.processing.comparisons as comps
|
12 |
+
import marcai.processing.normalizations as norms
|
13 |
+
from marcai.utils.parsing import load_records, record_dict
|
14 |
+
|
15 |
+
from multiprocessing import get_context
|
16 |
+
|
17 |
+
|
18 |
+
def multiprocess_pairs(
|
19 |
+
records_df,
|
20 |
+
pair_indices,
|
21 |
+
chunksize=50000,
|
22 |
+
processes=1,
|
23 |
+
):
|
24 |
+
# Create chunked iterator
|
25 |
+
pairs_chunked = chunked(pair_indices, chunksize)
|
26 |
+
|
27 |
+
# Create processing jobs
|
28 |
+
max_jobs = processes * 2
|
29 |
+
|
30 |
+
context = get_context("fork")
|
31 |
+
|
32 |
+
with concurrent.futures.ProcessPoolExecutor(
|
33 |
+
max_workers=processes, mp_context=context
|
34 |
+
) as executor:
|
35 |
+
futures = set()
|
36 |
+
done = set()
|
37 |
+
first_spawn = True
|
38 |
+
|
39 |
+
while futures or first_spawn:
|
40 |
+
if first_spawn:
|
41 |
+
spawn_count = max_jobs
|
42 |
+
first_spawn = False
|
43 |
+
else:
|
44 |
+
# Wait for a job to complete
|
45 |
+
done, futures = concurrent.futures.wait(
|
46 |
+
futures, return_when=concurrent.futures.FIRST_COMPLETED
|
47 |
+
)
|
48 |
+
spawn_count = max_jobs - len(futures)
|
49 |
+
|
50 |
+
for future in done:
|
51 |
+
# Get job's output
|
52 |
+
df = future.result()
|
53 |
+
|
54 |
+
# Yield output
|
55 |
+
yield df
|
56 |
+
|
57 |
+
# Spawn jobs
|
58 |
+
for _ in range(spawn_count):
|
59 |
+
pairs_chunk = next(pairs_chunked, None)
|
60 |
+
|
61 |
+
if pairs_chunk is None:
|
62 |
+
break
|
63 |
+
|
64 |
+
indices = np.array(pairs_chunk).astype(int)
|
65 |
+
|
66 |
+
left_indices = indices[:, 0]
|
67 |
+
right_indices = indices[:, 1]
|
68 |
+
|
69 |
+
left_records = records_df.iloc[left_indices].reset_index(drop=True)
|
70 |
+
right_records = records_df.iloc[right_indices].reset_index(drop=True)
|
71 |
+
|
72 |
+
futures.add(executor.submit(process, left_records, right_records))
|
73 |
+
|
74 |
+
|
75 |
+
def process(df0, df1):
|
76 |
+
normalize_fields = [
|
77 |
+
"author_names",
|
78 |
+
"corporate_names",
|
79 |
+
"meeting_names",
|
80 |
+
"publisher",
|
81 |
+
"title",
|
82 |
+
"title_a",
|
83 |
+
"title_b",
|
84 |
+
"title_c",
|
85 |
+
"title_p",
|
86 |
+
]
|
87 |
+
|
88 |
+
# Normalize text fields
|
89 |
+
for field in normalize_fields:
|
90 |
+
df0[field] = norms.lowercase(df0[field])
|
91 |
+
df1[field] = norms.lowercase(df1[field])
|
92 |
+
|
93 |
+
df0[field] = norms.remove_punctuation(df0[field])
|
94 |
+
df1[field] = norms.remove_punctuation(df1[field])
|
95 |
+
|
96 |
+
df0[field] = norms.remove_diacritics(df0[field])
|
97 |
+
df1[field] = norms.remove_diacritics(df1[field])
|
98 |
+
|
99 |
+
df0[field] = norms.normalize_whitespace(df0[field])
|
100 |
+
df1[field] = norms.normalize_whitespace(df1[field])
|
101 |
+
|
102 |
+
# Compare fields
|
103 |
+
result_df = pd.DataFrame()
|
104 |
+
|
105 |
+
result_df["id_0"] = df0["id"]
|
106 |
+
result_df["id_1"] = df1["id"]
|
107 |
+
|
108 |
+
result_df["raw_tokenset"] = comps.token_set_similarity(
|
109 |
+
df0["raw"], df1["raw"], null_value=0.5
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
# Token sort ratio
|
114 |
+
result_df["publisher"] = comps.token_sort_similarity(
|
115 |
+
df0["publisher"], df1["publisher"], null_value=0.5
|
116 |
+
)
|
117 |
+
|
118 |
+
author_names = comps.token_sort_similarity(
|
119 |
+
df0["author_names"], df1["author_names"], null_value=np.nan
|
120 |
+
)
|
121 |
+
corporate_names = comps.token_sort_similarity(
|
122 |
+
df0["corporate_names"], df1["corporate_names"], null_value=np.nan
|
123 |
+
)
|
124 |
+
meeting_names = comps.token_sort_similarity(
|
125 |
+
df0["meeting_names"], df1["meeting_names"], null_value=np.nan
|
126 |
+
)
|
127 |
+
authors = pd.concat([author_names, corporate_names, meeting_names], axis=1)
|
128 |
+
|
129 |
+
# Take max of author comparisons
|
130 |
+
result_df["author"] = comps.maximum(authors, null_value=0.5)
|
131 |
+
|
132 |
+
# Weighted title comparison
|
133 |
+
weights = {
|
134 |
+
"title_a": 1,
|
135 |
+
"raw": 0,
|
136 |
+
"title_p": 1
|
137 |
+
}
|
138 |
+
|
139 |
+
result_df["title_agg"] = comps.column_aggregate_similarity(
|
140 |
+
df0[weights.keys()], df1[weights.keys()], weights.values(), null_value=0
|
141 |
+
)
|
142 |
+
|
143 |
+
# Phonetic difference
|
144 |
+
result_df["title_phonetic"] = comps.phonetic_similarity(
|
145 |
+
df0["title"], df1["title"], null_value=0
|
146 |
+
)
|
147 |
+
|
148 |
+
# Length difference
|
149 |
+
result_df["title_length"] = comps.length_similarity(
|
150 |
+
df0["title"], df1["title"], null_value=0.5
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
# Token set similarity
|
156 |
+
result_df["title_tokenset"] = comps.token_set_similarity(
|
157 |
+
df0["title"], df1["title"], null_value=0
|
158 |
+
)
|
159 |
+
|
160 |
+
# Token sort ratio
|
161 |
+
result_df["title_tokensort"] = comps.token_sort_similarity(
|
162 |
+
df0["title"], df1["title"], null_value=0
|
163 |
+
)
|
164 |
+
|
165 |
+
# Levenshtein
|
166 |
+
result_df["title_levenshtein"] = comps.levenshtein_similarity(
|
167 |
+
df0["title"], df1["title"], null_value=0
|
168 |
+
)
|
169 |
+
|
170 |
+
# Jaro
|
171 |
+
result_df["title_jaro"] = comps.jaro_similarity(
|
172 |
+
df0["title"], df1["title"], null_value=0
|
173 |
+
)
|
174 |
+
|
175 |
+
# Jaro Winkler
|
176 |
+
result_df["title_jaro_winkler"] = comps.jaro_winkler_similarity(
|
177 |
+
df0["title"], df1["title"], null_value=0
|
178 |
+
)
|
179 |
+
|
180 |
+
# Pagination
|
181 |
+
result_df["pagination"] = comps.pagination_match(
|
182 |
+
df0["pagination"], df1["pagination"], null_value=0.5
|
183 |
+
)
|
184 |
+
|
185 |
+
# Dates
|
186 |
+
result_df["pub_date"] = comps.year_similarity(
|
187 |
+
df0["pub_date"], df1["pub_date"], null_value=0.5, exp_coeff=0.15
|
188 |
+
)
|
189 |
+
|
190 |
+
# Pub place
|
191 |
+
result_df["pub_place"] = comps.equal(
|
192 |
+
df0["pub_place"], df1["pub_place"], null_value=0.5
|
193 |
+
)
|
194 |
+
|
195 |
+
# CID/Label
|
196 |
+
result_df["cid"] = comps.equal(df0["cid"], df1["cid"], null_value=0.5)
|
197 |
+
|
198 |
+
return result_df
|
199 |
+
|
200 |
+
|
201 |
+
def parse_args():
|
202 |
+
parser = argparse.ArgumentParser(
|
203 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
204 |
+
)
|
205 |
+
|
206 |
+
required = parser.add_argument_group("required arguments")
|
207 |
+
required.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
|
208 |
+
required.add_argument("-o", "--output", help="Output file", required=True)
|
209 |
+
|
210 |
+
parser.add_argument(
|
211 |
+
"-C",
|
212 |
+
"--chunksize",
|
213 |
+
type=int,
|
214 |
+
help="Number of comparisons per job",
|
215 |
+
default=50000,
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
"-p", "--pair-indices", help="File containing indices of comparisons"
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"-P",
|
222 |
+
"--processes",
|
223 |
+
type=int,
|
224 |
+
help="Number of processes to run in parallel.",
|
225 |
+
default=1,
|
226 |
+
)
|
227 |
+
|
228 |
+
return parser.parse_args()
|
229 |
+
|
230 |
+
|
231 |
+
def main():
|
232 |
+
|
233 |
+
start = time.time()
|
234 |
+
args = parse_args()
|
235 |
+
|
236 |
+
# Load records
|
237 |
+
print("Loading records...")
|
238 |
+
records = []
|
239 |
+
for path in args.inputs:
|
240 |
+
records.extend([record_dict(r) for r in load_records(path)])
|
241 |
+
|
242 |
+
records_df = pd.DataFrame(records)
|
243 |
+
|
244 |
+
print(f"Loaded {len(records)} records.")
|
245 |
+
|
246 |
+
print("Processing records...")
|
247 |
+
# Process records
|
248 |
+
written = False
|
249 |
+
with open(args.pair_indices, "r") as indices_file:
|
250 |
+
reader = csv.reader(indices_file)
|
251 |
+
|
252 |
+
for df in multiprocess_pairs(
|
253 |
+
records_df, reader, args.chunksize, args.processes
|
254 |
+
):
|
255 |
+
if not written:
|
256 |
+
# Write header
|
257 |
+
df.to_csv(args.output, mode="w", header=True, index=False)
|
258 |
+
written = True
|
259 |
+
else:
|
260 |
+
# Write rows of df to output CSV
|
261 |
+
df.to_csv(args.output, mode="a", header=False, index=False)
|
262 |
+
|
263 |
+
end = time.time()
|
264 |
+
print(f"Processed {len(records)} records.")
|
265 |
+
print(f"Time elapsed: {end - start:.2f} seconds.")
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == "__main__":
|
269 |
+
main()
|
marcai/processing/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
marcai/processing/comparisons.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import re
|
3 |
+
import pandas as pd
|
4 |
+
from thefuzz import fuzz
|
5 |
+
import textdistance
|
6 |
+
import fuzzy
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
HAND_COUNT_PAGE_PATTERN = re.compile(r"\[(?P<hand_count>\d+)\]\s*p(ages)?[^\w]")
|
13 |
+
PAGE_PATTERN = re.compile(r"(?P<pages>\d+)\s*p(ages)?[^\w]")
|
14 |
+
|
15 |
+
|
16 |
+
def equal(se0, se1, null_value):
|
17 |
+
se0_np = se0.to_numpy(dtype=str)
|
18 |
+
se1_np = se1.to_numpy(dtype=str)
|
19 |
+
|
20 |
+
col = (se0_np == se1_np).astype(float)
|
21 |
+
|
22 |
+
se0_nulls = np.argwhere(np.char.strip(se0_np, " ") == "")
|
23 |
+
se1_nulls = np.argwhere(np.char.strip(se1_np, " ") == "")
|
24 |
+
|
25 |
+
col[se0_nulls] = null_value
|
26 |
+
col[se1_nulls] = null_value
|
27 |
+
|
28 |
+
return pd.Series(col)
|
29 |
+
|
30 |
+
|
31 |
+
def maximum(df, null_value, ignore_value=np.nan):
|
32 |
+
df_np = df.to_numpy(dtype=float)
|
33 |
+
|
34 |
+
df_np[df_np == ignore_value] = np.nan
|
35 |
+
|
36 |
+
# Mask ignore_value
|
37 |
+
masked = np.ma.masked_invalid(df_np)
|
38 |
+
|
39 |
+
# Get the max, ignoring NaNs
|
40 |
+
col = np.max(masked, axis=1)
|
41 |
+
|
42 |
+
# Replace NaNs with null_value
|
43 |
+
col = col.filled(fill_value=null_value)
|
44 |
+
|
45 |
+
return pd.Series(col)
|
46 |
+
|
47 |
+
|
48 |
+
def minimum(se0, se1, null_value, ignore_value=np.nan):
|
49 |
+
se0_np = se0.to_numpy(dtype=float)
|
50 |
+
se1_np = se1.to_numpy(dtype=float)
|
51 |
+
|
52 |
+
# Replace ignore_value with np.nans
|
53 |
+
se0_np[se0_np == ignore_value] = np.nan
|
54 |
+
se1_np[se1_np == ignore_value] = np.nan
|
55 |
+
|
56 |
+
# Get the min, ignoring NaNs
|
57 |
+
col = np.nanmin(np.stack([se0_np, se1_np], axis=1), axis=1)
|
58 |
+
|
59 |
+
# Replace NaNs with null_value
|
60 |
+
col[np.isnan(col)] = null_value
|
61 |
+
|
62 |
+
return pd.Series(col)
|
63 |
+
|
64 |
+
|
65 |
+
def pagination_match(se0, se1, null_value):
|
66 |
+
def group_values(pat, group, s):
|
67 |
+
return {m.groupdict()[group] for m in pat.finditer(s)}
|
68 |
+
|
69 |
+
def compare(pag0, pag1):
|
70 |
+
hand_counts0 = group_values(HAND_COUNT_PAGE_PATTERN, "hand_count", pag0)
|
71 |
+
hand_counts1 = group_values(HAND_COUNT_PAGE_PATTERN, "hand_count", pag1)
|
72 |
+
|
73 |
+
# Remove bracketed digits
|
74 |
+
pag0 = re.sub(r"\[\d+\]", "", pag0)
|
75 |
+
pag1 = re.sub(r"\[\d+\]", " ", pag1)
|
76 |
+
|
77 |
+
# Remove punctuation
|
78 |
+
pag0 = re.sub(r"[^\w\s]", " ", pag0)
|
79 |
+
pag1 = re.sub(r"[^\w\s]", " ", pag1)
|
80 |
+
|
81 |
+
# Extract page counts
|
82 |
+
counts0 = group_values(PAGE_PATTERN, "pages", pag0 + " ")
|
83 |
+
counts1 = group_values(PAGE_PATTERN, "pages", pag1 + " ")
|
84 |
+
|
85 |
+
page_counts0 = counts0 | hand_counts0
|
86 |
+
page_counts1 = counts1 | hand_counts1
|
87 |
+
|
88 |
+
# Check if any pages are in common.
|
89 |
+
if page_counts0 and page_counts1:
|
90 |
+
for pg0 in page_counts0:
|
91 |
+
for pg1 in page_counts1:
|
92 |
+
pg0 = int(pg0)
|
93 |
+
pg1 = int(pg1)
|
94 |
+
|
95 |
+
if pg0 == pg1:
|
96 |
+
return 1.0
|
97 |
+
return 0.0
|
98 |
+
|
99 |
+
return null_value
|
100 |
+
|
101 |
+
se0_np = se0.to_numpy(dtype=str)
|
102 |
+
se1_np = se1.to_numpy(dtype=str)
|
103 |
+
|
104 |
+
col = np.vectorize(compare)(se0_np, se1_np)
|
105 |
+
return pd.Series(col)
|
106 |
+
|
107 |
+
|
108 |
+
def year_similarity(se0, se1, null_value, exp_coeff):
|
109 |
+
def compare(yr0, yr1):
|
110 |
+
if yr0.isnumeric() and yr1.isnumeric():
|
111 |
+
x = abs(int(yr0) - int(yr1))
|
112 |
+
|
113 |
+
# Sigmoid where x = 0, y = 1, tail to the right
|
114 |
+
return 2 / (1 + np.exp(exp_coeff * x))
|
115 |
+
|
116 |
+
return null_value
|
117 |
+
|
118 |
+
se0_np = se0.to_numpy(dtype=str)
|
119 |
+
se1_np = se1.to_numpy(dtype=str)
|
120 |
+
|
121 |
+
return np.vectorize(compare)(se0_np, se1_np)
|
122 |
+
|
123 |
+
|
124 |
+
def column_aggregate_similarity(df0, df1, column_weights, null_value):
|
125 |
+
weights_dict = {k: v for k, v in zip(df0.columns, column_weights)}
|
126 |
+
|
127 |
+
def get_word_weights(row):
|
128 |
+
word_weights = {}
|
129 |
+
for i, value in enumerate(row):
|
130 |
+
column = df0.columns[i]
|
131 |
+
if column in weights_dict:
|
132 |
+
current_weight = weights_dict[column]
|
133 |
+
else:
|
134 |
+
current_weight = 0
|
135 |
+
|
136 |
+
for w in value.split():
|
137 |
+
if w not in word_weights:
|
138 |
+
word_weights[w] = current_weight
|
139 |
+
else:
|
140 |
+
word_weights[w] = max(current_weight, word_weights[w])
|
141 |
+
return word_weights
|
142 |
+
|
143 |
+
def compare(row0, row1):
|
144 |
+
weights0 = get_word_weights(row0)
|
145 |
+
weights1 = get_word_weights(row1)
|
146 |
+
|
147 |
+
total_weight = 0
|
148 |
+
missing_weight = 0
|
149 |
+
|
150 |
+
for w in weights0:
|
151 |
+
weight = weights0[w]
|
152 |
+
if w not in weights1:
|
153 |
+
missing_weight += weights0[w]
|
154 |
+
else:
|
155 |
+
weight = max(weight, weights1[w])
|
156 |
+
total_weight += weight
|
157 |
+
|
158 |
+
for w in weights1:
|
159 |
+
weight = weights1[w]
|
160 |
+
if w not in weights0:
|
161 |
+
missing_weight += weights1[w]
|
162 |
+
else:
|
163 |
+
weight = max(weight, weights0[w])
|
164 |
+
total_weight += weight
|
165 |
+
|
166 |
+
if total_weight == 0:
|
167 |
+
return null_value
|
168 |
+
|
169 |
+
return float((total_weight - missing_weight) / total_weight)
|
170 |
+
|
171 |
+
if df0.columns.to_list() != df1.columns.to_list():
|
172 |
+
raise ValueError("DataFrames must have the same columns")
|
173 |
+
|
174 |
+
# Run compare on rows of each df
|
175 |
+
col = np.array(
|
176 |
+
[compare(row0, row1) for row0, row1 in zip(df0.to_numpy(), df1.to_numpy())]
|
177 |
+
)
|
178 |
+
|
179 |
+
return pd.Series(col)
|
180 |
+
|
181 |
+
|
182 |
+
def length_similarity(se0, se1, null_value):
|
183 |
+
se0_np = se0.to_numpy(dtype=str)
|
184 |
+
se1_np = se1.to_numpy(dtype=str)
|
185 |
+
|
186 |
+
col = np.array([1 - abs(len(s0) - len(s1)) / max(len(s0), len(s1)) for s0, s1 in zip(se0_np, se1_np)])
|
187 |
+
|
188 |
+
# If either string is empty, set similarity to null_value
|
189 |
+
col[(se0_np == "") | (se1_np == "")] = null_value
|
190 |
+
|
191 |
+
return pd.Series(col)
|
192 |
+
|
193 |
+
def phonetic_similarity(se0, se1, null_value):
|
194 |
+
soundex = fuzzy.Soundex(4)
|
195 |
+
|
196 |
+
se0_np = se0.to_numpy(dtype=str)
|
197 |
+
se1_np = se1.to_numpy(dtype=str)
|
198 |
+
|
199 |
+
def compare_words(str0, str1):
|
200 |
+
words0 = str0.split()
|
201 |
+
words1 = str1.split()
|
202 |
+
|
203 |
+
sounds0 = [soundex(word) for word in words0]
|
204 |
+
sounds1 = [soundex(word) for word in words1]
|
205 |
+
|
206 |
+
return sum(s0 == s1 for s0, s1 in zip(sounds0, sounds1)) / max(len(sounds0), len(sounds1))
|
207 |
+
|
208 |
+
col = np.vectorize(compare_words)(se0_np, se1_np)
|
209 |
+
|
210 |
+
return pd.Series(col)
|
211 |
+
|
212 |
+
|
213 |
+
def jaccard_similarity(se0, se1, null_value):
|
214 |
+
se0_np = se0.to_numpy(dtype=str)
|
215 |
+
se1_np = se1.to_numpy(dtype=str)
|
216 |
+
|
217 |
+
col = np.array([textdistance.jaccard.normalized_similarity(set(s0.split()), set(s1.split())) for s0, s1 in zip(se0_np, se1_np)])
|
218 |
+
|
219 |
+
# If either string is empty, set similarity to null_value
|
220 |
+
col[(se0_np == "") | (se1_np == "")] = null_value
|
221 |
+
|
222 |
+
return pd.Series(col)
|
223 |
+
|
224 |
+
|
225 |
+
def similarity_factory(similarity_function):
|
226 |
+
def similarity(se0, se1, null_value):
|
227 |
+
se0_np = se0.to_numpy(dtype=str)
|
228 |
+
se1_np = se1.to_numpy(dtype=str)
|
229 |
+
|
230 |
+
col = np.vectorize(similarity_function)(se0_np, se1_np)
|
231 |
+
|
232 |
+
# Replace original null values with null_value
|
233 |
+
col[se0_np == ""] = null_value
|
234 |
+
col[se0_np == ""] = null_value
|
235 |
+
|
236 |
+
return pd.Series(col)
|
237 |
+
|
238 |
+
return similarity
|
239 |
+
|
240 |
+
|
241 |
+
token_set_similarity = similarity_factory(
|
242 |
+
lambda s0, s1: fuzz.token_set_ratio(s0, s1) / 100
|
243 |
+
)
|
244 |
+
token_sort_similarity = similarity_factory(
|
245 |
+
lambda s0, s1: fuzz.token_sort_ratio(s0, s1) / 100
|
246 |
+
)
|
247 |
+
levenshtein_similarity = similarity_factory(lambda s0, s1: (fuzz.ratio(s0, s1) / 100))
|
248 |
+
jaro_winkler_similarity = similarity_factory(lambda s0, s1: textdistance.jaro_winkler.similarity(s0, s1))
|
249 |
+
jaro_similarity = similarity_factory(lambda s0, s1: textdistance.jaro.similarity(s0, s1))
|
marcai/processing/normalizations.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unidecode import unidecode
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
def remove_diacritics(series):
|
7 |
+
se_np = series.to_numpy()
|
8 |
+
se_np = np.vectorize(unidecode)(se_np)
|
9 |
+
return pd.Series(se_np)
|
10 |
+
|
11 |
+
|
12 |
+
def lowercase(series):
|
13 |
+
return series.str.lower()
|
14 |
+
|
15 |
+
|
16 |
+
def remove_punctuation(series):
|
17 |
+
return series.str.replace(r"[^\w\s]", "")
|
18 |
+
|
19 |
+
|
20 |
+
def normalize_whitespace(series):
|
21 |
+
# Replace all whitespace with a single space
|
22 |
+
s = series.str.replace(r"\s", " ")
|
23 |
+
# Remove leading and trailing whitespace
|
24 |
+
s = s.str.strip()
|
25 |
+
# Remove double spaces
|
26 |
+
return s.str.replace(r"\s+", " ")
|
27 |
+
|
28 |
+
|
29 |
+
def substring(series, start, end):
|
30 |
+
return series.str[start:end]
|
31 |
+
|
32 |
+
|
33 |
+
def apply_normalizers(series, transforms):
|
34 |
+
for transform in transforms:
|
35 |
+
series = transform(series)
|
36 |
+
return series
|
marcai/train.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as lightning
|
2 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
3 |
+
import warnings
|
4 |
+
import yaml
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from marcai.pl import MARCDataModule, SimilarityVectorModel
|
9 |
+
from marcai.utils import load_config
|
10 |
+
import tarfile
|
11 |
+
|
12 |
+
|
13 |
+
def train(name=None):
|
14 |
+
config_path = "config.yaml"
|
15 |
+
config = load_config(config_path)
|
16 |
+
model_config = load_config(config_path)["model"]
|
17 |
+
|
18 |
+
# Create data module from processed data
|
19 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
20 |
+
data = MARCDataModule(
|
21 |
+
model_config["train_processed_path"],
|
22 |
+
model_config["val_processed_path"],
|
23 |
+
model_config["test_processed_path"],
|
24 |
+
model_config["features"],
|
25 |
+
model_config["batch_size"],
|
26 |
+
)
|
27 |
+
|
28 |
+
# Create model
|
29 |
+
model = SimilarityVectorModel(
|
30 |
+
model_config["lr"],
|
31 |
+
model_config["weight_decay"],
|
32 |
+
model_config["optimizer"],
|
33 |
+
model_config["batch_size"],
|
34 |
+
model_config["features"],
|
35 |
+
model_config["hidden_sizes"],
|
36 |
+
)
|
37 |
+
|
38 |
+
save_dir = os.path.join(model_config["saved_models_dir"], name)
|
39 |
+
os.makedirs(save_dir, exist_ok=True)
|
40 |
+
|
41 |
+
# Save best models
|
42 |
+
checkpoint_callback = ModelCheckpoint(
|
43 |
+
monitor="val_acc", mode="max", dirpath=save_dir, filename="model"
|
44 |
+
)
|
45 |
+
callbacks = [checkpoint_callback]
|
46 |
+
|
47 |
+
if model_config["patience"] != -1:
|
48 |
+
early_stop_callback = EarlyStopping(
|
49 |
+
monitor="val_acc",
|
50 |
+
min_delta=0.00,
|
51 |
+
patience=model_config["patience"],
|
52 |
+
verbose=False,
|
53 |
+
mode="max",
|
54 |
+
)
|
55 |
+
callbacks.append(early_stop_callback)
|
56 |
+
|
57 |
+
trainer = lightning.Trainer(
|
58 |
+
max_epochs=model_config["max_epochs"], callbacks=callbacks, accelerator="cpu"
|
59 |
+
)
|
60 |
+
trainer.fit(model, data)
|
61 |
+
|
62 |
+
# Save ONNX
|
63 |
+
onnx_path = os.path.join(save_dir, "model.onnx")
|
64 |
+
input_sample = torch.randn((1, len(model.attrs)))
|
65 |
+
torch.onnx.export(
|
66 |
+
model,
|
67 |
+
input_sample,
|
68 |
+
onnx_path,
|
69 |
+
export_params=True,
|
70 |
+
do_constant_folding=True,
|
71 |
+
input_names=["input"],
|
72 |
+
output_names=["output"],
|
73 |
+
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
|
74 |
+
)
|
75 |
+
|
76 |
+
# Save config
|
77 |
+
config_filename = os.path.join(save_dir, "config.yaml")
|
78 |
+
|
79 |
+
with open(config_filename, "w") as f:
|
80 |
+
dump = yaml.dump(config)
|
81 |
+
f.write(dump)
|
82 |
+
|
83 |
+
# Compress model directory files
|
84 |
+
tar_path = f"{save_dir}/{name}.tar.gz"
|
85 |
+
with tarfile.open(tar_path, mode="w:gz") as archive:
|
86 |
+
archive.add(save_dir, arcname=os.path.basename(save_dir))
|
87 |
+
|
88 |
+
|
89 |
+
def main():
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
parser.add_argument(
|
92 |
+
"-n", "--run-name", help="Name for training run"
|
93 |
+
)
|
94 |
+
args = parser.parse_args()
|
95 |
+
|
96 |
+
train(args.run_name)
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|
marcai/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .load_config import load_config
|
marcai/utils/load_config.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
|
4 |
+
def load_config(filename):
|
5 |
+
with open(filename, 'r') as file:
|
6 |
+
return yaml.safe_load(file)
|
marcai/utils/parsing.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import pymarc
|
4 |
+
|
5 |
+
|
6 |
+
def get_record_values(record, location):
|
7 |
+
split = location.split("$")
|
8 |
+
|
9 |
+
if len(split) == 1:
|
10 |
+
tag = split[0]
|
11 |
+
code = None
|
12 |
+
elif len(split) == 2:
|
13 |
+
tag, code = split
|
14 |
+
else:
|
15 |
+
raise ValueError("Invalid location")
|
16 |
+
|
17 |
+
# Find fields matching tag
|
18 |
+
fields = record.get_fields(tag)
|
19 |
+
|
20 |
+
results = []
|
21 |
+
for current_value in fields:
|
22 |
+
if current_value is not None:
|
23 |
+
if code is not None:
|
24 |
+
values = current_value.get_subfields(code)
|
25 |
+
results.extend(values)
|
26 |
+
elif isinstance(current_value, pymarc.Field):
|
27 |
+
results.append(current_value.value())
|
28 |
+
|
29 |
+
return " ".join(results)
|
30 |
+
|
31 |
+
|
32 |
+
def record_dict(record):
|
33 |
+
d = OrderedDict()
|
34 |
+
|
35 |
+
# Dump every field value into a string
|
36 |
+
d["raw"] = " ".join([f.value() for f in record.fields])
|
37 |
+
|
38 |
+
d["cid"] = get_record_values(record, "CID")
|
39 |
+
d["id"] = get_record_values(record, "001")
|
40 |
+
|
41 |
+
fixed_data = get_record_values(record, "008")
|
42 |
+
d["pub_date"] = fixed_data[7:11]
|
43 |
+
d["pub_place"] = fixed_data[15:18]
|
44 |
+
d["language"] = fixed_data[35:38]
|
45 |
+
|
46 |
+
d["title_a"] = get_record_values(record, "245$a")
|
47 |
+
d["title_b"] = get_record_values(record, "245$b")
|
48 |
+
d["title_c"] = get_record_values(record, "245$c")
|
49 |
+
d["title_p"] = get_record_values(record, "245$p")
|
50 |
+
|
51 |
+
d["title"] = " ".join([d["title_a"], d["title_b"], d["title_p"]])
|
52 |
+
|
53 |
+
d["title_variation_a"] = get_record_values(record, "246$a")
|
54 |
+
d["title_variation_b"] = get_record_values(record, "246$b")
|
55 |
+
|
56 |
+
d["subject_headings"] = " ".join(
|
57 |
+
get_record_values(record, "650$a") + get_record_values(record, "650$x")
|
58 |
+
)
|
59 |
+
|
60 |
+
d["author_names"] = " ".join(
|
61 |
+
[get_record_values(record, "100$a"), get_record_values(record, "700$a")]
|
62 |
+
)
|
63 |
+
d["corporate_names"] = " ".join(
|
64 |
+
[get_record_values(record, "110$a"), get_record_values(record, "710$a")]
|
65 |
+
)
|
66 |
+
d["meeting_names"] = " ".join(
|
67 |
+
[get_record_values(record, "111$a"), get_record_values(record, "711$a")]
|
68 |
+
)
|
69 |
+
|
70 |
+
d["publisher"] = record.publisher or ""
|
71 |
+
|
72 |
+
d["pagination"] = get_record_values(record, "300$a")
|
73 |
+
d["dimensions"] = get_record_values(record, "300$c")
|
74 |
+
|
75 |
+
return d
|
76 |
+
|
77 |
+
|
78 |
+
def load_records(path):
|
79 |
+
records = []
|
80 |
+
extension = path.split(".")[-1]
|
81 |
+
if extension == "mrc" or extension == "marc":
|
82 |
+
with open(path, "rb") as marcfile:
|
83 |
+
reader = pymarc.MARCReader(marcfile)
|
84 |
+
records.extend(list(reader))
|
85 |
+
elif extension == "json":
|
86 |
+
with open(path, "r") as jsonfile:
|
87 |
+
for line in jsonfile:
|
88 |
+
record = pymarc.parse_json_to_array(line)[0]
|
89 |
+
records.append(record)
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unsupported file extension: {extension}")
|
92 |
+
|
93 |
+
return records
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pymarc
|
2 |
+
thefuzz
|
3 |
+
pandas
|
4 |
+
unidecode
|
5 |
+
python-levenshtein
|
6 |
+
onnxruntime
|
7 |
+
textdistance
|
8 |
+
more-itertools
|
9 |
+
pyyaml
|
10 |
+
onnx
|
11 |
+
tqdm
|