(feat) add usage / cost tracking for Anthropic passthrough routes (#6835)

* move _process_response in transformation

* fix AnthropicConfig test

* add AnthropicConfig

* fix anthropic_passthrough_handler

* fix get_response_body

* fix check for streaming response

* use 1 helper to return stream_response on passthrough
This commit is contained in:
Ishaan Jaff 2024-11-20 17:25:12 -08:00 committed by GitHub
parent 434b1d3d86
commit c107bae7ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 142 additions and 30 deletions

View file

@ -45,11 +45,11 @@ router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging() pass_through_endpoint_logging = PassThroughEndpointLogging()
def get_response_body(response: httpx.Response): def get_response_body(response: httpx.Response) -> Optional[dict]:
try: try:
return response.json() return response.json()
except Exception: except Exception:
return response.text return None
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]: async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
def get_endpoint_type(url: str) -> EndpointType: def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url: if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915 async def pass_through_request( # noqa: PLR0915
request: Request, request: Request,
target: str, target: str,
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread() status_code=e.response.status_code, detail=await e.response.aread()
) )
async def stream_response() -> AsyncIterable[bytes]: return StreamingResponse(
async for chunk in chunk_processor( stream_response(
response.aiter_bytes(), response=response,
litellm_logging_obj=logging_obj, logging_obj=logging_obj,
endpoint_type=endpoint_type, endpoint_type=endpoint_type,
start_time=start_time, start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging, url=str(url),
url_route=str(url), ),
):
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers), headers=get_response_headers(response.headers),
status_code=response.status_code, status_code=response.status_code,
) )
@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915
json=_parsed_body, json=_parsed_body,
) )
if ( verbose_proxy_logger.debug("response.headers= %s", response.headers)
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream" if _is_streaming_response(response) is True:
):
try: try:
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread() status_code=e.response.status_code, detail=await e.response.aread()
) )
async def stream_response() -> AsyncIterable[bytes]: return StreamingResponse(
async for chunk in chunk_processor( stream_response(
response.aiter_bytes(), response=response,
litellm_logging_obj=logging_obj, logging_obj=logging_obj,
endpoint_type=endpoint_type, endpoint_type=endpoint_type,
start_time=start_time, start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging, url=str(url),
url_route=str(url), ),
):
yield chunk
return StreamingResponse(
stream_response(),
headers=get_response_headers(response.headers), headers=get_response_headers(response.headers),
status_code=response.status_code, status_code=response.status_code,
) )
@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915
content = await response.aread() content = await response.aread()
## LOG SUCCESS ## LOG SUCCESS
passthrough_logging_payload["response_body"] = get_response_body(response) response_body: Optional[dict] = get_response_body(response)
passthrough_logging_payload["response_body"] = response_body
end_time = datetime.now() end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler( await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response, httpx_response=response,
response_body=response_body,
url_route=str(url), url_route=str(url),
result="", result="",
start_time=start_time, start_time=start_time,
@ -619,6 +630,13 @@ def create_pass_through_route(
return endpoint_func return endpoint_func
def _is_streaming_response(response: httpx.Response) -> bool:
_content_type = response.headers.get("content-type")
if _content_type is not None and "text/event-stream" in _content_type:
return True
return False
async def initialize_pass_through_endpoints(pass_through_endpoints: list): async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug("initializing pass through endpoints") verbose_proxy_logger.debug("initializing pass through endpoints")

View file

@ -2,12 +2,17 @@ import json
import re import re
import threading import threading
from datetime import datetime from datetime import datetime
from typing import Union from typing import Optional, Union
import httpx import httpx
import litellm import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -23,9 +28,13 @@ class PassThroughEndpointLogging:
"predict", "predict",
] ]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
async def pass_through_async_success_handler( async def pass_through_async_success_handler(
self, self,
httpx_response: httpx.Response, httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
url_route: str, url_route: str,
result: str, result: str,
@ -45,6 +54,18 @@ class PassThroughEndpointLogging:
cache_hit=cache_hit, cache_hit=cache_hit,
**kwargs, **kwargs,
) )
elif self.is_anthropic_route(url_route):
await self.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
else: else:
standard_logging_response_object = StandardPassThroughResponseObject( standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text response=httpx_response.text
@ -76,6 +97,12 @@ class PassThroughEndpointLogging:
return True return True
return False return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def extract_model_from_url(self, url: str) -> str: def extract_model_from_url(self, url: str) -> str:
pattern = r"/models/([^:]+)" pattern = r"/models/([^:]+)"
match = re.search(pattern, url) match = re.search(pattern, url)
@ -83,6 +110,72 @@ class PassThroughEndpointLogging:
return match.group(1) return match.group(1)
return "unknown" return "unknown"
async def anthropic_passthrough_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: litellm.ModelResponse = (
AnthropicConfig._process_response(
response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
stream=False,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
json_mode=False,
)
)
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
kwargs["standard_logging_object"] = standard_logging_object
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
pass
async def vertex_passthrough_handler( async def vertex_passthrough_handler(
self, self,
httpx_response: httpx.Response, httpx_response: httpx.Response,

View file

@ -4,6 +4,7 @@ from typing import Optional, TypedDict
class EndpointType(str, Enum): class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai" VERTEX_AI = "vertex-ai"
ANTHROPIC = "anthropic"
GENERIC = "generic" GENERIC = "generic"