forked from phoenix/litellm-mirror
(feat) add usage / cost tracking for Anthropic passthrough routes (#6835)
* move _process_response in transformation * fix AnthropicConfig test * add AnthropicConfig * fix anthropic_passthrough_handler * fix get_response_body * fix check for streaming response * use 1 helper to return stream_response on passthrough
This commit is contained in:
parent
434b1d3d86
commit
c107bae7ae
3 changed files with 142 additions and 30 deletions
|
@ -45,11 +45,11 @@ router = APIRouter()
|
||||||
pass_through_endpoint_logging = PassThroughEndpointLogging()
|
pass_through_endpoint_logging = PassThroughEndpointLogging()
|
||||||
|
|
||||||
|
|
||||||
def get_response_body(response: httpx.Response):
|
def get_response_body(response: httpx.Response) -> Optional[dict]:
|
||||||
try:
|
try:
|
||||||
return response.json()
|
return response.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
return response.text
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
|
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
|
||||||
|
@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
|
||||||
def get_endpoint_type(url: str) -> EndpointType:
|
def get_endpoint_type(url: str) -> EndpointType:
|
||||||
if ("generateContent") in url or ("streamGenerateContent") in url:
|
if ("generateContent") in url or ("streamGenerateContent") in url:
|
||||||
return EndpointType.VERTEX_AI
|
return EndpointType.VERTEX_AI
|
||||||
|
elif ("api.anthropic.com") in url:
|
||||||
|
return EndpointType.ANTHROPIC
|
||||||
return EndpointType.GENERIC
|
return EndpointType.GENERIC
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_response(
|
||||||
|
response: httpx.Response,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
endpoint_type: EndpointType,
|
||||||
|
start_time: datetime,
|
||||||
|
url: str,
|
||||||
|
) -> AsyncIterable[bytes]:
|
||||||
|
async for chunk in chunk_processor(
|
||||||
|
response.aiter_bytes(),
|
||||||
|
litellm_logging_obj=logging_obj,
|
||||||
|
endpoint_type=endpoint_type,
|
||||||
|
start_time=start_time,
|
||||||
|
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
||||||
|
url_route=str(url),
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request( # noqa: PLR0915
|
async def pass_through_request( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
target: str,
|
target: str,
|
||||||
|
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
status_code=e.response.status_code, detail=await e.response.aread()
|
status_code=e.response.status_code, detail=await e.response.aread()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_response() -> AsyncIterable[bytes]:
|
return StreamingResponse(
|
||||||
async for chunk in chunk_processor(
|
stream_response(
|
||||||
response.aiter_bytes(),
|
response=response,
|
||||||
litellm_logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
endpoint_type=endpoint_type,
|
endpoint_type=endpoint_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
url=str(url),
|
||||||
url_route=str(url),
|
),
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_response(),
|
|
||||||
headers=get_response_headers(response.headers),
|
headers=get_response_headers(response.headers),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
|
@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
json=_parsed_body,
|
json=_parsed_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||||
response.headers.get("content-type") is not None
|
|
||||||
and response.headers["content-type"] == "text/event-stream"
|
if _is_streaming_response(response) is True:
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
status_code=e.response.status_code, detail=await e.response.aread()
|
status_code=e.response.status_code, detail=await e.response.aread()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_response() -> AsyncIterable[bytes]:
|
return StreamingResponse(
|
||||||
async for chunk in chunk_processor(
|
stream_response(
|
||||||
response.aiter_bytes(),
|
response=response,
|
||||||
litellm_logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
endpoint_type=endpoint_type,
|
endpoint_type=endpoint_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
passthrough_success_handler_obj=pass_through_endpoint_logging,
|
url=str(url),
|
||||||
url_route=str(url),
|
),
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_response(),
|
|
||||||
headers=get_response_headers(response.headers),
|
headers=get_response_headers(response.headers),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
|
@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
|
|
||||||
## LOG SUCCESS
|
## LOG SUCCESS
|
||||||
passthrough_logging_payload["response_body"] = get_response_body(response)
|
response_body: Optional[dict] = get_response_body(response)
|
||||||
|
passthrough_logging_payload["response_body"] = response_body
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
await pass_through_endpoint_logging.pass_through_async_success_handler(
|
await pass_through_endpoint_logging.pass_through_async_success_handler(
|
||||||
httpx_response=response,
|
httpx_response=response,
|
||||||
|
response_body=response_body,
|
||||||
url_route=str(url),
|
url_route=str(url),
|
||||||
result="",
|
result="",
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -619,6 +630,13 @@ def create_pass_through_route(
|
||||||
return endpoint_func
|
return endpoint_func
|
||||||
|
|
||||||
|
|
||||||
|
def _is_streaming_response(response: httpx.Response) -> bool:
|
||||||
|
_content_type = response.headers.get("content-type")
|
||||||
|
if _content_type is not None and "text/event-stream" in _content_type:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
|
|
||||||
verbose_proxy_logger.debug("initializing pass through endpoints")
|
verbose_proxy_logger.debug("initializing pass through endpoints")
|
||||||
|
|
|
@ -2,12 +2,17 @@ import json
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
get_standard_logging_object_payload,
|
||||||
|
)
|
||||||
|
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
@ -23,9 +28,13 @@ class PassThroughEndpointLogging:
|
||||||
"predict",
|
"predict",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Anthropic
|
||||||
|
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
|
||||||
|
|
||||||
async def pass_through_async_success_handler(
|
async def pass_through_async_success_handler(
|
||||||
self,
|
self,
|
||||||
httpx_response: httpx.Response,
|
httpx_response: httpx.Response,
|
||||||
|
response_body: Optional[dict],
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
url_route: str,
|
url_route: str,
|
||||||
result: str,
|
result: str,
|
||||||
|
@ -45,6 +54,18 @@ class PassThroughEndpointLogging:
|
||||||
cache_hit=cache_hit,
|
cache_hit=cache_hit,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
elif self.is_anthropic_route(url_route):
|
||||||
|
await self.anthropic_passthrough_handler(
|
||||||
|
httpx_response=httpx_response,
|
||||||
|
response_body=response_body or {},
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
url_route=url_route,
|
||||||
|
result=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||||
response=httpx_response.text
|
response=httpx_response.text
|
||||||
|
@ -76,6 +97,12 @@ class PassThroughEndpointLogging:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_anthropic_route(self, url_route: str):
|
||||||
|
for route in self.TRACKED_ANTHROPIC_ROUTES:
|
||||||
|
if route in url_route:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def extract_model_from_url(self, url: str) -> str:
|
def extract_model_from_url(self, url: str) -> str:
|
||||||
pattern = r"/models/([^:]+)"
|
pattern = r"/models/([^:]+)"
|
||||||
match = re.search(pattern, url)
|
match = re.search(pattern, url)
|
||||||
|
@ -83,6 +110,72 @@ class PassThroughEndpointLogging:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
|
async def anthropic_passthrough_handler(
|
||||||
|
self,
|
||||||
|
httpx_response: httpx.Response,
|
||||||
|
response_body: dict,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
url_route: str,
|
||||||
|
result: str,
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime,
|
||||||
|
cache_hit: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||||
|
"""
|
||||||
|
model = response_body.get("model", "")
|
||||||
|
litellm_model_response: litellm.ModelResponse = (
|
||||||
|
AnthropicConfig._process_response(
|
||||||
|
response=httpx_response,
|
||||||
|
model_response=litellm.ModelResponse(),
|
||||||
|
model=model,
|
||||||
|
stream=False,
|
||||||
|
messages=[],
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params={},
|
||||||
|
api_key="",
|
||||||
|
data={},
|
||||||
|
print_verbose=litellm.print_verbose,
|
||||||
|
encoding=None,
|
||||||
|
json_mode=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response_cost = litellm.completion_cost(
|
||||||
|
completion_response=litellm_model_response,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
kwargs["response_cost"] = response_cost
|
||||||
|
kwargs["model"] = model
|
||||||
|
|
||||||
|
# Make standard logging object for Vertex AI
|
||||||
|
standard_logging_object = get_standard_logging_object_payload(
|
||||||
|
kwargs=kwargs,
|
||||||
|
init_response_obj=litellm_model_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
status="success",
|
||||||
|
)
|
||||||
|
|
||||||
|
# pretty print standard logging object
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
|
||||||
|
)
|
||||||
|
kwargs["standard_logging_object"] = standard_logging_object
|
||||||
|
|
||||||
|
await logging_obj.async_success_handler(
|
||||||
|
result=litellm_model_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
async def vertex_passthrough_handler(
|
async def vertex_passthrough_handler(
|
||||||
self,
|
self,
|
||||||
httpx_response: httpx.Response,
|
httpx_response: httpx.Response,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional, TypedDict
|
||||||
|
|
||||||
class EndpointType(str, Enum):
|
class EndpointType(str, Enum):
|
||||||
VERTEX_AI = "vertex-ai"
|
VERTEX_AI = "vertex-ai"
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
GENERIC = "generic"
|
GENERIC = "generic"
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue