(feat) Add usage tracking for streaming /anthropic passthrough routes (#6842)

* use 1 file for AnthropicPassthroughLoggingHandler

* add support for anthropic streaming usage tracking

* ci/cd run again

* fix - add real streaming for anthropic pass through

* remove unused function stream_response

* working anthropic streaming logging

* fix code quality

* fix use 1 file for vertex success handler

* use helper for _handle_logging_vertex_collected_chunks

* enforce vertex streaming to use sse for streaming

* test test_basic_vertex_ai_pass_through_streaming_with_spendlog

* fix type hints

* add comment

* fix linting

* add pass through logging unit testing
This commit is contained in:
Ishaan Jaff 2024-11-21 19:36:03 -08:00 committed by GitHub
parent 920f4c9f82
commit b8af46e1a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 688 additions and 295 deletions

View file

@ -779,3 +779,32 @@ class ModelResponseIterator:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
"""
Convert a string chunk to a GenericStreamingChunk
Note: This is used for Anthropic pass through streaming logging
We can move __anext__, and __next__ to use this function since it's common logic.
Did not migrate them to minmize changes made in 1 PR.
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
if str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
return self.chunk_parser(chunk=data_json)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)

View file

@ -178,7 +178,10 @@ async def anthropic_proxy_route(
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH

View file

@ -0,0 +1,206 @@
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class AnthropicPassthroughLoggingHandler:
@staticmethod
async def anthropic_passthrough_handler(
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: litellm.ModelResponse = (
AnthropicConfig._process_response(
response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
stream=False,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
json_mode=False,
)
)
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
)
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
pass
@staticmethod
def _create_anthropic_response_logging_payload(
litellm_model_response: Union[
litellm.ModelResponse, litellm.TextCompletionResponse
],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
):
"""
Create the standard logging object for Anthropic passthrough
handles streaming and non-streaming responses
"""
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
kwargs["standard_logging_object"] = standard_logging_object
return kwargs
@staticmethod
async def _handle_logging_anthropic_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
):
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
model = request_body.get("model", "")
complete_streaming_response = (
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]:
"""
Builds complete response from raw Anthropic chunks
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
completion_stream=anthropic_model_response_iterator,
model=model,
logging_obj=litellm_logging_obj,
custom_llm_provider="anthropic",
)
all_openai_chunks = []
for _chunk_str in all_chunks:
try:
generic_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk(
chunk=_chunk_str
)
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
chunk=generic_chunk
)
if litellm_chunk is not None:
all_openai_chunks.append(litellm_chunk)
except (StopIteration, StopAsyncIteration):
break
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response

View file

