mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #3797 from BerriAI/litellm_fix_post_call_streaming_hooks
[Fix] async_post_call_streaming_hook not triggered on proxy server
This commit is contained in:
commit
769070b3fe
5 changed files with 95 additions and 10 deletions
|
@ -17,6 +17,8 @@ This function is called just before a litellm completion call is made, and allow
|
|||
```python
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache
|
||||
from typing import Optional, Literal
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
|
@ -34,7 +36,7 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
|
|||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
]) -> Optional[dict, str, Exception]:
|
||||
]):
|
||||
data["model"] = "my-new-model"
|
||||
return data
|
||||
|
||||
|
|
64
litellm/proxy/custom_callbacks1.py
Normal file
64
litellm/proxy/custom_callbacks1.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache
|
||||
from typing import Optional, Literal
|
||||
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
class MyCustomHandler(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
):
|
||||
return data
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
# print("in async_post_call_success_hook")
|
||||
pass
|
||||
|
||||
async def async_moderation_hook( # call made in parallel to llm api call
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: str,
|
||||
):
|
||||
# print("in async_post_call_streaming_hook")
|
||||
pass
|
||||
|
||||
|
||||
proxy_handler_instance = MyCustomHandler()
|
|
@ -20,16 +20,8 @@ model_list:
|
|||
api_base: https://exampleopenaiendpoint-production.up.railway.app/triton/embeddings
|
||||
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
master_key: sk-1234
|
||||
alerting: ["slack"]
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
default_team_settings:
|
||||
- team_id: 7bf09cd5-217a-40d4-8634-fc31d9b88bf4
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
langfuse_public_key: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
|
||||
langfuse_secret_key: "os.environ/LANGFUSE_DEV_SK_KEY"
|
||||
callbacks: custom_callbacks1.proxy_handler_instance
|
|
@ -3402,6 +3402,12 @@ async def async_data_generator(
|
|||
try:
|
||||
start_time = time.time()
|
||||
async for chunk in response:
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
chunk = chunk.model_dump_json(exclude_none=True)
|
||||
try:
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
|
|
@ -516,6 +516,27 @@ class ProxyLogging:
|
|||
raise e
|
||||
return response
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
Allow user to modify outgoing streaming data -> per chunk
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
"""
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return response
|
||||
|
||||
async def post_call_streaming_hook(
|
||||
self,
|
||||
response: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue