Spaces:
Running
Running
File size: 5,641 Bytes
8b19901 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import json
import random
import time
import pickle
from threading import Lock
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Dict, List
from fastapi import FastAPI, HTTPException, Request
from loguru import logger
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.middleware.cors import CORSMiddleware
MIN_PROMPTS = 1
MAX_PROMPTS = 1000
RATE_LIMIT = "100/minute"
CACHE_TTL = 300
CATEGORIES_FILE = "categories.json"
CACHE_FILE = "cache.pkl"
LOCK = Lock()
IP_REQUESTS = defaultdict(list)
logger.add("app.log", rotation="500 MB", retention="2 days", level="ERROR")
categorias_cache = None
last_cache_update = 0
app = FastAPI(
title="API de Generación de Prompts",
version="1.0.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
limiter = Limiter(key_func=get_remote_address, default_limits=[RATE_LIMIT])
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
def load_categories() -> Dict[str, List[str]]:
with LOCK:
try:
with open(CATEGORIES_FILE, "r") as file:
return json.load(file)
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.error(f"Error al cargar 'categories.json': {str(e)}")
raise HTTPException(status_code=500, detail="Error al cargar categorías.")
def save_cache(data):
with open(CACHE_FILE, "wb") as f:
pickle.dump(data, f)
def load_cache():
try:
with open(CACHE_FILE, "rb") as f:
return pickle.load(f)
except (FileNotFoundError, pickle.UnpicklingError):
return None
def get_cached_categories() -> Dict[str, List[str]]:
global categorias_cache, last_cache_update
current_time = time.time()
if categorias_cache is None or (current_time - last_cache_update) > CACHE_TTL:
categorias_cache = load_categories()
save_cache(categorias_cache)
last_cache_update = current_time
return categorias_cache
def calcular_combinaciones(record_count: Dict[str, int]) -> int:
total_combinations = 1
for count in record_count.values():
total_combinations *= count
return total_combinations
@app.get("/")
@limiter.limit(RATE_LIMIT)
async def read_root(request: Request):
logger.info("Endpoint raíz consultado.")
return {"message": "Bienvenido a la API de generación de prompts"}
@app.get("/generate")
@limiter.limit(RATE_LIMIT)
async def detail_generate_prompts(request: Request):
logger.warning("Intento de generar prompts sin cantidad especificada.")
raise HTTPException(
status_code=400,
detail="Debe especificar la cantidad de prompts a generar, por ejemplo: /generate/10",
)
@app.get("/generate/{cantidad}")
@limiter.limit(RATE_LIMIT)
async def generate_prompts(request: Request, cantidad: int):
if cantidad < MIN_PROMPTS or cantidad > MAX_PROMPTS:
logger.warning(
f"Intento de generar {cantidad} prompts: fuera del rango permitido."
)
raise HTTPException(
status_code=400,
detail=f"La cantidad debe estar entre {MIN_PROMPTS} y {MAX_PROMPTS}.",
)
prompts = [generar_base_prompt() for _ in range(cantidad)]
logger.info(f"Generados {cantidad} prompts exitosamente.")
return {"prompts": prompts}
@app.get("/count_records")
@limiter.limit(RATE_LIMIT)
async def count_records(request: Request):
try:
categorias = get_cached_categories()
record_count = {key: len(value) for key, value in categorias.items()}
logger.info(f"Número de registros por etiqueta: {record_count}")
total_combinations = calcular_combinaciones(record_count)
logger.info(f"Total de combinaciones posibles: {total_combinations}")
return {"record_count": record_count, "total_combinations": total_combinations}
except HTTPException as e:
raise e
@app.middleware("http")
async def limit_request_frequency(request: Request, call_next):
ip = request.client.host
now = datetime.now()
IP_REQUESTS[ip] = [
time for time in IP_REQUESTS[ip] if now - time < timedelta(minutes=1)
]
if len(IP_REQUESTS[ip]) >= 100:
logger.warning(f"Bloqueo temporal para la IP {ip}, demasiadas solicitudes.")
raise HTTPException(
status_code=429, detail="Demasiadas solicitudes. Espere 1 minuto."
)
IP_REQUESTS[ip].append(now)
response = await call_next(request)
return response
def generar_base_prompt() -> str:
categorias = get_cached_categories()
return (
f"A {random.choice(categorias['edad'])} {random.choice(categorias['sexo'])} "
f"{random.choice(categorias['tipo'])} with {random.choice(categorias['peinado'])} "
f"({random.choice(categorias['color_cabello'])}) and {random.choice(categorias['ojos'])}, "
f"having {random.choice(categorias['piel'])}, wearing {random.choice(categorias['ropa'])}, "
f"in a {random.choice(categorias['escenario'])}, {random.choice(categorias['pose'])} while feeling "
f"{random.choice(categorias['emocion'])}, adorned with {random.choice(categorias['extras'])}."
)
|