Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import json | |
from threading import Lock | |
from functools import partial | |
from typing import Iterator, List, Optional, Union, Dict | |
import llama_cpp | |
import anyio | |
from anyio.streams.memory import MemoryObjectSendStream | |
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool | |
from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.security import HTTPBearer | |
from sse_starlette.sse import EventSourceResponse | |
from starlette_context.plugins import RequestIdPlugin # type: ignore | |
from starlette_context.middleware import RawContextMiddleware | |
from llama_cpp.server.model import ( | |
LlamaProxy, | |
) | |
from llama_cpp.server.settings import ( | |
ConfigFileSettings, | |
Settings, | |
ModelSettings, | |
ServerSettings, | |
) | |
from llama_cpp.server.types import ( | |
CreateCompletionRequest, | |
CreateEmbeddingRequest, | |
CreateChatCompletionRequest, | |
ModelList, | |
TokenizeInputRequest, | |
TokenizeInputResponse, | |
TokenizeInputCountResponse, | |
DetokenizeInputRequest, | |
DetokenizeInputResponse, | |
) | |
from llama_cpp.server.errors import RouteErrorHandler | |
router = APIRouter(route_class=RouteErrorHandler) | |
_server_settings: Optional[ServerSettings] = None | |
def set_server_settings(server_settings: ServerSettings): | |
global _server_settings | |
_server_settings = server_settings | |
def get_server_settings(): | |
yield _server_settings | |
_llama_proxy: Optional[LlamaProxy] = None | |
llama_outer_lock = Lock() | |
llama_inner_lock = Lock() | |
def set_llama_proxy(model_settings: List[ModelSettings]): | |
global _llama_proxy | |
_llama_proxy = LlamaProxy(models=model_settings) | |
def get_llama_proxy(): | |
# NOTE: This double lock allows the currently streaming llama model to | |
# check if any other requests are pending in the same thread and cancel | |
# the stream if so. | |
llama_outer_lock.acquire() | |
release_outer_lock = True | |
try: | |
llama_inner_lock.acquire() | |
try: | |
llama_outer_lock.release() | |
release_outer_lock = False | |
yield _llama_proxy | |
finally: | |
llama_inner_lock.release() | |
finally: | |
if release_outer_lock: | |
llama_outer_lock.release() | |
_ping_message_factory = None | |
def set_ping_message_factory(factory): | |
global _ping_message_factory | |
_ping_message_factory = factory | |
def create_app( | |
settings: Settings | None = None, | |
server_settings: ServerSettings | None = None, | |
model_settings: List[ModelSettings] | None = None, | |
): | |
config_file = os.environ.get("CONFIG_FILE", None) | |
if config_file is not None: | |
if not os.path.exists(config_file): | |
raise ValueError(f"Config file {config_file} not found!") | |
with open(config_file, "rb") as f: | |
# Check if yaml file | |
if config_file.endswith(".yaml") or config_file.endswith(".yml"): | |
import yaml | |
config_file_settings = ConfigFileSettings.model_validate_json( | |
json.dumps(yaml.safe_load(f)) | |
) | |
else: | |
config_file_settings = ConfigFileSettings.model_validate_json(f.read()) | |
server_settings = ServerSettings.model_validate(config_file_settings) | |
model_settings = config_file_settings.models | |
if server_settings is None and model_settings is None: | |
if settings is None: | |
settings = Settings() | |
server_settings = ServerSettings.model_validate(settings) | |
model_settings = [ModelSettings.model_validate(settings)] | |
assert ( | |
server_settings is not None and model_settings is not None | |
), "server_settings and model_settings must be provided together" | |
set_server_settings(server_settings) | |
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))] | |
app = FastAPI( | |
middleware=middleware, | |
title="🦙 llama.cpp Python API", | |
version=llama_cpp.__version__, | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.include_router(router) | |
assert model_settings is not None | |
set_llama_proxy(model_settings=model_settings) | |
if server_settings.disable_ping_events: | |
set_ping_message_factory(lambda: bytes()) | |
return app | |
async def get_event_publisher( | |
request: Request, | |
inner_send_chan: MemoryObjectSendStream, | |
iterator: Iterator, | |
): | |
async with inner_send_chan: | |
try: | |
async for chunk in iterate_in_threadpool(iterator): | |
await inner_send_chan.send(dict(data=json.dumps(chunk))) | |
if await request.is_disconnected(): | |
raise anyio.get_cancelled_exc_class()() | |
if ( | |
next(get_server_settings()).interrupt_requests | |
and llama_outer_lock.locked() | |
): | |
await inner_send_chan.send(dict(data="[DONE]")) | |
raise anyio.get_cancelled_exc_class()() | |
await inner_send_chan.send(dict(data="[DONE]")) | |
except anyio.get_cancelled_exc_class() as e: | |
print("disconnected") | |
with anyio.move_on_after(1, shield=True): | |
print(f"Disconnected from client (via refresh/close) {request.client}") | |
raise e | |
def _logit_bias_tokens_to_input_ids( | |
llama: llama_cpp.Llama, | |
logit_bias: Dict[str, float], | |
) -> Dict[str, float]: | |
to_bias: Dict[str, float] = {} | |
for token, score in logit_bias.items(): | |
token = token.encode("utf-8") | |
for input_id in llama.tokenize(token, add_bos=False, special=True): | |
to_bias[str(input_id)] = score | |
return to_bias | |
# Setup Bearer authentication scheme | |
bearer_scheme = HTTPBearer(auto_error=False) | |
async def authenticate( | |
settings: Settings = Depends(get_server_settings), | |
authorization: Optional[str] = Depends(bearer_scheme), | |
): | |
# Skip API key check if it's not set in settings | |
if settings.api_key is None: | |
return True | |
# check bearer credentials against the api_key | |
if authorization and authorization.credentials == settings.api_key: | |
# api key is valid | |
return authorization.credentials | |
# raise http error 401 | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid API key", | |
) | |
openai_v1_tag = "OpenAI V1" | |
async def create_completion( | |
request: Request, | |
body: CreateCompletionRequest, | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
) -> llama_cpp.Completion: | |
if isinstance(body.prompt, list): | |
assert len(body.prompt) <= 1 | |
body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" | |
llama = llama_proxy( | |
body.model | |
if request.url.path != "/v1/engines/copilot-codex/completions" | |
else "copilot-codex" | |
) | |
exclude = { | |
"n", | |
"best_of", | |
"logit_bias_type", | |
"user", | |
} | |
kwargs = body.model_dump(exclude=exclude) | |
if body.logit_bias is not None: | |
kwargs["logit_bias"] = ( | |
_logit_bias_tokens_to_input_ids(llama, body.logit_bias) | |
if body.logit_bias_type == "tokens" | |
else body.logit_bias | |
) | |
if body.grammar is not None: | |
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) | |
iterator_or_completion: Union[ | |
llama_cpp.CreateCompletionResponse, | |
Iterator[llama_cpp.CreateCompletionStreamResponse], | |
] = await run_in_threadpool(llama, **kwargs) | |
if isinstance(iterator_or_completion, Iterator): | |
# EAFP: It's easier to ask for forgiveness than permission | |
first_response = await run_in_threadpool(next, iterator_or_completion) | |
# If no exception was raised from first_response, we can assume that | |
# the iterator is valid and we can use it to stream the response. | |
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: | |
yield first_response | |
yield from iterator_or_completion | |
send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
return EventSourceResponse( | |
recv_chan, | |
data_sender_callable=partial( # type: ignore | |
get_event_publisher, | |
request=request, | |
inner_send_chan=send_chan, | |
iterator=iterator(), | |
), | |
sep="\n", | |
ping_message_factory=_ping_message_factory, | |
) | |
else: | |
return iterator_or_completion | |
async def create_embedding( | |
request: CreateEmbeddingRequest, | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
): | |
return await run_in_threadpool( | |
llama_proxy(request.model).create_embedding, | |
**request.model_dump(exclude={"user"}), | |
) | |
async def create_chat_completion( | |
request: Request, | |
body: CreateChatCompletionRequest = Body( | |
openapi_examples={ | |
"normal": { | |
"summary": "Chat Completion", | |
"value": { | |
"model": "gpt-3.5-turbo", | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "What is the capital of France?"}, | |
], | |
}, | |
}, | |
"json_mode": { | |
"summary": "JSON Mode", | |
"value": { | |
"model": "gpt-3.5-turbo", | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Who won the world series in 2020"}, | |
], | |
"response_format": { "type": "json_object" } | |
}, | |
}, | |
"tool_calling": { | |
"summary": "Tool Calling", | |
"value": { | |
"model": "gpt-3.5-turbo", | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Extract Jason is 30 years old."}, | |
], | |
"tools": [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "User", | |
"description": "User record", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"name": {"type": "string"}, | |
"age": {"type": "number"}, | |
}, | |
"required": ["name", "age"], | |
}, | |
} | |
} | |
], | |
"tool_choice": { | |
"type": "function", | |
"function": { | |
"name": "User", | |
} | |
} | |
}, | |
}, | |
"logprobs": { | |
"summary": "Logprobs", | |
"value": { | |
"model": "gpt-3.5-turbo", | |
"messages": [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "What is the capital of France?"}, | |
], | |
"logprobs": True, | |
"top_logprobs": 10 | |
}, | |
}, | |
} | |
), | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
) -> llama_cpp.ChatCompletion: | |
exclude = { | |
"n", | |
"logit_bias_type", | |
"user", | |
} | |
kwargs = body.model_dump(exclude=exclude) | |
llama = llama_proxy(body.model) | |
if body.logit_bias is not None: | |
kwargs["logit_bias"] = ( | |
_logit_bias_tokens_to_input_ids(llama, body.logit_bias) | |
if body.logit_bias_type == "tokens" | |
else body.logit_bias | |
) | |
if body.grammar is not None: | |
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) | |
iterator_or_completion: Union[ | |
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] | |
] = await run_in_threadpool(llama.create_chat_completion, **kwargs) | |
if isinstance(iterator_or_completion, Iterator): | |
# EAFP: It's easier to ask for forgiveness than permission | |
first_response = await run_in_threadpool(next, iterator_or_completion) | |
# If no exception was raised from first_response, we can assume that | |
# the iterator is valid and we can use it to stream the response. | |
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: | |
yield first_response | |
yield from iterator_or_completion | |
send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
return EventSourceResponse( | |
recv_chan, | |
data_sender_callable=partial( # type: ignore | |
get_event_publisher, | |
request=request, | |
inner_send_chan=send_chan, | |
iterator=iterator(), | |
), | |
sep="\n", | |
ping_message_factory=_ping_message_factory, | |
) | |
else: | |
return iterator_or_completion | |
async def get_models( | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
) -> ModelList: | |
return { | |
"object": "list", | |
"data": [ | |
{ | |
"id": model_alias, | |
"object": "model", | |
"owned_by": "me", | |
"permissions": [], | |
} | |
for model_alias in llama_proxy | |
], | |
} | |
extras_tag = "Extras" | |
async def tokenize( | |
body: TokenizeInputRequest, | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
) -> TokenizeInputResponse: | |
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) | |
return TokenizeInputResponse(tokens=tokens) | |
async def count_query_tokens( | |
body: TokenizeInputRequest, | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
) -> TokenizeInputCountResponse: | |
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) | |
return TokenizeInputCountResponse(count=len(tokens)) | |
async def detokenize( | |
body: DetokenizeInputRequest, | |
llama_proxy: LlamaProxy = Depends(get_llama_proxy), | |
) -> DetokenizeInputResponse: | |
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8") | |
return DetokenizeInputResponse(text=text) | |