(feat) add usage / cost tracking for Anthropic passthrough routes (#6835)

* move _process_response in transformation

* fix AnthropicConfig test

* add AnthropicConfig

* fix anthropic_passthrough_handler

* fix get_response_body

* fix check for streaming response

* use 1 helper to return stream_response on passthrough
This commit is contained in:
Ishaan Jaff 2024-11-20 17:25:12 -08:00 committed by GitHub
parent 434b1d3d86
commit c107bae7ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 142 additions and 30 deletions

View file

@ -45,11 +45,11 @@ router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging()
def get_response_body(response: httpx.Response):
def get_response_body(response: httpx.Response) -> Optional[dict]:
try:
return response.json()
except Exception:
return response.text
return None
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915
request: Request,
target: str,
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915
json=_parsed_body,
)
if (
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream"
):
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True:
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915
content = await response.aread()
## LOG SUCCESS
passthrough_logging_payload["response_body"] = get_response_body(response)
response_body: Optional[dict] = get_response_body(response)
passthrough_logging_payload["response_body"] = response_body
end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
response_body=response_body,
url_route=str(url),
result="",
start_time=start_time,
@ -619,6 +630,13 @@ def create_pass_through_route(
return endpoint_func
def _is_streaming_response(response: httpx.Response) -> bool:
_content_type = response.headers.get("content-type")
if _content_type is not None and "text/event-stream" in _content_type:
return True
return False
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug("initializing pass through endpoints")

View file

@ -2,12 +2,17 @@ import json
import re
import threading
from datetime import datetime
from typing import Union
from typing import Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -23,9 +28,13 @@ class PassThroughEndpointLogging:
"predict",
]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
@ -45,6 +54,18 @@ class PassThroughEndpointLogging:
cache_hit=cache_hit,
**kwargs,
)
elif self.is_anthropic_route(url_route):
await self.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
else:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
@ -76,6 +97,12 @@ class PassThroughEndpointLogging:
return True
return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def extract_model_from_url(self, url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
@ -83,6 +110,72 @@ class PassThroughEndpointLogging:
return match.group(1)
return "unknown"
async def anthropic_passthrough_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: litellm.ModelResponse = (
AnthropicConfig._process_response(
response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
stream=False,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
json_mode=False,
)
)
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
kwargs["standard_logging_object"] = standard_logging_object
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
pass
async def vertex_passthrough_handler(
self,
httpx_response: httpx.Response,

View file

@ -4,6 +4,7 @@ from typing import Optional, TypedDict
class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai"
ANTHROPIC = "anthropic"
GENERIC = "generic"