From 088532082eb43effb2f7791c7f175c19f868e6b0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 21 Nov 2024 17:57:16 -0800 Subject: [PATCH] use helper for _handle_logging_vertex_collected_chunks --- .../vertex_passthrough_logging_handler.py | 75 +++++++++++++++++++ .../streaming_handler.py | 14 +++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 4902ed8be..5a49daa58 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -11,6 +11,9 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import ( get_standard_logging_object_payload, ) +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + ModelResponseIterator as VertexModelResponseIterator, +) if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -111,6 +114,78 @@ class VertexPassthroughLoggingHandler: **kwargs, ) + @staticmethod + async def _handle_logging_vertex_collected_chunks( + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ): + """ + Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + kwargs = {} + model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) + complete_streaming_response = ( + VertexPassthroughLoggingHandler._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + ) + + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." + ) + return + await litellm_logging_obj.async_success_handler( + result=complete_streaming_response, + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) + + @staticmethod + def _build_complete_streaming_response( + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]: + vertex_iterator = VertexModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + litellm_custom_stream_wrapper = litellm.CustomStreamWrapper( + completion_stream=vertex_iterator, + model=model, + logging_obj=litellm_logging_obj, + custom_llm_provider="vertex_ai", + ) + all_openai_chunks = [] + for chunk in all_chunks: + generic_chunk = vertex_iterator._common_chunk_parsing_logic(chunk) + litellm_chunk = litellm_custom_stream_wrapper.chunk_creator( + chunk=generic_chunk + ) + if litellm_chunk is not None: + all_openai_chunks.append(litellm_chunk) + + complete_streaming_response = litellm.stream_chunk_builder( + chunks=all_openai_chunks + ) + + return complete_streaming_response + @staticmethod def extract_model_from_url(url: str) -> str: pattern = r"/models/([^:]+)" diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 9917d88c3..67c5d7201 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -20,6 +20,9 @@ from litellm.types.utils import GenericStreamingChunk from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, ) +from .llm_provider_handlers.vertex_passthrough_logging_handler import ( + VertexPassthroughLoggingHandler, +) from .success_handler import PassThroughEndpointLogging from .types import EndpointType @@ -100,7 +103,16 @@ async def _route_streaming_logging_to_handler( end_time=end_time, ) elif endpoint_type == EndpointType.VERTEX_AI: - pass + await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) elif endpoint_type == EndpointType.GENERIC: # No logging is supported for generic streaming endpoints pass