mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) add usage / cost tracking for Anthropic passthrough routes (#6835)
* move _process_response in transformation * fix AnthropicConfig test * add AnthropicConfig * fix anthropic_passthrough_handler * fix get_response_body * fix check for streaming response * use 1 helper to return stream_response on passthrough
This commit is contained in:
parent
2ee4fbb0a5
commit
c991864d69
3 changed files with 142 additions and 30 deletions
|
@ -45,11 +45,11 @@ router = APIRouter()
|
|||
pass_through_endpoint_logging = PassThroughEndpointLogging()
|
||||
|
||||
|
||||
def get_response_body(response: httpx.Response):
|
||||
def get_response_body(response: httpx.Response) -> Optional[dict]:
|
||||
try:
|
||||
return response.json()
|
||||
except Exception:
|
||||
return response.text
|
||||
return None
|
||||
|
||||
|
||||
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
|
||||
|
@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
|
|||
def get_endpoint_type(url: str) -> EndpointType:
|
||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||
return EndpointType.VERTEX_AI
|
||||
elif ("api.anthropic.com") in url:
|
||||
return EndpointType.ANTHROPIC
|
||||
return EndpointType.GENERIC
|
||||
|
||||
|
||||
async def stream_response(
|
||||
response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
url: str,
|
||||
) -> AsyncIterable[bytes]:
|
||||
async for chunk in chunk_processor(
|
||||
response.aiter_bytes(),
|
||||
litellm_logging_obj=logging_obj,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def pass_through_request( # noqa: PLR0915
|
||||
request: Request,
|
||||
target: str,
|
||||
|
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
|
|||
status_code=e.response.status_code, detail=await e.response.aread()
|
||||
)
|
||||
|
||||
async def stream_response() -> AsyncIterable[bytes]:
|
||||
async for chunk in chunk_processor(
|
||||
response.aiter_bytes(),
|
||||
litellm_logging_obj=logging_obj,
|
||||
return StreamingResponse(
|
||||
stream_response(
|
||||
response=response,
|
||||
logging_obj=logging_obj,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
url=str(url),
|
||||
),
|
||||
headers=get_response_headers(response.headers),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915
|
|||
json=_parsed_body,
|
||||
)
|
||||
|
||||
if (
|
||||
response.headers.get("content-type") is not None
|
||||
and response.headers["content-type"] == "text/event-stream"
|
||||
):
|
||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||
|
||||
if _is_streaming_response(response) is True:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
|
|||
status_code=e.response.status_code, detail=await e.response.aread()
|
||||
)
|
||||
|
||||
async def stream_response() -> AsyncIterable[bytes]:
|
||||
async for chunk in chunk_processor(
|
||||
response.aiter_bytes(),
|
||||
litellm_logging_obj=logging_obj,
|
||||
return StreamingResponse(
|
||||
stream_response(
|
||||
response=response,
|
||||
logging_obj=logging_obj,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||
url_route=str(url),
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
url=str(url),
|
||||
),
|
||||
headers=get_response_headers(response.headers),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915
|
|||
content = await response.aread()
|
||||
|
||||
## LOG SUCCESS
|
||||
passthrough_logging_payload["response_body"] = get_response_body(response)
|
||||
response_body: Optional[dict] = get_response_body(response)
|
||||
passthrough_logging_payload["response_body"] = response_body
|
||||
end_time = datetime.now()
|
||||
await pass_through_endpoint_logging.pass_through_async_success_handler(
|
||||
httpx_response=response,
|
||||
response_body=response_body,
|
||||
url_route=str(url),
|
||||
result="",
|
||||
start_time=start_time,
|
||||
|
@ -619,6 +630,13 @@ def create_pass_through_route(
|
|||
return endpoint_func
|
||||
|
||||
|
||||
def _is_streaming_response(response: httpx.Response) -> bool:
|
||||
_content_type = response.headers.get("content-type")
|
||||
if _content_type is not None and "text/event-stream" in _content_type:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||
|
||||
verbose_proxy_logger.debug("initializing pass through endpoints")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue