forked from phoenix/litellm-mirror
feat(completion()): enable setting prompt templates via completion()
This commit is contained in:
parent
1fc726d5dd
commit
512a1637eb
9 changed files with 94 additions and 37 deletions
|
@ -1,11 +1,14 @@
|
|||
import litellm, os, traceback
|
||||
import os, traceback
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import json
|
||||
import os
|
||||
import json, sys
|
||||
from typing import Optional
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../")
|
||||
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||
import litellm
|
||||
try:
|
||||
from utils import set_callbacks, load_router_config, print_verbose
|
||||
except ImportError:
|
||||
|
@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None
|
|||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||
|
||||
if "CONFIG_FILE_PATH" in os.environ:
|
||||
print(f"CONFIG FILE DETECTED")
|
||||
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
||||
else:
|
||||
llm_router, llm_model_list = load_router_config(router=llm_router)
|
||||
|
@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
## CHECK KEYS ##
|
||||
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
||||
env_validation = litellm.validate_environment(model=data["model"])
|
||||
print(f"request headers: {request.headers}")
|
||||
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
||||
if "authorization" in request.headers:
|
||||
api_key = request.headers.get("authorization")
|
||||
|
@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
for key, value in m["litellm_params"].items():
|
||||
data[key] = value
|
||||
break
|
||||
print(f"data going into litellm completion: {data}")
|
||||
response = litellm.completion(
|
||||
**data
|
||||
)
|
||||
|
@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
return response
|
||||
except Exception as e:
|
||||
error_traceback = traceback.format_exc()
|
||||
print(f"{error_traceback}")
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
return {"error": error_msg}
|
||||
# raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue