(feat) add usage / cost tracking for Anthropic passthrough routes (#6835)

* move _process_response in transformation

* fix AnthropicConfig test

* add AnthropicConfig

* fix anthropic_passthrough_handler

* fix get_response_body

* fix check for streaming response

* use 1 helper to return stream_response on passthrough
This commit is contained in:
Ishaan Jaff 2024-11-20 17:25:12 -08:00 committed by GitHub
parent 2ee4fbb0a5
commit c991864d69
3 changed files with 142 additions and 30 deletions

View file

@ -45,11 +45,11 @@ router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging()
def get_response_body(response: httpx.Response):
def get_response_body(response: httpx.Response) -> Optional[dict]:
try:
return response.json()
except Exception:
return response.text
return None
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915
request: Request,
target: str,
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915
json=_parsed_body,
)
if (
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream"
):
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True:
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915
content = await response.aread()
## LOG SUCCESS
passthrough_logging_payload["response_body"] = get_response_body(response)
response_body: Optional[dict] = get_response_body(response)
passthrough_logging_payload["response_body"] = response_body
end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
response_body=response_body,
url_route=str(url),
result="",
start_time=start_time,
@ -619,6 +630,13 @@ def create_pass_through_route(
return endpoint_func
def _is_streaming_response(response: httpx.Response) -> bool:
_content_type = response.headers.get("content-type")
if _content_type is not None and "text/event-stream" in _content_type:
return True
return False
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug("initializing pass through endpoints")