litellm-mirror/litellm/proxy/pass_through_endpoints/streaming_handler.py
Ishaan Jaff b2b3e40d13
(feat) use @google-cloud/vertexai js sdk with litellm (#6873)
* stash gemini JS test

* add vertex js sdj example

* handle vertex pass through separately

* tes vertex JS sdk

* fix vertex_proxy_route

* use PassThroughStreamingHandler

* fix PassThroughStreamingHandler

* use common _create_vertex_response_logging_payload_for_generate_content

* test vertex js

* add working vertex jest tests

* move basic bass through test

* use good name for test

* test vertex

* test_chunk_processor_yields_raw_bytes

* unit tests for streaming

* test_convert_raw_bytes_to_str_lines

* run unit tests 1st

* simplify local

* docs add usage example for js

* use get_litellm_virtual_key

* add unit tests for vertex pass through
2024-11-22 16:50:10 -08:00

133 lines
4.7 KiB
Python

import asyncio
import json
from datetime import datetime
from enum import Enum
from typing import AsyncIterable, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicIterator,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexAIIterator,
)
from litellm.types.utils import GenericStreamingChunk
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
from .success_handler import PassThroughEndpointLogging
from .types import EndpointType
class PassThroughStreamingHandler:
@staticmethod
async def chunk_processor(
response: httpx.Response,
request_body: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
):
"""
- Yields chunks from the response
- Collect non-empty chunks for post-processing (logging)
"""
try:
raw_bytes: List[bytes] = []
async for chunk in response.aiter_bytes():
raw_bytes.append(chunk)
yield chunk
# After all chunks are processed, handle post-processing
end_time = datetime.now()
await PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
)
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
raise
@staticmethod
async def _route_streaming_logging_to_handler(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
raw_bytes: List[bytes],
end_time: datetime,
):
"""
Route the logging for the collected chunks to the appropriate handler
Supported endpoint types:
- Anthropic
- Vertex AI
"""
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes
)
if endpoint_type == EndpointType.ANTHROPIC:
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
elif endpoint_type == EndpointType.VERTEX_AI:
await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
elif endpoint_type == EndpointType.GENERIC:
# No logging is supported for generic streaming endpoints
pass
@staticmethod
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
"""
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
Args:
raw_bytes: List of bytes chunks from aiter.bytes()
Returns:
List of string lines, with each line being a complete data: {} chunk
"""
# Combine all bytes and decode to string
combined_str = b"".join(raw_bytes).decode("utf-8")
# Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
return lines