(feat) Add usage tracking for streaming /anthropic passthrough routes (#6842)

* use 1 file for AnthropicPassthroughLoggingHandler

* add support for anthropic streaming usage tracking

* ci/cd run again

* fix - add real streaming for anthropic pass through

* remove unused function stream_response

* working anthropic streaming logging

* fix code quality

* fix use 1 file for vertex success handler

* use helper for _handle_logging_vertex_collected_chunks

* enforce vertex streaming to use sse for streaming

* test test_basic_vertex_ai_pass_through_streaming_with_spendlog

* fix type hints

* add comment

* fix linting

* add pass through logging unit testing
This commit is contained in:
Ishaan Jaff 2024-11-21 19:36:03 -08:00 committed by GitHub
parent 920f4c9f82
commit b8af46e1a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 688 additions and 295 deletions

View file

@ -4,7 +4,7 @@ import json
import traceback
from base64 import b64encode
from datetime import datetime
from typing import AsyncIterable, List, Optional
from typing import AsyncIterable, List, Optional, Union
import httpx
from fastapi import (
@ -308,24 +308,6 @@ def get_endpoint_type(url: str) -> EndpointType:
return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915
request: Request,
target: str,
@ -446,7 +428,6 @@ async def pass_through_request( # noqa: PLR0915
"headers": headers,
},
)
if stream:
req = async_client.build_request(
"POST",
@ -466,12 +447,14 @@ async def pass_through_request( # noqa: PLR0915
)
return StreamingResponse(
stream_response(
chunk_processor(
response=response,
logging_obj=logging_obj,
request_body=_parsed_body,
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
url=str(url),
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
@ -504,12 +487,14 @@ async def pass_through_request( # noqa: PLR0915
)
return StreamingResponse(
stream_response(
chunk_processor(
response=response,
logging_obj=logging_obj,
request_body=_parsed_body,
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
url=str(url),
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,