diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index f34efdcf39..e138df0096 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -22,6 +22,9 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + ModelResponseIterator, +) from litellm.proxy._types import ( ConfigFieldInfo, ConfigFieldUpdate, @@ -32,6 +35,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from .streaming_handler import ModelIteratorType, chunk_processor from .success_handler import PassThroughEndpointLogging router = APIRouter() @@ -416,9 +420,13 @@ async def pass_through_request( status_code=e.response.status_code, detail=await e.response.aread() ) - # Create an async generator to yield the response content async def stream_response() -> AsyncIterable[bytes]: - async for chunk in response.aiter_bytes(): + async for chunk in chunk_processor( + response.aiter_bytes(), + litellm_logging_obj=logging_obj, + iterator_type=ModelIteratorType.VERTEX_AI, + start_time=start_time, + ): yield chunk return StreamingResponse( @@ -454,10 +462,13 @@ async def pass_through_request( status_code=e.response.status_code, detail=await e.response.aread() ) - # streaming response - # Create an async generator to yield the response content async def stream_response() -> AsyncIterable[bytes]: - async for chunk in response.aiter_bytes(): + async for chunk in chunk_processor( + response.aiter_bytes(), + litellm_logging_obj=logging_obj, + iterator_type=ModelIteratorType.VERTEX_AI, + start_time=start_time, + ): yield chunk return StreamingResponse( diff --git a/litellm/proxy/tests/test_vertex_sdk_forward_headers.py b/litellm/proxy/tests/test_vertex_sdk_forward_headers.py index 0799ef8eb8..7aa87905ab 100644 --- a/litellm/proxy/tests/test_vertex_sdk_forward_headers.py +++ b/litellm/proxy/tests/test_vertex_sdk_forward_headers.py @@ -10,7 +10,12 @@ vertexai.init( api_transport="rest", ) -model = GenerativeModel(model_name="gemini-1.0-pro") -response = model.generate_content("hi") +model = GenerativeModel(model_name="gemini-1.5-flash-001") +response = model.generate_content( + "hi tell me a joke and a very long story", stream=True +) print("response", response) + +for chunk in response: + print(chunk)