Merge pull request #3797 from BerriAI/litellm_fix_post_call_streaming_hooks

[Fix]  async_post_call_streaming_hook not triggered on proxy server
This commit is contained in:
Ishaan Jaff 2024-05-23 15:35:47 -07:00 committed by GitHub
commit 769070b3fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 95 additions and 10 deletions

View file

@ -17,6 +17,8 @@ This function is called just before a litellm completion call is made, and allow
```python
from litellm.integrations.custom_logger import CustomLogger
import litellm
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache
from typing import Optional, Literal
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
@ -34,7 +36,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
]):
data["model"] = "my-new-model"
return data

View file

@ -0,0 +1,64 @@
from litellm.integrations.custom_logger import CustomLogger
import litellm
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache
from typing import Optional, Literal
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self):
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
return data
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
# print("in async_post_call_success_hook")
pass
async def async_moderation_hook( # call made in parallel to llm api call
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
pass
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
# print("in async_post_call_streaming_hook")
pass
proxy_handler_instance = MyCustomHandler()

View file

@ -20,16 +20,8 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
general_settings:
store_model_in_db: true
master_key: sk-1234
alerting: ["slack"]
litellm_settings:
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
default_team_settings:
- team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY"
callbacks: custom_callbacks1.proxy_handler_instance

View file

@ -3402,6 +3402,12 @@ async def async_data_generator(
try:
start_time = time.time()
async for chunk in response:
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
chunk = chunk.model_dump_json(exclude_none=True)
try:
yield f"data: {chunk}\n\n"

View file

@ -516,6 +516,27 @@ class ProxyLogging:
raise e
return response
async def async_post_call_streaming_hook(
self,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
user_api_key_dict: UserAPIKeyAuth,
):
"""
Allow user to modify outgoing streaming data -> per chunk
Covers:
1. /chat/completions
"""
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=response
)
except Exception as e:
raise e
return response
async def post_call_streaming_hook(
self,
response: str,