mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(proxy_server.py): support for streaming
This commit is contained in:
parent
76541c4a0a
commit
8eb7dc6393
4 changed files with 219 additions and 142 deletions
|
@ -3,7 +3,7 @@ import os, subprocess, hashlib, importlib, asyncio
|
|||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
|
@ -11,32 +11,35 @@ def print_verbose(print_statement):
|
|||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging for proxy.
|
||||
Logging/Custom Handlers for proxy.
|
||||
|
||||
Implemented mainly to log successful/failed db read/writes.
|
||||
|
||||
Currently just logs this to a provided sentry integration.
|
||||
Implemented mainly to:
|
||||
- log successful/failed db read/writes
|
||||
- support the max parallel request integration
|
||||
"""
|
||||
|
||||
def __init__(self,):
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self._init_litellm_callbacks()
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
def _init_litellm_callbacks(self):
|
||||
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.failure_callback.append(callback)
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append(callback)
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm._async_failure_callback.append(callback)
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
|
@ -53,6 +56,30 @@ class ProxyLogging:
|
|||
callback_list=callback_list
|
||||
)
|
||||
|
||||
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
|
@ -67,6 +94,27 @@ class ProxyLogging:
|
|||
"""
|
||||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await self.max_parallel_request_limiter.async_log_failure_call(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
|
@ -290,65 +338,4 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
### CALL HOOKS ###
|
||||
class CallHooks:
|
||||
"""
|
||||
Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
|
||||
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
|
||||
self.call_details["litellm_settings"] = litellm_settings
|
||||
self.call_details["general_settings"] = general_settings
|
||||
self.call_details["model_list"] = model_list
|
||||
|
||||
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||
try:
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call, once complete
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue