generate_prompts / main.py
JairoDanielMT's picture
Upload 5 files
8b19901 verified
import json
import random
import time
import pickle
from threading import Lock
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Dict, List
from fastapi import FastAPI, HTTPException, Request
from loguru import logger
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.middleware.cors import CORSMiddleware
MIN_PROMPTS = 1
MAX_PROMPTS = 1000
RATE_LIMIT = "100/minute"
CACHE_TTL = 300
CATEGORIES_FILE = "categories.json"
CACHE_FILE = "cache.pkl"
LOCK = Lock()
IP_REQUESTS = defaultdict(list)
logger.add("app.log", rotation="500 MB", retention="2 days", level="ERROR")
categorias_cache = None
last_cache_update = 0
app = FastAPI(
title="API de Generaci贸n de Prompts",
version="1.0.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
limiter = Limiter(key_func=get_remote_address, default_limits=[RATE_LIMIT])
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
def load_categories() -> Dict[str, List[str]]:
with LOCK:
try:
with open(CATEGORIES_FILE, "r") as file:
return json.load(file)
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.error(f"Error al cargar 'categories.json': {str(e)}")
raise HTTPException(status_code=500, detail="Error al cargar categor铆as.")
def save_cache(data):
with open(CACHE_FILE, "wb") as f:
pickle.dump(data, f)
def load_cache():
try:
with open(CACHE_FILE, "rb") as f:
return pickle.load(f)
except (FileNotFoundError, pickle.UnpicklingError):
return None
def get_cached_categories() -> Dict[str, List[str]]:
global categorias_cache, last_cache_update
current_time = time.time()
if categorias_cache is None or (current_time - last_cache_update) > CACHE_TTL:
categorias_cache = load_categories()
save_cache(categorias_cache)
last_cache_update = current_time
return categorias_cache
def calcular_combinaciones(record_count: Dict[str, int]) -> int:
total_combinations = 1
for count in record_count.values():
total_combinations *= count
return total_combinations
@app.get("/")
@limiter.limit(RATE_LIMIT)
async def read_root(request: Request):
logger.info("Endpoint ra铆z consultado.")
return {"message": "Bienvenido a la API de generaci贸n de prompts"}
@app.get("/generate")
@limiter.limit(RATE_LIMIT)
async def detail_generate_prompts(request: Request):
logger.warning("Intento de generar prompts sin cantidad especificada.")
raise HTTPException(
status_code=400,
detail="Debe especificar la cantidad de prompts a generar, por ejemplo: /generate/10",
)
@app.get("/generate/{cantidad}")
@limiter.limit(RATE_LIMIT)
async def generate_prompts(request: Request, cantidad: int):
if cantidad < MIN_PROMPTS or cantidad > MAX_PROMPTS:
logger.warning(
f"Intento de generar {cantidad} prompts: fuera del rango permitido."
)
raise HTTPException(
status_code=400,
detail=f"La cantidad debe estar entre {MIN_PROMPTS} y {MAX_PROMPTS}.",
)
prompts = [generar_base_prompt() for _ in range(cantidad)]
logger.info(f"Generados {cantidad} prompts exitosamente.")
return {"prompts": prompts}
@app.get("/count_records")
@limiter.limit(RATE_LIMIT)
async def count_records(request: Request):
try:
categorias = get_cached_categories()
record_count = {key: len(value) for key, value in categorias.items()}
logger.info(f"N煤mero de registros por etiqueta: {record_count}")
total_combinations = calcular_combinaciones(record_count)
logger.info(f"Total de combinaciones posibles: {total_combinations}")
return {"record_count": record_count, "total_combinations": total_combinations}
except HTTPException as e:
raise e
@app.middleware("http")
async def limit_request_frequency(request: Request, call_next):
ip = request.client.host
now = datetime.now()
IP_REQUESTS[ip] = [
time for time in IP_REQUESTS[ip] if now - time < timedelta(minutes=1)
]
if len(IP_REQUESTS[ip]) >= 100:
logger.warning(f"Bloqueo temporal para la IP {ip}, demasiadas solicitudes.")
raise HTTPException(
status_code=429, detail="Demasiadas solicitudes. Espere 1 minuto."
)
IP_REQUESTS[ip].append(now)
response = await call_next(request)
return response
def generar_base_prompt() -> str:
categorias = get_cached_categories()
return (
f"A {random.choice(categorias['edad'])} {random.choice(categorias['sexo'])} "
f"{random.choice(categorias['tipo'])} with {random.choice(categorias['peinado'])} "
f"({random.choice(categorias['color_cabello'])}) and {random.choice(categorias['ojos'])}, "
f"having {random.choice(categorias['piel'])}, wearing {random.choice(categorias['ropa'])}, "
f"in a {random.choice(categorias['escenario'])}, {random.choice(categorias['pose'])} while feeling "
f"{random.choice(categorias['emocion'])}, adorned with {random.choice(categorias['extras'])}."
)