fix(proxy_server.py): support for streaming

This commit is contained in:
Krrish Dholakia 2023-12-09 16:22:53 -08:00
parent 0c8b8200b8
commit 6ef0e8485e
4 changed files with 219 additions and 142 deletions

View file

@ -1,4 +1,4 @@
import sys, os, platform, time, copy, re, asyncio
import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
@ -94,7 +94,6 @@ import litellm
from litellm.proxy.utils import (
PrismaClient,
get_instance_fn,
CallHooks,
ProxyLogging
)
import pydantic
@ -198,8 +197,8 @@ user_custom_auth = None
use_background_health_checks = None
health_check_interval = None
health_check_results = {}
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
proxy_logging_obj: Optional[ProxyLogging] = None
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -309,10 +308,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
)
def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging()
global prisma_client, proxy_logging_obj, user_api_key_cache
proxy_logging_obj._init_litellm_callbacks()
if database_url is not None:
try:
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
@ -390,6 +388,10 @@ async def track_cost_callback(
completion=output_text
)
print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses
input_text = kwargs.get("messages", "")
print(f"type of input_text: {type(input_text)}")
@ -400,10 +402,10 @@ async def track_cost_callback(
print(f"received completion response: {completion_response}")
print(f"regular response_cost: {response_cost}")
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
@ -475,9 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
## ROUTER CONFIG
cache_responses = False
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
@ -554,6 +553,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
else:
setattr(litellm, key, value)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
general_settings = config.get("general_settings", {})
if general_settings is None:
@ -589,18 +590,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
use_background_health_checks = general_settings.get("background_health_checks", False)
health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = {
"num_retries": 3,
"cache_responses": litellm.cache != None # cache if user passed in cache values
}
## MODEL LIST
model_list = config.get('model_list', None)
if model_list:
router = litellm.Router(model_list=model_list, num_retries=3, cache_responses=cache_responses)
router_params["model_list"] = model_list
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
for model in model_list:
print(f"\033[32m {model.get('model_name', '')}\033[0m")
litellm_model_name = model["litellm_params"]["model"]
if "ollama" in litellm_model_name:
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
}
available_args = [
x for x in arg_spec.args if x not in exclude_args
]
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
@ -772,10 +796,13 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response, user_api_key_dict):
global call_hooks
print_verbose("inside generator")
async for chunk in response:
# try:
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
# except Exception as e:
# print(f"An exception occurred - {str(e)}")
print_verbose(f"returned chunk: {chunk}")
try:
yield f"data: {json.dumps(chunk.dict())}\n\n"
@ -946,7 +973,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug, call_hooks
global general_settings, user_debug, proxy_logging_obj
try:
data = {}
data = await request.json() # type: ignore
@ -992,7 +1019,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
@ -1009,15 +1036,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
### CALL HOOKS ### - modify outgoing response
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
print(f"Exception received: {str(e)}")
raise e
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data.get("model", "") in router_model_names:
@ -1052,7 +1074,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global call_hooks
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
@ -1095,8 +1117,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
data["input"] = input_list
break
### CALL HOOKS ### - modify incoming data before calling the model
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
@ -1107,12 +1129,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
response = await litellm.aembedding(**data)
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
### CALL HOOKS ### - modify outgoing response
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
return response
except Exception as e:
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc()
raise e
@ -1139,7 +1158,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
# data = await request.json()
data_json = json.loads(data.json()) # type: ignore
data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])