diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 99c6faad0f..1dc9784350 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -35,8 +35,9 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from .streaming_handler import ModelIteratorType, chunk_processor +from .streaming_handler import chunk_processor from .success_handler import PassThroughEndpointLogging +from .types import EndpointType router = APIRouter() @@ -288,6 +289,12 @@ def get_response_headers(headers: httpx.Headers) -> dict: return return_headers +def get_endpoint_type(url: str) -> EndpointType: + if ("generateContent") in url or ("streamGenerateContent") in url: + return EndpointType.VERTEX_AI + return EndpointType.GENERIC + + async def pass_through_request( request: Request, target: str, @@ -311,6 +318,8 @@ async def pass_through_request( request=request, headers=headers, forward_headers=forward_headers ) + endpoint_type: EndpointType = get_endpoint_type(str(url)) + _parsed_body = None if custom_body: _parsed_body = custom_body @@ -424,7 +433,7 @@ async def pass_through_request( async for chunk in chunk_processor( response.aiter_bytes(), litellm_logging_obj=logging_obj, - iterator_type=ModelIteratorType.VERTEX_AI, + endpoint_type=endpoint_type, start_time=start_time, passthrough_success_handler_obj=pass_through_endpoint_logging, url_route=str(url), @@ -468,7 +477,7 @@ async def pass_through_request( async for chunk in chunk_processor( response.aiter_bytes(), litellm_logging_obj=logging_obj, - iterator_type=ModelIteratorType.VERTEX_AI, + endpoint_type=endpoint_type, start_time=start_time, passthrough_success_handler_obj=pass_through_endpoint_logging, url_route=str(url), diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 8513e2702b..ab1d5d813f 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -12,17 +12,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu from litellm.types.utils import GenericStreamingChunk from .success_handler import PassThroughEndpointLogging - - -class ModelIteratorType(Enum): - VERTEX_AI = "vertexAI" - # Add more iterator types here as needed - - -MODEL_ITERATORS: Dict[ModelIteratorType, type] = { - ModelIteratorType.VERTEX_AI: VertexAIIterator, - # Add more mappings here as needed -} +from .types import EndpointType def get_litellm_chunk( @@ -37,73 +27,89 @@ def get_litellm_chunk( return None +def get_iterator_class_from_endpoint_type( + endpoint_type: EndpointType, +) -> Optional[type]: + if endpoint_type == EndpointType.VERTEX_AI: + return VertexAIIterator + return None + + async def chunk_processor( aiter_bytes: AsyncIterable[bytes], litellm_logging_obj: LiteLLMLoggingObj, - iterator_type: ModelIteratorType, + endpoint_type: EndpointType, start_time: datetime, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, ) -> AsyncIterable[bytes]: - IteratorClass = MODEL_ITERATORS[iterator_type] - model_iterator = IteratorClass(sync_stream=False, streaming_response=aiter_bytes) - custom_stream_wrapper = litellm.utils.CustomStreamWrapper( - completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj - ) - buffer = b"" - all_chunks = [] - async for chunk in aiter_bytes: - buffer += chunk - try: - _decoded_chunk = chunk.decode("utf-8") - _chunk_dict = json.loads(_decoded_chunk) - litellm_chunk = get_litellm_chunk( - model_iterator, custom_stream_wrapper, _chunk_dict - ) - if litellm_chunk: - all_chunks.append(litellm_chunk) - except json.JSONDecodeError: - pass - finally: - yield chunk # Yield the original bytes - - # Process any remaining data in the buffer - if buffer: - try: - _chunk_dict = json.loads(buffer.decode("utf-8")) - - if isinstance(_chunk_dict, list): - for _chunk in _chunk_dict: - litellm_chunk = get_litellm_chunk( - model_iterator, custom_stream_wrapper, _chunk - ) - if litellm_chunk: - all_chunks.append(litellm_chunk) - elif isinstance(_chunk_dict, dict): + iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type) + if iteratorClass is None: + # Generic endpoint - litellm does not do any tracking / logging for this + async for chunk in aiter_bytes: + yield chunk + else: + # known streaming endpoint - litellm will do tracking / logging for this + model_iterator = iteratorClass( + sync_stream=False, streaming_response=aiter_bytes + ) + custom_stream_wrapper = litellm.utils.CustomStreamWrapper( + completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj + ) + buffer = b"" + all_chunks = [] + async for chunk in aiter_bytes: + buffer += chunk + try: + _decoded_chunk = chunk.decode("utf-8") + _chunk_dict = json.loads(_decoded_chunk) litellm_chunk = get_litellm_chunk( model_iterator, custom_stream_wrapper, _chunk_dict ) if litellm_chunk: all_chunks.append(litellm_chunk) - except json.JSONDecodeError: - pass + except json.JSONDecodeError: + pass + finally: + yield chunk # Yield the original bytes - complete_streaming_response: litellm.ModelResponse = litellm.stream_chunk_builder( - chunks=all_chunks - ) - end_time = datetime.now() + # Process any remaining data in the buffer + if buffer: + try: + _chunk_dict = json.loads(buffer.decode("utf-8")) - if passthrough_success_handler_obj.is_vertex_route(url_route): - _model = passthrough_success_handler_obj.extract_model_from_url(url_route) - complete_streaming_response.model = _model - litellm_logging_obj.model = _model - litellm_logging_obj.model_call_details["model"] = _model + if isinstance(_chunk_dict, list): + for _chunk in _chunk_dict: + litellm_chunk = get_litellm_chunk( + model_iterator, custom_stream_wrapper, _chunk + ) + if litellm_chunk: + all_chunks.append(litellm_chunk) + elif isinstance(_chunk_dict, dict): + litellm_chunk = get_litellm_chunk( + model_iterator, custom_stream_wrapper, _chunk_dict + ) + if litellm_chunk: + all_chunks.append(litellm_chunk) + except json.JSONDecodeError: + pass - asyncio.create_task( - litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, + complete_streaming_response: litellm.ModelResponse = ( + litellm.stream_chunk_builder(chunks=all_chunks) + ) + end_time = datetime.now() + + if passthrough_success_handler_obj.is_vertex_route(url_route): + _model = passthrough_success_handler_obj.extract_model_from_url(url_route) + complete_streaming_response.model = _model + litellm_logging_obj.model = _model + litellm_logging_obj.model_call_details["model"] = _model + + asyncio.create_task( + litellm_logging_obj.async_success_handler( + result=complete_streaming_response, + start_time=start_time, + end_time=end_time, + ) ) - ) diff --git a/litellm/proxy/pass_through_endpoints/types.py b/litellm/proxy/pass_through_endpoints/types.py new file mode 100644 index 0000000000..662788af08 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/types.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class EndpointType(str, Enum): + VERTEX_AI = "vertex-ai" + GENERIC = "generic"