@ -0,0 +1,195 @@
import json
import re
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator,
)
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class VertexPassthroughLoggingHandler:
@staticmethod
async def vertex_passthrough_handler(
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
if "generateContent" in url_route:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = (
instance_of_vertex_llm._transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
)
)
logging_obj.model = litellm_model_response.model or model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
elif "predict" in url_route:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration,
)
from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration()
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
_json_response = httpx_response.json()
litellm_prediction_response: Union[
litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse
] = litellm.ModelResponse()
if vertex_image_generation_class.is_image_generation_response(
_json_response
):
litellm_prediction_response = (
vertex_image_generation_class.process_image_generation_response(
_json_response,
model_response=litellm.ImageResponse(),
model=model,
)
)
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
else:
litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response,
model=model,
model_response=litellm.EmbeddingResponse(),
)
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
litellm_prediction_response.model = model
logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_prediction_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
@staticmethod
async def _handle_logging_vertex_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
):
"""
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
kwargs: Dict[str, Any] = {}
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
complete_streaming_response = (
VertexPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
)
return
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]:
vertex_iterator = VertexModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
completion_stream=vertex_iterator,
model=model,
logging_obj=litellm_logging_obj,
custom_llm_provider="vertex_ai",
)
all_openai_chunks = []
for chunk in all_chunks:
generic_chunk = vertex_iterator._common_chunk_parsing_logic(chunk)
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
chunk=generic_chunk
)
if litellm_chunk is not None:
all_openai_chunks.append(litellm_chunk)
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response
@staticmethod
def extract_model_from_url(url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"

View file

@ -4,7 +4,7 @@ import json
import traceback
from base64 import b64encode
from datetime import datetime
from typing import AsyncIterable, List, Optional
from typing import AsyncIterable, List, Optional, Union
import httpx
from fastapi import (
@ -308,24 +308,6 @@ def get_endpoint_type(url: str) -> EndpointType:
return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915
request: Request,
target: str,
@ -446,7 +428,6 @@ async def pass_through_request( # noqa: PLR0915
"headers": headers,
},
)
if stream:
req = async_client.build_request(
"POST",
@ -466,12 +447,14 @@ async def pass_through_request( # noqa: PLR0915
)
return StreamingResponse(
stream_response(
chunk_processor(
response=response,
logging_obj=logging_obj,
request_body=_parsed_body,
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
url=str(url),
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
@ -504,12 +487,14 @@ async def pass_through_request( # noqa: PLR0915
)
return StreamingResponse(
stream_response(
chunk_processor(
response=response,
logging_obj=logging_obj,
request_body=_parsed_body,
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
url=str(url),
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,

View file

@ -4,114 +4,116 @@ from datetime import datetime
from enum import Enum
from typing import AsyncIterable, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicIterator,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexAIIterator,
)
from litellm.types.utils import GenericStreamingChunk
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
from .success_handler import PassThroughEndpointLogging
from .types import EndpointType
def get_litellm_chunk(
model_iterator: VertexAIIterator,
custom_stream_wrapper: litellm.utils.CustomStreamWrapper,
chunk_dict: Dict,
) -> Optional[Dict]:
generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict)
if generic_chunk:
return custom_stream_wrapper.chunk_creator(chunk=generic_chunk)
return None
def get_iterator_class_from_endpoint_type(
endpoint_type: EndpointType,
) -> Optional[type]:
if endpoint_type == EndpointType.VERTEX_AI:
return VertexAIIterator
return None
async def chunk_processor(
aiter_bytes: AsyncIterable[bytes],
response: httpx.Response,
request_body: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
) -> AsyncIterable[bytes]:
iteratorClass = get_iterator_class_from_endpoint_type(endpoint_type)
if iteratorClass is None:
# Generic endpoint - litellm does not do any tracking / logging for this
async for chunk in aiter_bytes:
yield chunk
else:
# known streaming endpoint - litellm will do tracking / logging for this
model_iterator = iteratorClass(
sync_stream=False, streaming_response=aiter_bytes
)
custom_stream_wrapper = litellm.utils.CustomStreamWrapper(
completion_stream=aiter_bytes, model=None, logging_obj=litellm_logging_obj
)
buffer = b""
all_chunks = []
async for chunk in aiter_bytes:
buffer += chunk
):
"""
- Yields chunks from the response
- Collect non-empty chunks for post-processing (logging)
"""
collected_chunks: List[str] = [] # List to store all chunks
try:
_decoded_chunk = chunk.decode("utf-8")
_chunk_dict = json.loads(_decoded_chunk)
litellm_chunk = get_litellm_chunk(
model_iterator, custom_stream_wrapper, _chunk_dict
)
if litellm_chunk:
all_chunks.append(litellm_chunk)
except json.JSONDecodeError:
pass
finally:
yield chunk # Yield the original bytes
async for chunk in response.aiter_lines():
verbose_proxy_logger.debug(f"Processing chunk: {chunk}")
if not chunk:
continue
# Process any remaining data in the buffer
if buffer:
try:
_chunk_dict = json.loads(buffer.decode("utf-8"))
# Handle SSE format - pass through the raw SSE format
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if isinstance(_chunk_dict, list):
for _chunk in _chunk_dict:
litellm_chunk = get_litellm_chunk(
model_iterator, custom_stream_wrapper, _chunk
)
if litellm_chunk:
all_chunks.append(litellm_chunk)
elif isinstance(_chunk_dict, dict):
litellm_chunk = get_litellm_chunk(
model_iterator, custom_stream_wrapper, _chunk_dict
)
if litellm_chunk:
all_chunks.append(litellm_chunk)
except json.JSONDecodeError:
pass
# Store the chunk for post-processing
if chunk.strip(): # Only store non-empty chunks
collected_chunks.append(chunk)
yield f"{chunk}\n"
complete_streaming_response: Optional[
Union[litellm.ModelResponse, litellm.TextCompletionResponse]
] = litellm.stream_chunk_builder(chunks=all_chunks)
if complete_streaming_response is None:
complete_streaming_response = litellm.ModelResponse()
# After all chunks are processed, handle post-processing
end_time = datetime.now()
if passthrough_success_handler_obj.is_vertex_route(url_route):
_model = passthrough_success_handler_obj.extract_model_from_url(url_route)
complete_streaming_response.model = _model
litellm_logging_obj.model = _model
litellm_logging_obj.model_call_details["model"] = _model
asyncio.create_task(
litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
await _route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=collected_chunks,
end_time=end_time,
)
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
raise
async def _route_streaming_logging_to_handler(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
):
"""
Route the logging for the collected chunks to the appropriate handler
Supported endpoint types:
- Anthropic
- Vertex AI
"""
if endpoint_type == EndpointType.ANTHROPIC:
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
elif endpoint_type == EndpointType.VERTEX_AI:
await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
elif endpoint_type == EndpointType.GENERIC:
# No logging is supported for generic streaming endpoints
pass

View file

@ -12,13 +12,19 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import StandardPassThroughResponseObject
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
class PassThroughEndpointLogging:
def __init__(self):
@ -44,7 +50,7 @@ class PassThroughEndpointLogging:
**kwargs,
):
if self.is_vertex_route(url_route):
await self.vertex_passthrough_handler(
await VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
@ -55,7 +61,7 @@ class PassThroughEndpointLogging:
**kwargs,
)
elif self.is_anthropic_route(url_route):
await self.anthropic_passthrough_handler(
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
@ -102,166 +108,3 @@ class PassThroughEndpointLogging:
if route in url_route:
return True
return False
def extract_model_from_url(self, url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"
async def anthropic_passthrough_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: litellm.ModelResponse = (
AnthropicConfig._process_response(
response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
stream=False,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
json_mode=False,
)
)
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
kwargs["standard_logging_object"] = standard_logging_object
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
pass
async def vertex_passthrough_handler(
self,
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
if "generateContent" in url_route:
model = self.extract_model_from_url(url_route)
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = (
instance_of_vertex_llm._transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
)
)
logging_obj.model = litellm_model_response.model or model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
elif "predict" in url_route:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration,
)
from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration()
model = self.extract_model_from_url(url_route)
_json_response = httpx_response.json()
litellm_prediction_response: Union[
litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse
] = litellm.ModelResponse()
if vertex_image_generation_class.is_image_generation_response(
_json_response
):
litellm_prediction_response = (
vertex_image_generation_class.process_image_generation_response(
_json_response,
model_response=litellm.ImageResponse(),
model=model,
)
)
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
else:
litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response,
model=model,
model_response=litellm.EmbeddingResponse(),
)
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
litellm_prediction_response.model = model
logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_prediction_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)

View file

@ -4,15 +4,6 @@ model_list:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
router_settings:
provider_budget_config:
openai:
budget_limit: 0.000000000001 # float of $ value budget for time period
time_period: 1d # can be 1d, 2d, 30d
azure:
budget_limit: 100
time_period: 1d
litellm_settings:
callbacks: ["prometheus"]
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"

View file

@ -194,14 +194,16 @@ async def vertex_proxy_route(
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(

View file

@ -1,5 +1,6 @@
"""
This test ensures that the proxy can passthrough anthropic requests
"""
import pytest

View file

@ -121,6 +121,7 @@ async def test_basic_vertex_ai_pass_through_with_spendlog():
@pytest.mark.asyncio()
@pytest.mark.skip(reason="skip flaky test - vertex pass through streaming is flaky")
async def test_basic_vertex_ai_pass_through_streaming_with_spendlog():
spend_before = await call_spend_logs_endpoint() or 0.0

View file

@ -0,0 +1,135 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import httpx
import pytest
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
# Import the class we're testing
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
@pytest.fixture
def mock_response():
return {
"model": "claude-3-opus-20240229",
"content": [{"text": "Hello, world!", "type": "text"}],
"role": "assistant",
}
@pytest.fixture
def mock_httpx_response():
mock_resp = Mock(spec=httpx.Response)
mock_resp.json.return_value = {
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-5-sonnet-20241022",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": None,
"type": "message",
"usage": {"input_tokens": 2095, "output_tokens": 503},
}
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/json"}
return mock_resp
@pytest.fixture
def mock_logging_obj():
logging_obj = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
)
logging_obj.async_success_handler = AsyncMock()
return logging_obj
@pytest.mark.asyncio
async def test_anthropic_passthrough_handler(
mock_httpx_response, mock_response, mock_logging_obj
):
"""
Unit test - Assert that the anthropic passthrough handler calls the litellm logging object's async_success_handler
"""
start_time = datetime.now()
end_time = datetime.now()
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=mock_httpx_response,
response_body=mock_response,
logging_obj=mock_logging_obj,
url_route="/v1/chat/completions",
result="success",
start_time=start_time,
end_time=end_time,
cache_hit=False,
)
# Assert that async_success_handler was called
assert mock_logging_obj.async_success_handler.called
call_args = mock_logging_obj.async_success_handler.call_args
call_kwargs = call_args.kwargs
print("call_kwargs", call_kwargs)
# Assert required fields are present in call_kwargs
assert "result" in call_kwargs
assert "start_time" in call_kwargs
assert "end_time" in call_kwargs
assert "cache_hit" in call_kwargs
assert "response_cost" in call_kwargs
assert "model" in call_kwargs
assert "standard_logging_object" in call_kwargs
# Assert specific values and types
assert isinstance(call_kwargs["result"], litellm.ModelResponse)
assert isinstance(call_kwargs["start_time"], datetime)
assert isinstance(call_kwargs["end_time"], datetime)
assert isinstance(call_kwargs["cache_hit"], bool)
assert isinstance(call_kwargs["response_cost"], float)
assert call_kwargs["model"] == "claude-3-opus-20240229"
assert isinstance(call_kwargs["standard_logging_object"], dict)
def test_create_anthropic_response_logging_payload(mock_logging_obj):
# Test the logging payload creation
model_response = litellm.ModelResponse()
model_response.choices = [{"message": {"content": "Test response"}}]
start_time = datetime.now()
end_time = datetime.now()
result = (
AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=model_response,
model="claude-3-opus-20240229",
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=mock_logging_obj,
)
)
assert isinstance(result, dict)
assert "model" in result
assert "response_cost" in result
assert "standard_logging_object" in result