Spaces:
Running
Running
File size: 4,305 Bytes
4cec855 e63103b 9717e5b 68d3cc8 4cec855 e63103b bca06e1 e63103b deb05dc e63103b 3b3d8b9 e63103b 182e6fa 3b3d8b9 aeb38f0 3b3d8b9 d8142ab 3b3d8b9 9717e5b 4cec855 9717e5b bca06e1 9717e5b b700f35 9717e5b 3f199c2 9717e5b b700f35 3b3d8b9 b700f35 9717e5b b700f35 9717e5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import tempfile, os
from rest_framework import viewsets, filters
from django_filters.rest_framework import DjangoFilterBackend
from endpoint_teste.models import EndpointTesteModel
from endpoint_teste.serializer import EndpointTesteSerializer, PDFUploadSerializer
from setup.environment import default_model
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import api_view, parser_classes
from rest_framework.parsers import MultiPartParser
from rest_framework.response import Response
from langchain_backend.main import get_llm_answer
from .serializer import TesteSerializer
from langchain_huggingface import HuggingFaceEndpoint
class EndpointTesteViewSet(viewsets.ModelViewSet):
"""Mostrará todas as tarefas"""
queryset = EndpointTesteModel.objects.order_by("id").all()
serializer_class = EndpointTesteSerializer
filter_backends = [DjangoFilterBackend, filters.SearchFilter]
search_fields = ["id"]
@api_view(["GET", "POST"])
def getTeste(request):
if request.method == "POST":
serializer = TesteSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
data = request.data
pdf_url = None
if data["pdf_url"]:
pdf_url = data["pdf_url"]
resposta_llm = get_llm_answer(data["system_prompt"], data["user_message"], pdf_url)
return Response({
"Resposta": resposta_llm
})
if request.method == "GET":
# hugging_face_token = os.environ.get("hugging_face_token")
# API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B"
# headers = {"Authorization": "Bearer " + hugging_face_token}
# def query(payload):
# response = requests.post(API_URL, headers=headers, json=payload)
# return response.json()
# output = query({
# "inputs": "Can you please let us know more details about your something I don't know",
# })
# print('output: ', output)
# print('output: ', dir(output))
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
task="text-generation",
max_new_tokens=100,
do_sample=False,
huggingfacehub_api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
result = llm.invoke("Hugging Face is")
print('result: ', result)
return Response(result)
@extend_schema(
request=PDFUploadSerializer,
)
@api_view(["POST"])
@parser_classes([MultiPartParser])
def getPDF(request):
if request.method == "POST":
serializer = PDFUploadSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
listaPDFs = []
print('\n\n')
data = request.data
print('data: ', data)
embedding = serializer.validated_data.get("embedding", "gpt")
model = serializer.validated_data.get("model", default_model)
# pdf_file = serializer.validated_data['file']
for file in serializer.validated_data['files']:
print("file: ", file)
file.seek(0)
# Create a temporary file to save the uploaded PDF
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
# Write the uploaded file content to the temporary file
for chunk in file.chunks():
temp_file.write(chunk)
temp_file_path = temp_file.name # Get the path of the temporary file
listaPDFs.append(temp_file_path)
# print('temp_file_path: ', temp_file_path)
print('listaPDFs: ', listaPDFs)
resposta_llm = None
# resposta_llm = get_llm_answer(data["system_prompt"], data["user_message"], temp_file_path, model=model, embedding=embedding)
resposta_llm = get_llm_answer(data["system_prompt"], data["user_message"], listaPDFs, model=model, embedding=embedding)
for file in listaPDFs:
os.remove(file)
# os.remove(temp_file_path)
return Response({
"Resposta": resposta_llm
}) |