feat(completion()): enable setting prompt templates via completion()

This commit is contained in:
Krrish Dholakia 2023-11-02 16:23:51 -07:00
parent 1fc726d5dd
commit 512a1637eb
9 changed files with 94 additions and 37 deletions

View file

@ -1,11 +1,14 @@
import litellm, os, traceback
import os, traceback
from fastapi import FastAPI, Request, HTTPException
from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import json
import os
import json, sys
from typing import Optional
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path - for litellm local dev
import litellm
try:
from utils import set_callbacks, load_router_config, print_verbose
except ImportError:
@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
if "CONFIG_FILE_PATH" in os.environ:
print(f"CONFIG FILE DETECTED")
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
else:
llm_router, llm_model_list = load_router_config(router=llm_router)
@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
## CHECK KEYS ##
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
env_validation = litellm.validate_environment(model=data["model"])
print(f"request headers: {request.headers}")
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
if "authorization" in request.headers:
api_key = request.headers.get("authorization")
@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
for key, value in m["litellm_params"].items():
data[key] = value
break
print(f"data going into litellm completion: {data}")
response = litellm.completion(
**data
)
@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
return response
except Exception as e:
error_traceback = traceback.format_exc()
print(f"{error_traceback}")
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
# raise HTTPException(status_code=500, detail=error_msg)