|
--- |
|
license: bigcode-openrail-m |
|
base_model: |
|
- bigcode/starcoder2-15b |
|
tags: |
|
- code |
|
- int4_awq |
|
--- |
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
This is a int4_awq quantized checkpoint of bigcode/starcoder2-15b. It takes about 10GB of VRAM. |
|
|
|
## Running this Model |
|
vLLM does not natively support autoawq currently (or any a4w8 as of writing this), so one can just serve directly from the autoawq backend. |
|
|
|
Note, if you want to start this in a container, then: |
|
`docker run --gpus all -it --name=starcoder2-15b-int4-awq -p 8000:8000 -v ~/.cache:/root/.cache nvcr.io/nvidia/pytorch:24.12-py3 bash` |
|
|
|
`pip install fastapi[all] torch transformers autoawq` |
|
|
|
Then in python3: |
|
``` |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import torch |
|
from awq import AutoAWQForCausalLM |
|
from transformers import AutoTokenizer |
|
import uvicorn |
|
|
|
# Define the FastAPI app |
|
app = FastAPI() |
|
|
|
# Define the request body model |
|
class TextRequest(BaseModel): |
|
text: str |
|
|
|
# Load the quantized model and tokenizer |
|
model_path = '/root/.cache/huggingface/hub/models--shavera--starcoder2-15b-w4-autoawq-gemm/snapshots/13fab46ef237de327397549f427106890e0dec67' |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = AutoAWQForCausalLM.from_quantized(model_path, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
# Ensure the model is in evaluation mode |
|
model.eval() |
|
|
|
# Create the inference function |
|
def generate_text(prompt: str) -> str: |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
outputs = model.generate(**inputs) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Define the API endpoint for text generation |
|
@app.post("/generate") |
|
async def generate(request: TextRequest): |
|
try: |
|
generated_text = generate_text(request.text) |
|
return {"generated_text": generated_text} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
# Run the server (port 8000) |
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
``` |
|
|
|
|