fix(proxy_server.py): support for streaming

This commit is contained in:
Krrish Dholakia 2023-12-09 16:22:53 -08:00
parent 0c8b8200b8
commit 6ef0e8485e
4 changed files with 219 additions and 142 deletions

View file

@ -1,17 +1,32 @@
from typing import Optional
import litellm
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
class MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None:
return
if max_parallel_requests is None:
return
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
self.print_verbose(f"current: {current}")
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
@ -21,13 +36,35 @@ async def max_parallel_request_allow_request(max_parallel_requests: Optional[int
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if user_api_key is None:
return
request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_call(self, api_key, user_api_key_cache):
try:
if api_key is None:
return
request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key) or 1
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(f"An exception occurred - {str(e)}") # noqa

View file

@ -1,4 +1,4 @@
import sys, os, platform, time, copy, re, asyncio
import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta
@ -94,7 +94,6 @@ import litellm
from litellm.proxy.utils import (
PrismaClient,
get_instance_fn,
CallHooks,
ProxyLogging
)
import pydantic
@ -198,8 +197,8 @@ user_custom_auth = None
use_background_health_checks = None
health_check_interval = None
health_check_results = {}
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
proxy_logging_obj: Optional[ProxyLogging] = None
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -309,10 +308,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
)
def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging()
global prisma_client, proxy_logging_obj, user_api_key_cache
proxy_logging_obj._init_litellm_callbacks()
if database_url is not None:
try:
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
@ -390,6 +388,10 @@ async def track_cost_callback(
completion=output_text
)
print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses
input_text = kwargs.get("messages", "")
print(f"type of input_text: {type(input_text)}")
@ -475,9 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
## ROUTER CONFIG
cache_responses = False
## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None)
if environment_variables:
@ -554,6 +553,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
else:
setattr(litellm, key, value)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
general_settings = config.get("general_settings", {})
if general_settings is None:
@ -589,18 +590,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
use_background_health_checks = general_settings.get("background_health_checks", False)
health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = {
"num_retries": 3,
"cache_responses": litellm.cache != None # cache if user passed in cache values
}
## MODEL LIST
model_list = config.get('model_list', None)
if model_list:
router = litellm.Router(model_list=model_list, num_retries=3, cache_responses=cache_responses)
router_params["model_list"] = model_list
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
for model in model_list:
print(f"\033[32m {model.get('model_name', '')}\033[0m")
litellm_model_name = model["litellm_params"]["model"]
if "ollama" in litellm_model_name:
litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
}
available_args = [
x for x in arg_spec.args if x not in exclude_args
]
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
@ -772,10 +796,13 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response, user_api_key_dict):
global call_hooks
print_verbose("inside generator")
async for chunk in response:
# try:
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
# except Exception as e:
# print(f"An exception occurred - {str(e)}")
print_verbose(f"returned chunk: {chunk}")
try:
yield f"data: {json.dumps(chunk.dict())}\n\n"
@ -946,7 +973,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug, call_hooks
global general_settings, user_debug, proxy_logging_obj
try:
data = {}
data = await request.json() # type: ignore
@ -992,7 +1019,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
@ -1009,15 +1036,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
### CALL HOOKS ### - modify outgoing response
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
print(f"Exception received: {str(e)}")
raise e
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data.get("model", "") in router_model_names:
@ -1052,7 +1074,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global call_hooks
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
@ -1095,8 +1117,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
data["input"] = input_list
break
### CALL HOOKS ### - modify incoming data before calling the model
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
@ -1107,12 +1129,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
response = await litellm.aembedding(**data)
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
### CALL HOOKS ### - modify outgoing response
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
return response
except Exception as e:
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc()
raise e
@ -1139,7 +1158,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
# data = await request.json()
data_json = json.loads(data.json()) # type: ignore
data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])

View file

@ -3,7 +3,7 @@ import os, subprocess, hashlib, importlib, asyncio
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
def print_verbose(print_statement):
if litellm.set_verbose:
@ -11,20 +11,23 @@ def print_verbose(print_statement):
### LOGGING ###
class ProxyLogging:
"""
Logging for proxy.
Logging/Custom Handlers for proxy.
Implemented mainly to log successful/failed db read/writes.
Currently just logs this to a provided sentry integration.
Implemented mainly to:
- log successful/failed db read/writes
- support the max parallel request integration
"""
def __init__(self,):
def __init__(self, user_api_key_cache: DualCache):
## INITIALIZE LITELLM CALLBACKS ##
self._init_litellm_callbacks()
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
pass
def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
litellm.callbacks.append(self.max_parallel_request_limiter)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
@ -53,6 +56,30 @@ class ProxyLogging:
callback_list=callback_list
)
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
async def success_handler(self, *args, **kwargs):
"""
Log successful db read/writes
@ -68,6 +95,27 @@ class ProxyLogging:
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await self.max_parallel_request_limiter.async_log_failure_call(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return
### DB CONNECTOR ###
# Define the retry decorator with backoff strategy
@ -290,65 +338,4 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e:
raise e
### CALL HOOKS ###
class CallHooks:
"""
Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
def __init__(self, user_api_key_cache: DualCache):
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
self.call_details["litellm_settings"] = litellm_settings
self.call_details["general_settings"] = general_settings
self.call_details["model_list"] = model_list
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
try:
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call, once complete
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return response
except Exception as e:
raise e
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return

View file

@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
config_fp = f"{filepath}/test_configs/test_config.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@ -100,3 +100,37 @@ def test_add_new_key_max_parallel_limit(client):
_run_in_parallel()
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
def test_add_new_key_max_parallel_limit_streaming(client):
try:
# Your test data
test_data = {"duration": "20m", "max_parallel_requests": 1}
# Your bearer token
token = os.getenv('PROXY_MASTER_KEY')
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}], "stream": True}
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
def _run_in_parallel():
with ThreadPoolExecutor(max_workers=2) as executor:
future1 = executor.submit(_post_data)
future2 = executor.submit(_post_data)
# Obtain the results from the futures
response1 = future1.result()
response2 = future2.result()
if response1.status_code == 429 or response2.status_code == 429:
pass
else:
raise Exception()
_run_in_parallel()
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")