Spaces:
Running
Running
import json | |
import random | |
import time | |
import pickle | |
from threading import Lock | |
from datetime import datetime, timedelta | |
from collections import defaultdict | |
from typing import Dict, List | |
from fastapi import FastAPI, HTTPException, Request | |
from loguru import logger | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from fastapi.middleware.cors import CORSMiddleware | |
MIN_PROMPTS = 1 | |
MAX_PROMPTS = 1000 | |
RATE_LIMIT = "100/minute" | |
CACHE_TTL = 300 | |
CATEGORIES_FILE = "categories.json" | |
CACHE_FILE = "cache.pkl" | |
LOCK = Lock() | |
IP_REQUESTS = defaultdict(list) | |
logger.add("app.log", rotation="500 MB", retention="2 days", level="ERROR") | |
categorias_cache = None | |
last_cache_update = 0 | |
app = FastAPI( | |
title="API de Generaci贸n de Prompts", | |
version="1.0.0", | |
docs_url=None, | |
redoc_url=None, | |
openapi_url=None, | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
limiter = Limiter(key_func=get_remote_address, default_limits=[RATE_LIMIT]) | |
app.state.limiter = limiter | |
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
def load_categories() -> Dict[str, List[str]]: | |
with LOCK: | |
try: | |
with open(CATEGORIES_FILE, "r") as file: | |
return json.load(file) | |
except (FileNotFoundError, json.JSONDecodeError) as e: | |
logger.error(f"Error al cargar 'categories.json': {str(e)}") | |
raise HTTPException(status_code=500, detail="Error al cargar categor铆as.") | |
def save_cache(data): | |
with open(CACHE_FILE, "wb") as f: | |
pickle.dump(data, f) | |
def load_cache(): | |
try: | |
with open(CACHE_FILE, "rb") as f: | |
return pickle.load(f) | |
except (FileNotFoundError, pickle.UnpicklingError): | |
return None | |
def get_cached_categories() -> Dict[str, List[str]]: | |
global categorias_cache, last_cache_update | |
current_time = time.time() | |
if categorias_cache is None or (current_time - last_cache_update) > CACHE_TTL: | |
categorias_cache = load_categories() | |
save_cache(categorias_cache) | |
last_cache_update = current_time | |
return categorias_cache | |
def calcular_combinaciones(record_count: Dict[str, int]) -> int: | |
total_combinations = 1 | |
for count in record_count.values(): | |
total_combinations *= count | |
return total_combinations | |
async def read_root(request: Request): | |
logger.info("Endpoint ra铆z consultado.") | |
return {"message": "Bienvenido a la API de generaci贸n de prompts"} | |
async def detail_generate_prompts(request: Request): | |
logger.warning("Intento de generar prompts sin cantidad especificada.") | |
raise HTTPException( | |
status_code=400, | |
detail="Debe especificar la cantidad de prompts a generar, por ejemplo: /generate/10", | |
) | |
async def generate_prompts(request: Request, cantidad: int): | |
if cantidad < MIN_PROMPTS or cantidad > MAX_PROMPTS: | |
logger.warning( | |
f"Intento de generar {cantidad} prompts: fuera del rango permitido." | |
) | |
raise HTTPException( | |
status_code=400, | |
detail=f"La cantidad debe estar entre {MIN_PROMPTS} y {MAX_PROMPTS}.", | |
) | |
prompts = [generar_base_prompt() for _ in range(cantidad)] | |
logger.info(f"Generados {cantidad} prompts exitosamente.") | |
return {"prompts": prompts} | |
async def count_records(request: Request): | |
try: | |
categorias = get_cached_categories() | |
record_count = {key: len(value) for key, value in categorias.items()} | |
logger.info(f"N煤mero de registros por etiqueta: {record_count}") | |
total_combinations = calcular_combinaciones(record_count) | |
logger.info(f"Total de combinaciones posibles: {total_combinations}") | |
return {"record_count": record_count, "total_combinations": total_combinations} | |
except HTTPException as e: | |
raise e | |
async def limit_request_frequency(request: Request, call_next): | |
ip = request.client.host | |
now = datetime.now() | |
IP_REQUESTS[ip] = [ | |
time for time in IP_REQUESTS[ip] if now - time < timedelta(minutes=1) | |
] | |
if len(IP_REQUESTS[ip]) >= 100: | |
logger.warning(f"Bloqueo temporal para la IP {ip}, demasiadas solicitudes.") | |
raise HTTPException( | |
status_code=429, detail="Demasiadas solicitudes. Espere 1 minuto." | |
) | |
IP_REQUESTS[ip].append(now) | |
response = await call_next(request) | |
return response | |
def generar_base_prompt() -> str: | |
categorias = get_cached_categories() | |
return ( | |
f"A {random.choice(categorias['edad'])} {random.choice(categorias['sexo'])} " | |
f"{random.choice(categorias['tipo'])} with {random.choice(categorias['peinado'])} " | |
f"({random.choice(categorias['color_cabello'])}) and {random.choice(categorias['ojos'])}, " | |
f"having {random.choice(categorias['piel'])}, wearing {random.choice(categorias['ropa'])}, " | |
f"in a {random.choice(categorias['escenario'])}, {random.choice(categorias['pose'])} while feeling " | |
f"{random.choice(categorias['emocion'])}, adorned with {random.choice(categorias['extras'])}." | |
) | |