Spaces:
Runtime error
Runtime error
File size: 3,837 Bytes
3d9b66a eb4b465 3d9b66a eb4b465 3d9b66a 8e9c357 dac8d18 3d9b66a 7f39fb2 3d9b66a 7f39fb2 3d9b66a 2b8b93a 3d9b66a 2636ace 3d9b66a ac62d88 eb4b465 e1e1d6a eb4b465 9b11bf4 eb4b465 ac62d88 e01d0a4 3d9b66a 8e9c357 3d9b66a 2b8b93a 3d9b66a 2636ace 8e9c357 eb4b465 3d9b66a 8e9c357 e01d0a4 d4533bc 8e9c357 e01d0a4 d4533bc 8e9c357 d4533bc 3d9b66a 7f39fb2 8e9c357 7f39fb2 3d9b66a 7f39fb2 3d9b66a 2636ace 3d9b66a eb4b465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F
header = """
import psycopg2
conn = psycopg2.connect("CONN")
cur = conn.cursor()
MIDDLE
def rename_customer(id, newName):\n\t# PROMPT\n\tcur.execute("UPDATE customer SET name =
"""
modelPath = {
# "GPT2-Medium": "gpt2-medium",
"CodeParrot-mini": "codeparrot/codeparrot-small",
"CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
# "GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
# "CodeParrot": "codeparrot/codeparrot",
# "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}
preloadModels = {}
for m in list(modelPath.keys()):
preloadModels[m] = ecco.from_pretrained(modelPath[m])
def generation(tokenizer, model, content):
decoder = 'Standard'
num_beams = 2 if decoder == 'Beam' else None
typical_p = 0.8 if decoder == 'Typical' else None
do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
seek_token_ids = [
tokenizer.encode('= \'" +')[1:],
tokenizer.encode('= " +')[1:],
]
full_output = model.generate(content, generate=6, do_sample=False)
def next_words(code, position, seek_token_ids):
op_model = model.generate(code, generate=1, do_sample=False)
hidden_states = op_model.hidden_states
layer_no = len(hidden_states) - 1
h = hidden_states[-1]
hidden_state = h[position - 1]
logits = op_model.lm_head(op_model.to(hidden_state))
softmax = F.softmax(logits, dim=-1)
my_token_prob = softmax[seek_token_ids[0]]
if len(seek_token_ids) > 1:
newprompt = code + tokenizer.decode(seek_token_ids[0])
return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
return my_token_prob
prob = 0
for opt in seek_token_ids:
prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']
def code_from_prompts(prompt, model, type_hints, pre_content):
tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
# model = ecco.from_pretrained(modelPath[model])
model = preloadModels[model]
code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
if type_hints:
code = code.replace('id,', 'id: int,')
code = code.replace('id)', 'id: int)')
code = code.replace('newName)', 'newName: str) -> None')
if pre_content == 'None':
code = code.replace('MIDDLE\n', '')
elif 'Concatenation' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = ' + str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
elif 'composition' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
results = generation(tokenizer, model, code)
return results
iface = gr.Interface(
fn=code_from_prompts,
inputs=[
gr.components.Textbox(label="Insert comment"),
gr.components.Radio(list(modelPath.keys()), label="Code Model"),
gr.components.Checkbox(label="Include type hints"),
gr.components.Radio([
"None",
"Proper composition: Include function 'WHERE id = %s'",
"Concatenation: Include a function with 'WHERE id = ' + id",
], label="Has user already written a function?")
],
outputs=[
gr.components.Textbox(label="Most probable code"),
gr.components.Textbox(label="Probability of concat"),
],
description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()
|