[Feat] Pass through endpoints - ensure PassthroughStandardLoggingPayload is logged and contains method, url, request/response body (#10194)

* ensure passthrough_logging_payload is filled in kwargs

* test_assistants_passthrough_logging

* test_assistants_passthrough_logging

* test_assistants_passthrough_logging

* test_threads_passthrough_logging

* test _init_kwargs_for_pass_through_endpoint

* _init_kwargs_for_pass_through_endpoint
This commit is contained in:
Ishaan Jaff 2025-04-21 19:46:22 -07:00 committed by GitHub
parent 8cf4042161
commit 9314c633ed
11 changed files with 244 additions and 18 deletions

View file

@ -1,6 +1,6 @@
"""
- call /messages on Anthropic API
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
- Ensure requests are logged in the DB - stream + non-stream
"""
@ -43,7 +43,9 @@ class AnthropicMessagesHandler:
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
from litellm.proxy.pass_through_endpoints.types import EndpointType
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
)
# Create success handler object
passthrough_success_handler_obj = PassThroughEndpointLogging()
@ -98,11 +100,11 @@ async def anthropic_messages(
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
anthropic_messages_provider_config: Optional[
BaseAnthropicMessagesConfig
] = ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
)
if anthropic_messages_provider_config is None:
raise ValueError(

View file

@ -13,7 +13,9 @@ from litellm.llms.anthropic.chat.handler import (
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import ModelResponse, TextCompletionResponse
if TYPE_CHECKING:
@ -122,9 +124,9 @@ class AnthropicPassthroughLoggingHandler:
litellm_model_response.id = logging_obj.litellm_call_id
litellm_model_response.model = model
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details[
"custom_llm_provider"
] = litellm.LlmProviders.ANTHROPIC.value
logging_obj.model_call_details["custom_llm_provider"] = (
litellm.LlmProviders.ANTHROPIC.value
)
return kwargs
except Exception as e:
verbose_proxy_logger.exception(

View file

@ -14,11 +14,13 @@ from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.passthrough_endpoints.assembly_ai import (
ASSEMBLY_AI_MAX_POLLING_ATTEMPTS,
ASSEMBLY_AI_POLLING_INTERVAL,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
class AssemblyAITranscriptResponse(TypedDict, total=False):

View file

@ -13,7 +13,9 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
if TYPE_CHECKING:

View file

@ -23,6 +23,7 @@ from starlette.datastructures import UploadFile as StarletteUploadFile
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import (
@ -38,11 +39,14 @@ from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessin
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import StandardLoggingUserAPIKeyMetadata
from .streaming_handler import PassThroughStreamingHandler
from .success_handler import PassThroughEndpointLogging
from .types import EndpointType, PassthroughStandardLoggingPayload
router = APIRouter()
@ -530,6 +534,7 @@ async def pass_through_request( # noqa: PLR0915
passthrough_logging_payload = PassthroughStandardLoggingPayload(
url=str(url),
request_body=_parsed_body,
request_method=getattr(request, "method", None),
)
kwargs = _init_kwargs_for_pass_through_endpoint(
user_api_key_dict=user_api_key_dict,
@ -537,6 +542,7 @@ async def pass_through_request( # noqa: PLR0915
passthrough_logging_payload=passthrough_logging_payload,
litellm_call_id=litellm_call_id,
request=request,
logging_obj=logging_obj,
)
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
logging_obj.update_environment_variables(
@ -741,6 +747,7 @@ def _init_kwargs_for_pass_through_endpoint(
request: Request,
user_api_key_dict: UserAPIKeyAuth,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
logging_obj: LiteLLMLoggingObj,
_parsed_body: Optional[dict] = None,
litellm_call_id: Optional[str] = None,
) -> dict:
@ -775,6 +782,11 @@ def _init_kwargs_for_pass_through_endpoint(
"litellm_call_id": litellm_call_id,
"passthrough_logging_payload": passthrough_logging_payload,
}
logging_obj.model_call_details["passthrough_logging_payload"] = (
passthrough_logging_payload
)
return kwargs

View file

@ -8,6 +8,7 @@ import httpx
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from litellm.types.utils import StandardPassThroughResponseObject
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
@ -17,7 +18,6 @@ from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
from .success_handler import PassThroughEndpointLogging
from .types import EndpointType
class PassThroughStreamingHandler:

View file

@ -7,6 +7,9 @@ import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor
@ -92,11 +95,15 @@ class PassThroughEndpointLogging:
end_time: datetime,
cache_hit: bool,
request_body: dict,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
**kwargs,
):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
logging_obj.model_call_details["passthrough_logging_payload"] = (
passthrough_logging_payload
)
if self.is_vertex_route(url_route):
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler.vertex_passthrough_handler(

View file

@ -14,5 +14,21 @@ class PassthroughStandardLoggingPayload(TypedDict, total=False):
"""
url: str
"""
The full url of the request
"""
request_method: Optional[str]
"""
The method of the request
"GET", "POST", "PUT", "DELETE", etc.
"""
request_body: Optional[dict]
"""
The body of the request
"""
response_body: Optional[dict] # only tracked for non-streaming responses
"""
The body of the response
"""

View file

@ -0,0 +1,155 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from typing import Optional
from fastapi import Request
import pytest
import asyncio
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.passthrough_endpoints.pass_through_endpoints import PassthroughStandardLoggingPayload
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import pass_through_request
class TestCustomLogger(CustomLogger):
def __init__(self):
self.logged_kwargs: Optional[dict] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("in async log success event kwargs", json.dumps(kwargs, indent=4, default=str))
self.logged_kwargs = kwargs
@pytest.mark.asyncio
async def test_assistants_passthrough_logging():
test_custom_logger = TestCustomLogger()
litellm._async_success_callback = [test_custom_logger]
TARGET_URL = "https://api.openai.com/v1/assistants"
REQUEST_BODY = {
"instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
"name": "Math Tutor",
"tools": [{"type": "code_interpreter"}],
"model": "gpt-4o"
}
TARGET_METHOD = "POST"
result = await pass_through_request(
request=Request(
scope={
"type": "http",
"method": TARGET_METHOD,
"path": "/v1/assistants",
"query_string": b"",
"headers": [
(b"content-type", b"application/json"),
(b"authorization", f"Bearer {os.getenv('OPENAI_API_KEY')}".encode()),
(b"openai-beta", b"assistants=v2")
]
},
),
target=TARGET_URL,
custom_headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
"OpenAI-Beta": "assistants=v2"
},
user_api_key_dict=UserAPIKeyAuth(
api_key="test",
user_id="test",
team_id="test",
end_user_id="test",
),
custom_body=REQUEST_BODY,
forward_headers=False,
merge_query_params=False,
)
print("got result", result)
print("result status code", result.status_code)
print("result content", result.body)
await asyncio.sleep(1)
assert test_custom_logger.logged_kwargs is not None
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = test_custom_logger.logged_kwargs["passthrough_logging_payload"]
assert passthrough_logging_payload is not None
assert passthrough_logging_payload["url"] == TARGET_URL
assert passthrough_logging_payload["request_body"] == REQUEST_BODY
# assert that the response body content matches the response body content
client_facing_response_body = json.loads(result.body)
assert passthrough_logging_payload["response_body"] == client_facing_response_body
# assert that the request method is correct
assert passthrough_logging_payload["request_method"] == TARGET_METHOD
@pytest.mark.asyncio
async def test_threads_passthrough_logging():
test_custom_logger = TestCustomLogger()
litellm._async_success_callback = [test_custom_logger]
TARGET_URL = "https://api.openai.com/v1/threads"
REQUEST_BODY = {}
TARGET_METHOD = "POST"
result = await pass_through_request(
request=Request(
scope={
"type": "http",
"method": TARGET_METHOD,
"path": "/v1/threads",
"query_string": b"",
"headers": [
(b"content-type", b"application/json"),
(b"authorization", f"Bearer {os.getenv('OPENAI_API_KEY')}".encode()),
(b"openai-beta", b"assistants=v2")
]
},
),
target=TARGET_URL,
custom_headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
"OpenAI-Beta": "assistants=v2"
},
user_api_key_dict=UserAPIKeyAuth(
api_key="test",
user_id="test",
team_id="test",
end_user_id="test",
),
custom_body=REQUEST_BODY,
forward_headers=False,
merge_query_params=False,
)
print("got result", result)
print("result status code", result.status_code)
print("result content", result.body)
await asyncio.sleep(1)
assert test_custom_logger.logged_kwargs is not None
passthrough_logging_payload = test_custom_logger.logged_kwargs["passthrough_logging_payload"]
assert passthrough_logging_payload is not None
# Fix for TypedDict access errors
assert passthrough_logging_payload.get("url") == TARGET_URL
assert passthrough_logging_payload.get("request_body") == REQUEST_BODY
# Fix for json.loads error with potential memoryview
response_body = result.body
client_facing_response_body = json.loads(response_body)
assert passthrough_logging_payload.get("response_body") == client_facing_response_body
assert passthrough_logging_payload.get("request_method") == TARGET_METHOD

View file

@ -17,7 +17,7 @@ import pytest
import litellm
from typing import AsyncGenerator
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy.pass_through_endpoints.types import EndpointType
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
@ -34,7 +34,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
_init_kwargs_for_pass_through_endpoint,
_update_metadata_with_tags_in_header,
)
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.passthrough_endpoints.pass_through_endpoints import PassthroughStandardLoggingPayload
@pytest.fixture
@ -115,6 +115,15 @@ def test_init_kwargs_for_pass_through_endpoint_basic(
user_api_key_dict=mock_user_api_key_dict,
passthrough_logging_payload=passthrough_payload,
litellm_call_id="test-call-id",
logging_obj=LiteLLMLoggingObj(
model="test-model",
messages=[],
stream=False,
call_type="test-call-type",
start_time=datetime.now(),
litellm_call_id="test-call-id",
function_id="test-function-id",
),
)
assert result["call_type"] == "pass_through_endpoint"
@ -158,6 +167,15 @@ def test_init_kwargs_with_litellm_metadata(mock_request, mock_user_api_key_dict)
passthrough_logging_payload=passthrough_payload,
_parsed_body=parsed_body,
litellm_call_id="test-call-id",
logging_obj=LiteLLMLoggingObj(
model="test-model",
messages=[],
stream=False,
call_type="test-call-type",
start_time=datetime.now(),
litellm_call_id="test-call-id",
function_id="test-function-id",
),
)
# Check that litellm_metadata was merged with default metadata
@ -183,6 +201,15 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
user_api_key_dict=mock_user_api_key_dict,
passthrough_logging_payload=passthrough_payload,
litellm_call_id="test-call-id",
logging_obj=LiteLLMLoggingObj(
model="test-model",
messages=[],
stream=False,
call_type="test-call-type",
start_time=datetime.now(),
litellm_call_id="test-call-id",
function_id="test-function-id",
),
)
# Check that tags were added to metadata

View file

@ -13,7 +13,8 @@ import pytest
import litellm
from typing import AsyncGenerator
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy.pass_through_endpoints.types import EndpointType
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from litellm.types.passthrough_endpoints.pass_through_endpoints import PassthroughStandardLoggingPayload
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)