From 67179292060609e0983af7b85e35fafbe393742e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 21 Nov 2024 21:41:05 -0800 Subject: [PATCH] (Feat) Allow passing `litellm_metadata` to pass through endpoints + Add e2e tests for /anthropic/ usage tracking (#6864) * allow passing _litellm_metadata in pass through endpoints * fix _create_anthropic_response_logging_payload * include litellm_call_id in logging * add e2e testing for anthropic spend logs * add testing for spend logs payload * add example with anthropic python SDK --- .../docs/pass_through/anthropic_completion.md | 39 ++- .../anthropic_passthrough_logging_handler.py | 5 + .../pass_through_endpoints.py | 73 ++++-- .../test_anthropic_passthrough.py | 224 ++++++++++++++++++ 4 files changed, 321 insertions(+), 20 deletions(-) diff --git a/docs/my-website/docs/pass_through/anthropic_completion.md b/docs/my-website/docs/pass_through/anthropic_completion.md index 0c6a5f1b6..320527580 100644 --- a/docs/my-website/docs/pass_through/anthropic_completion.md +++ b/docs/my-website/docs/pass_through/anthropic_completion.md @@ -1,10 +1,18 @@ -# Anthropic `/v1/messages` +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Anthropic SDK Pass-through endpoints for Anthropic - call provider-specific endpoint, in native format (no translation). -Just replace `https://api.anthropic.com` with `LITELLM_PROXY_BASE_URL/anthropic` 🚀 +Just replace `https://api.anthropic.com` with `LITELLM_PROXY_BASE_URL/anthropic` #### **Example Usage** + + + + + ```bash curl --request POST \ --url http://0.0.0.0:4000/anthropic/v1/messages \ @@ -20,6 +28,33 @@ curl --request POST \ }' ``` + + + +```python +from anthropic import Anthropic + +# Initialize client with proxy base URL +client = Anthropic( + base_url="http://0.0.0.0:4000/anthropic", # /anthropic + api_key="sk-anything" # proxy virtual key +) + +# Make a completion request +response = client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=1024, + messages=[ + {"role": "user", "content": "Hello, world"} + ] +) + +print(response) +``` + + + + Supports **ALL** Anthropic Endpoints (including streaming). [**See All Anthropic Endpoints**](https://docs.anthropic.com/en/api/messages) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 1b18c3ab0..35cff0db3 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -115,6 +115,11 @@ class AnthropicPassthroughLoggingHandler: "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) ) kwargs["standard_logging_object"] = standard_logging_object + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + litellm_model_response.model = model + logging_obj.model_call_details["model"] = model return kwargs @staticmethod diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index fd676189e..baf107a16 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -289,13 +289,18 @@ def forward_headers_from_request( return headers -def get_response_headers(headers: httpx.Headers) -> dict: +def get_response_headers( + headers: httpx.Headers, litellm_call_id: Optional[str] = None +) -> dict: excluded_headers = {"transfer-encoding", "content-encoding"} + return_headers = { key: value for key, value in headers.items() if key.lower() not in excluded_headers } + if litellm_call_id: + return_headers["x-litellm-call-id"] = litellm_call_id return return_headers @@ -361,6 +366,8 @@ async def pass_through_request( # noqa: PLR0915 async_client = httpx.AsyncClient(timeout=600) + litellm_call_id = str(uuid.uuid4()) + # create logging object start_time = datetime.now() logging_obj = Logging( @@ -369,27 +376,20 @@ async def pass_through_request( # noqa: PLR0915 stream=False, call_type="pass_through_endpoint", start_time=start_time, - litellm_call_id=str(uuid.uuid4()), + litellm_call_id=litellm_call_id, function_id="1245", ) passthrough_logging_payload = PassthroughStandardLoggingPayload( url=str(url), request_body=_parsed_body, ) - + kwargs = _init_kwargs_for_pass_through_endpoint( + user_api_key_dict=user_api_key_dict, + _parsed_body=_parsed_body, + passthrough_logging_payload=passthrough_logging_payload, + litellm_call_id=litellm_call_id, + ) # done for supporting 'parallel_request_limiter.py' with pass-through endpoints - kwargs = { - "litellm_params": { - "metadata": { - "user_api_key": user_api_key_dict.api_key, - "user_api_key_user_id": user_api_key_dict.user_id, - "user_api_key_team_id": user_api_key_dict.team_id, - "user_api_key_end_user_id": user_api_key_dict.user_id, - } - }, - "call_type": "pass_through_endpoint", - "passthrough_logging_payload": passthrough_logging_payload, - } logging_obj.update_environment_variables( model="unknown", user="unknown", @@ -397,6 +397,7 @@ async def pass_through_request( # noqa: PLR0915 litellm_params=kwargs["litellm_params"], call_type="pass_through_endpoint", ) + logging_obj.model_call_details["litellm_call_id"] = litellm_call_id # combine url with query params for logging @@ -456,7 +457,10 @@ async def pass_through_request( # noqa: PLR0915 passthrough_success_handler_obj=pass_through_endpoint_logging, url_route=str(url), ), - headers=get_response_headers(response.headers), + headers=get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), status_code=response.status_code, ) @@ -496,7 +500,10 @@ async def pass_through_request( # noqa: PLR0915 passthrough_success_handler_obj=pass_through_endpoint_logging, url_route=str(url), ), - headers=get_response_headers(response.headers), + headers=get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), status_code=response.status_code, ) @@ -531,7 +538,10 @@ async def pass_through_request( # noqa: PLR0915 return Response( content=content, status_code=response.status_code, - headers=get_response_headers(response.headers), + headers=get_response_headers( + headers=response.headers, + litellm_call_id=litellm_call_id, + ), ) except Exception as e: verbose_proxy_logger.exception( @@ -556,6 +566,33 @@ async def pass_through_request( # noqa: PLR0915 ) +def _init_kwargs_for_pass_through_endpoint( + user_api_key_dict: UserAPIKeyAuth, + passthrough_logging_payload: PassthroughStandardLoggingPayload, + _parsed_body: Optional[dict] = None, + litellm_call_id: Optional[str] = None, +) -> dict: + _parsed_body = _parsed_body or {} + _litellm_metadata: Optional[dict] = _parsed_body.pop("litellm_metadata", None) + _metadata = { + "user_api_key": user_api_key_dict.api_key, + "user_api_key_user_id": user_api_key_dict.user_id, + "user_api_key_team_id": user_api_key_dict.team_id, + "user_api_key_end_user_id": user_api_key_dict.user_id, + } + if _litellm_metadata: + _metadata.update(_litellm_metadata) + kwargs = { + "litellm_params": { + "metadata": _metadata, + }, + "call_type": "pass_through_endpoint", + "litellm_call_id": litellm_call_id, + "passthrough_logging_payload": passthrough_logging_payload, + } + return kwargs + + def create_pass_through_route( endpoint, target: str, diff --git a/tests/pass_through_tests/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough.py index 1e599b735..b062a025a 100644 --- a/tests/pass_through_tests/test_anthropic_passthrough.py +++ b/tests/pass_through_tests/test_anthropic_passthrough.py @@ -5,6 +5,8 @@ This test ensures that the proxy can passthrough anthropic requests import pytest import anthropic +import aiohttp +import asyncio client = anthropic.Anthropic( base_url="http://0.0.0.0:4000/anthropic", api_key="sk-1234" @@ -17,6 +19,11 @@ def test_anthropic_basic_completion(): model="claude-3-5-sonnet-20241022", max_tokens=1024, messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}], + extra_body={ + "litellm_metadata": { + "tags": ["test-tag-1", "test-tag-2"], + } + }, ) print(response) @@ -31,9 +38,226 @@ def test_anthropic_streaming(): {"role": "user", "content": "Say 'hello stream test' and nothing else"} ], model="claude-3-5-sonnet-20241022", + extra_body={ + "litellm_metadata": { + "tags": ["test-tag-stream-1", "test-tag-stream-2"], + } + }, ) as stream: for text in stream.text_stream: collected_output.append(text) full_response = "".join(collected_output) print(full_response) + + +@pytest.mark.asyncio +async def test_anthropic_basic_completion_with_headers(): + print("making basic completion request to anthropic passthrough with aiohttp") + + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + "Anthropic-Version": "2023-06-01", + } + + payload = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 10, + "messages": [{"role": "user", "content": "Say 'hello test' and nothing else"}], + "litellm_metadata": { + "tags": ["test-tag-1", "test-tag-2"], + }, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + "http://0.0.0.0:4000/anthropic/v1/messages", json=payload, headers=headers + ) as response: + response_text = await response.text() + print(f"Response text: {response_text}") + + response_json = await response.json() + response_headers = response.headers + litellm_call_id = response_headers.get("x-litellm-call-id") + + print(f"LiteLLM Call ID: {litellm_call_id}") + + # Wait for spend to be logged + await asyncio.sleep(15) + + # Check spend logs for this specific request + async with session.get( + f"http://0.0.0.0:4000/spend/logs?request_id={litellm_call_id}", + headers={"Authorization": "Bearer sk-1234"}, + ) as spend_response: + print("text spend response") + print(f"Spend response: {spend_response}") + spend_data = await spend_response.json() + print(f"Spend data: {spend_data}") + assert spend_data is not None, "Should have spend data for the request" + + log_entry = spend_data[ + 0 + ] # Get the first (and should be only) log entry + + # Basic existence checks + assert spend_data is not None, "Should have spend data for the request" + assert isinstance(log_entry, dict), "Log entry should be a dictionary" + + # Request metadata assertions + assert ( + log_entry["request_id"] == litellm_call_id + ), "Request ID should match" + assert ( + log_entry["call_type"] == "pass_through_endpoint" + ), "Call type should be pass_through_endpoint" + assert ( + log_entry["api_base"] == "https://api.anthropic.com/v1/messages" + ), "API base should be Anthropic's endpoint" + + # Token and spend assertions + assert log_entry["spend"] > 0, "Spend value should not be None" + assert isinstance( + log_entry["spend"], (int, float) + ), "Spend should be a number" + assert log_entry["total_tokens"] > 0, "Should have some tokens" + assert log_entry["prompt_tokens"] > 0, "Should have prompt tokens" + assert ( + log_entry["completion_tokens"] > 0 + ), "Should have completion tokens" + assert ( + log_entry["total_tokens"] + == log_entry["prompt_tokens"] + log_entry["completion_tokens"] + ), "Total tokens should equal prompt + completion" + + # Time assertions + assert all( + key in log_entry + for key in ["startTime", "endTime", "completionStartTime"] + ), "Should have all time fields" + assert ( + log_entry["startTime"] < log_entry["endTime"] + ), "Start time should be before end time" + + # Metadata assertions + assert log_entry["cache_hit"] == "False", "Cache should be off" + assert log_entry["request_tags"] == [ + "test-tag-1", + "test-tag-2", + ], "Tags should match input" + assert ( + "user_api_key" in log_entry["metadata"] + ), "Should have user API key in metadata" + + assert "claude" in log_entry["model"] + + +@pytest.mark.asyncio +async def test_anthropic_streaming_with_headers(): + print("making streaming request to anthropic passthrough with aiohttp") + + headers = { + "Authorization": f"Bearer sk-1234", + "Content-Type": "application/json", + "Anthropic-Version": "2023-06-01", + } + + payload = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 10, + "messages": [ + {"role": "user", "content": "Say 'hello stream test' and nothing else"} + ], + "stream": True, + "litellm_metadata": { + "tags": ["test-tag-stream-1", "test-tag-stream-2"], + }, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + "http://0.0.0.0:4000/anthropic/v1/messages", json=payload, headers=headers + ) as response: + print("response status") + print(response.status) + assert response.status == 200, "Response should be successful" + response_headers = response.headers + print(f"Response headers: {response_headers}") + litellm_call_id = response_headers.get("x-litellm-call-id") + print(f"LiteLLM Call ID: {litellm_call_id}") + + collected_output = [] + async for line in response.content: + if line: + text = line.decode("utf-8").strip() + if text.startswith("data: "): + collected_output.append(text[6:]) # Remove 'data: ' prefix + + print("Collected output:", "".join(collected_output)) + + # Wait for spend to be logged + await asyncio.sleep(20) + + # Check spend logs for this specific request + async with session.get( + f"http://0.0.0.0:4000/spend/logs?request_id={litellm_call_id}", + headers={"Authorization": "Bearer sk-1234"}, + ) as spend_response: + spend_data = await spend_response.json() + print(f"Spend data: {spend_data}") + assert spend_data is not None, "Should have spend data for the request" + + log_entry = spend_data[ + 0 + ] # Get the first (and should be only) log entry + + # Basic existence checks + assert spend_data is not None, "Should have spend data for the request" + assert isinstance(log_entry, dict), "Log entry should be a dictionary" + + # Request metadata assertions + assert ( + log_entry["request_id"] == litellm_call_id + ), "Request ID should match" + assert ( + log_entry["call_type"] == "pass_through_endpoint" + ), "Call type should be pass_through_endpoint" + assert ( + log_entry["api_base"] == "https://api.anthropic.com/v1/messages" + ), "API base should be Anthropic's endpoint" + + # Token and spend assertions + assert log_entry["spend"] > 0, "Spend value should not be None" + assert isinstance( + log_entry["spend"], (int, float) + ), "Spend should be a number" + assert log_entry["total_tokens"] > 0, "Should have some tokens" + assert ( + log_entry["completion_tokens"] > 0 + ), "Should have completion tokens" + assert ( + log_entry["total_tokens"] + == log_entry["prompt_tokens"] + log_entry["completion_tokens"] + ), "Total tokens should equal prompt + completion" + + # Time assertions + assert all( + key in log_entry + for key in ["startTime", "endTime", "completionStartTime"] + ), "Should have all time fields" + assert ( + log_entry["startTime"] < log_entry["endTime"] + ), "Start time should be before end time" + + # Metadata assertions + assert log_entry["cache_hit"] == "False", "Cache should be off" + assert log_entry["request_tags"] == [ + "test-tag-stream-1", + "test-tag-stream-2", + ], "Tags should match input" + assert ( + "user_api_key" in log_entry["metadata"] + ), "Should have user API key in metadata" + + assert "claude" in log_entry["model"]