Hariharan Vijayachandran commited on
Commit
bc92274
·
1 Parent(s): 97b903f
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +79 -0
  3. requirements.txt +4 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model_params filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from annotated_text import annotated_text
3
+
4
+ import os
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch
8
+ import torch.optim as optim
9
+ from transformers import DistilBertModel
10
+ from transformers import AutoTokenizer
11
+ import lightning.pytorch as pl
12
+ class Classifier(pl.LightningModule):
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.ln1 = torch.nn.Linear(512*768, 3)
17
+ # self.ln2 = torch.nn.Linear(1000, 3 )
18
+ self.criterion = nn.CrossEntropyLoss()
19
+ def training_step(self, batch, batch_idx):
20
+ x, y = batch
21
+ with torch.no_grad():
22
+ x = get_bert()(input_ids = x[:,:512], attention_mask = x[:,512:]).last_hidden_state.reshape(-1, 512*768)
23
+ x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768)
24
+ x = self.ln1(x)
25
+ # x = self.ln2(x)
26
+ loss = self.criterion(x, y)
27
+ self.log("my_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
28
+ return loss
29
+ def configure_optimizers(self):
30
+ optimizer = optim.Adam(self.parameters(), lr=1e-3)
31
+ return optimizer
32
+ def preprocess(self, x):
33
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
34
+ return tokenizer(x, padding='max_length', return_tensors="pt")
35
+ def forward(self, x):
36
+ print("here!", self.ln1.type)
37
+ with torch.no_grad():
38
+ x = get_bert()(**x).last_hidden_state.reshape(-1, 512*768)
39
+ x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768)
40
+ x = self.ln1(x)
41
+ # x = self.ln2(x)
42
+ return x
43
+
44
+ @st.cache
45
+ def get_bert():
46
+ return DistilBertModel.from_pretrained("distilbert-base-uncased")
47
+
48
+ @st.cache
49
+ def get_classifier():
50
+ os.system('gdown 1GxhHvg3lwlGpA7So06v3l43U8pSASy9L')
51
+ return Classifier.load_from_checkpoint(f"{os.getcwd()}/model_params")
52
+
53
+ def get_annotated_text(text):
54
+ model = get_classifier()
55
+ text = text.split(".")
56
+ l = []
57
+
58
+ for i in text:
59
+ if i == '' or i == " ":
60
+ continue
61
+ c = model(model.preprocess([i])).argmax()
62
+ print("class : ", c)
63
+ if c == 0:
64
+ l.append((i, "Leadership"))
65
+ if c == 1:
66
+ l.append((i, "Diversity"))
67
+ if c == 2:
68
+ l.append((i, "Integrity"))
69
+ l.append(".")
70
+ return tuple(l)
71
+
72
+ st.title("Code of Conduct Classifier")
73
+
74
+ input_text = st.text_area("enter code of conduct text" )
75
+
76
+ st.title("annotated text")
77
+ print(input_text)
78
+
79
+ annotated_text(*get_annotated_text(input_text))
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ st-annotated-text==3.0.0
2
+ torch==2.0.0.dev20230125
3
+ lightning==2.0.0
4
+ gdown==4.6.4