forked from phoenix/litellm-mirror
fix(proxy_server.py): support for streaming
This commit is contained in:
parent
0c8b8200b8
commit
6ef0e8485e
4 changed files with 219 additions and 142 deletions
|
@ -1,4 +1,4 @@
|
|||
import sys, os, platform, time, copy, re, asyncio
|
||||
import sys, os, platform, time, copy, re, asyncio, inspect
|
||||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -94,7 +94,6 @@ import litellm
|
|||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
get_instance_fn,
|
||||
CallHooks,
|
||||
ProxyLogging
|
||||
)
|
||||
import pydantic
|
||||
|
@ -198,8 +197,8 @@ user_custom_auth = None
|
|||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -309,10 +308,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client, proxy_logging_obj
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging()
|
||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks()
|
||||
if database_url is not None:
|
||||
try:
|
||||
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||
|
@ -390,6 +388,10 @@ async def track_cost_callback(
|
|||
completion=output_text
|
||||
)
|
||||
print("streaming response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
input_text = kwargs.get("messages", "")
|
||||
print(f"type of input_text: {type(input_text)}")
|
||||
|
@ -400,10 +402,10 @@ async def track_cost_callback(
|
|||
print(f"received completion response: {completion_response}")
|
||||
|
||||
print(f"regular response_cost: {response_cost}")
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
|
@ -475,9 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
|
||||
print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
|
||||
|
||||
## ROUTER CONFIG
|
||||
cache_responses = False
|
||||
|
||||
## ENVIRONMENT VARIABLES
|
||||
environment_variables = config.get('environment_variables', None)
|
||||
if environment_variables:
|
||||
|
@ -554,6 +553,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
|
||||
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||
general_settings = config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
|
@ -589,18 +590,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
use_background_health_checks = general_settings.get("background_health_checks", False)
|
||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||
|
||||
router_params: dict = {
|
||||
"num_retries": 3,
|
||||
"cache_responses": litellm.cache != None # cache if user passed in cache values
|
||||
}
|
||||
## MODEL LIST
|
||||
model_list = config.get('model_list', None)
|
||||
if model_list:
|
||||
router = litellm.Router(model_list=model_list, num_retries=3, cache_responses=cache_responses)
|
||||
router_params["model_list"] = model_list
|
||||
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
||||
for model in model_list:
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
if "ollama" in litellm_model_name:
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
arg_spec = inspect.getfullargspec(litellm.Router)
|
||||
# model list already set
|
||||
exclude_args = {
|
||||
"self",
|
||||
"model_list",
|
||||
}
|
||||
|
||||
available_args = [
|
||||
x for x in arg_spec.args if x not in exclude_args
|
||||
]
|
||||
|
||||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
|
||||
|
@ -772,10 +796,13 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
async def async_data_generator(response, user_api_key_dict):
|
||||
global call_hooks
|
||||
|
||||
print_verbose("inside generator")
|
||||
async for chunk in response:
|
||||
# try:
|
||||
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
|
||||
# except Exception as e:
|
||||
# print(f"An exception occurred - {str(e)}")
|
||||
|
||||
print_verbose(f"returned chunk: {chunk}")
|
||||
try:
|
||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||
|
@ -946,7 +973,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global general_settings, user_debug, call_hooks
|
||||
global general_settings, user_debug, proxy_logging_obj
|
||||
try:
|
||||
data = {}
|
||||
data = await request.json() # type: ignore
|
||||
|
@ -992,7 +1019,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
@ -1009,15 +1036,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"Exception received: {str(e)}")
|
||||
raise e
|
||||
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data.get("model", "") in router_model_names:
|
||||
|
@ -1052,7 +1074,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global call_hooks
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -1095,8 +1117,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
data["input"] = input_list
|
||||
break
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
|
@ -1107,12 +1129,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
response = await litellm.aembedding(**data)
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
@ -1139,7 +1158,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
|||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
# data = await request.json()
|
||||
data_json = json.loads(data.json()) # type: ignore
|
||||
data_json = data.json() # type: ignore
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue