From 3f9c58507e7136a60bfa67a422d8b51dd635ebda Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 2 Sep 2024 16:11:20 -0700 Subject: [PATCH] pass through track usage for streaming endpoints --- .../pass_through_endpoints.py | 4 +++ .../streaming_handler.py | 29 ++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index e138df0096..99c6faad0f 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -426,6 +426,8 @@ async def pass_through_request( litellm_logging_obj=logging_obj, iterator_type=ModelIteratorType.VERTEX_AI, start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), ): yield chunk @@ -468,6 +470,8 @@ async def pass_through_request( litellm_logging_obj=logging_obj, iterator_type=ModelIteratorType.VERTEX_AI, start_time=start_time, + passthrough_success_handler_obj=pass_through_endpoint_logging, + url_route=str(url), ): yield chunk diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index ba0359317d..8513e2702b 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -11,6 +11,8 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu ) from litellm.types.utils import GenericStreamingChunk +from .success_handler import PassThroughEndpointLogging + class ModelIteratorType(Enum): VERTEX_AI = "vertexAI" @@ -28,6 +30,7 @@ def get_litellm_chunk( custom_stream_wrapper: litellm.utils.CustomStreamWrapper, chunk_dict: Dict, ) -> Optional[Dict]: + generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict) if generic_chunk: return custom_stream_wrapper.chunk_creator(chunk=generic_chunk) @@ -39,6 +42,8 @@ async def chunk_processor( litellm_logging_obj: LiteLLMLoggingObj, iterator_type: ModelIteratorType, start_time: datetime, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, ) -> AsyncIterable[bytes]: IteratorClass = MODEL_ITERATORS[iterator_type] @@ -84,11 +89,21 @@ async def chunk_processor( except json.JSONDecodeError: pass - complete_streaming_response = litellm.stream_chunk_builder(chunks=all_chunks) - - end_time = datetime.now() - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, + complete_streaming_response: litellm.ModelResponse = litellm.stream_chunk_builder( + chunks=all_chunks + ) + end_time = datetime.now() + + if passthrough_success_handler_obj.is_vertex_route(url_route): + _model = passthrough_success_handler_obj.extract_model_from_url(url_route) + complete_streaming_response.model = _model + litellm_logging_obj.model = _model + litellm_logging_obj.model_call_details["model"] = _model + + asyncio.create_task( + litellm_logging_obj.async_success_handler( + result=complete_streaming_response, + start_time=start_time, + end_time=end_time, + ) )