forked from phoenix/litellm-mirror
(fix) pass through endpoints - run logging async + use thread pool executor for sync logging callbacks (#6907)
* run pass through logging async * fix use thread_pool_executor for pass through logging * test_pass_through_request_logging_failure_with_stream * fix anthropic pt logging test * test_pass_through_request_logging_failure
This commit is contained in:
parent
d52aae4e82
commit
552c0dd7a4
6 changed files with 201 additions and 33 deletions
|
@ -529,16 +529,18 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
response_body: Optional[dict] = get_response_body(response)
|
response_body: Optional[dict] = get_response_body(response)
|
||||||
passthrough_logging_payload["response_body"] = response_body
|
passthrough_logging_payload["response_body"] = response_body
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
await pass_through_endpoint_logging.pass_through_async_success_handler(
|
asyncio.create_task(
|
||||||
httpx_response=response,
|
pass_through_endpoint_logging.pass_through_async_success_handler(
|
||||||
response_body=response_body,
|
httpx_response=response,
|
||||||
url_route=str(url),
|
response_body=response_body,
|
||||||
result="",
|
url_route=str(url),
|
||||||
start_time=start_time,
|
result="",
|
||||||
end_time=end_time,
|
start_time=start_time,
|
||||||
logging_obj=logging_obj,
|
end_time=end_time,
|
||||||
cache_hit=False,
|
logging_obj=logging_obj,
|
||||||
**kwargs,
|
cache_hit=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
|
|
|
@ -58,15 +58,17 @@ class PassThroughStreamingHandler:
|
||||||
# After all chunks are processed, handle post-processing
|
# After all chunks are processed, handle post-processing
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
|
|
||||||
await PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
asyncio.create_task(
|
||||||
litellm_logging_obj=litellm_logging_obj,
|
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
litellm_logging_obj=litellm_logging_obj,
|
||||||
url_route=url_route,
|
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||||
request_body=request_body or {},
|
url_route=url_route,
|
||||||
endpoint_type=endpoint_type,
|
request_body=request_body or {},
|
||||||
start_time=start_time,
|
endpoint_type=endpoint_type,
|
||||||
raw_bytes=raw_bytes,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
raw_bytes=raw_bytes,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
|
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
|
||||||
|
@ -108,9 +110,9 @@ class PassThroughStreamingHandler:
|
||||||
all_chunks=all_chunks,
|
all_chunks=all_chunks,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
standard_logging_response_object = anthropic_passthrough_logging_handler_result[
|
standard_logging_response_object = (
|
||||||
"result"
|
anthropic_passthrough_logging_handler_result["result"]
|
||||||
]
|
)
|
||||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||||
vertex_passthrough_logging_handler_result = (
|
vertex_passthrough_logging_handler_result = (
|
||||||
|
@ -125,9 +127,9 @@ class PassThroughStreamingHandler:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
standard_logging_response_object = vertex_passthrough_logging_handler_result[
|
standard_logging_response_object = (
|
||||||
"result"
|
vertex_passthrough_logging_handler_result["result"]
|
||||||
]
|
)
|
||||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||||
|
|
||||||
if standard_logging_response_object is None:
|
if standard_logging_response_object is None:
|
||||||
|
@ -168,4 +170,4 @@ class PassThroughStreamingHandler:
|
||||||
# Split by newlines and filter out empty lines
|
# Split by newlines and filter out empty lines
|
||||||
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
|
||||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.types.utils import StandardPassThroughResponseObject
|
from litellm.types.utils import StandardPassThroughResponseObject
|
||||||
|
from litellm.utils import executor as thread_pool_executor
|
||||||
|
|
||||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||||
AnthropicPassthroughLoggingHandler,
|
AnthropicPassthroughLoggingHandler,
|
||||||
|
@ -93,15 +94,16 @@ class PassThroughEndpointLogging:
|
||||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||||
response=httpx_response.text
|
response=httpx_response.text
|
||||||
)
|
)
|
||||||
threading.Thread(
|
thread_pool_executor.submit(
|
||||||
target=logging_obj.success_handler,
|
logging_obj.success_handler,
|
||||||
args=(
|
args=(
|
||||||
standard_logging_response_object,
|
standard_logging_response_object,
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
cache_hit,
|
cache_hit,
|
||||||
),
|
),
|
||||||
).start()
|
)
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
await logging_obj.async_success_handler(
|
||||||
result=(
|
result=(
|
||||||
json.dumps(result)
|
json.dumps(result)
|
||||||
|
|
|
@ -21,4 +21,5 @@ router_settings:
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["prometheus"]
|
callbacks: ["prometheus"]
|
||||||
|
success_callback: ["langfuse"]
|
|
@ -141,7 +141,9 @@ async def test_anthropic_basic_completion_with_headers():
|
||||||
), "Start time should be before end time"
|
), "Start time should be before end time"
|
||||||
|
|
||||||
# Metadata assertions
|
# Metadata assertions
|
||||||
assert log_entry["cache_hit"] == "False", "Cache should be off"
|
assert (
|
||||||
|
str(log_entry["cache_hit"]).lower() != "true"
|
||||||
|
), "Cache should be off"
|
||||||
assert log_entry["request_tags"] == [
|
assert log_entry["request_tags"] == [
|
||||||
"test-tag-1",
|
"test-tag-1",
|
||||||
"test-tag-2",
|
"test-tag-2",
|
||||||
|
@ -251,7 +253,9 @@ async def test_anthropic_streaming_with_headers():
|
||||||
), "Start time should be before end time"
|
), "Start time should be before end time"
|
||||||
|
|
||||||
# Metadata assertions
|
# Metadata assertions
|
||||||
assert log_entry["cache_hit"] == "False", "Cache should be off"
|
assert (
|
||||||
|
str(log_entry["cache_hit"]).lower() != "true"
|
||||||
|
), "Cache should be off"
|
||||||
assert log_entry["request_tags"] == [
|
assert log_entry["request_tags"] == [
|
||||||
"test-tag-stream-1",
|
"test-tag-stream-1",
|
||||||
"test-tag-stream-2",
|
"test-tag-stream-2",
|
||||||
|
|
|
@ -3,11 +3,13 @@ import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import fastapi
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -21,6 +23,9 @@ from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||||
PassThroughStreamingHandler,
|
PassThroughStreamingHandler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
pass_through_request,
|
||||||
|
)
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
@ -33,9 +38,21 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_request():
|
def mock_request():
|
||||||
# Create a mock request with headers
|
# Create a mock request with headers
|
||||||
|
class QueryParams:
|
||||||
|
def __init__(self):
|
||||||
|
self._dict = {}
|
||||||
|
|
||||||
class MockRequest:
|
class MockRequest:
|
||||||
def __init__(self, headers=None):
|
def __init__(
|
||||||
|
self, headers=None, method="POST", request_body: Optional[dict] = None
|
||||||
|
):
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
|
self.query_params = QueryParams()
|
||||||
|
self.method = method
|
||||||
|
self.request_body = request_body or {}
|
||||||
|
|
||||||
|
async def body(self) -> bytes:
|
||||||
|
return bytes(json.dumps(self.request_body), "utf-8")
|
||||||
|
|
||||||
return MockRequest
|
return MockRequest
|
||||||
|
|
||||||
|
@ -163,3 +180,143 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
|
||||||
metadata = result["litellm_params"]["metadata"]
|
metadata = result["litellm_params"]["metadata"]
|
||||||
print("metadata", metadata)
|
print("metadata", metadata)
|
||||||
assert metadata["tags"] == ["tag1", "tag2"]
|
assert metadata["tags"] == ["tag1", "tag2"]
|
||||||
|
|
||||||
|
|
||||||
|
athropic_request_body = {
|
||||||
|
"model": "claude-3-5-sonnet-20241022",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}],
|
||||||
|
"litellm_metadata": {"tags": ["hi", "hello"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_request_logging_failure(
|
||||||
|
mock_request, mock_user_api_key_dict
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that pass_through_request still returns a response even if logging raises an Exception
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock the logging handler to raise an error
|
||||||
|
async def mock_logging_failure(*args, **kwargs):
|
||||||
|
raise Exception("Logging failed!")
|
||||||
|
|
||||||
|
# Create a mock response
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"content-type": "application/json"}
|
||||||
|
|
||||||
|
# Add mock content
|
||||||
|
mock_response._content = b'{"mock": "response"}'
|
||||||
|
|
||||||
|
async def mock_aread():
|
||||||
|
return mock_response._content
|
||||||
|
|
||||||
|
mock_response.aread = mock_aread
|
||||||
|
|
||||||
|
# Patch both the logging handler and the httpx client
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler",
|
||||||
|
new=mock_logging_failure,
|
||||||
|
), patch(
|
||||||
|
"httpx.AsyncClient.send",
|
||||||
|
return_value=mock_response,
|
||||||
|
), patch(
|
||||||
|
"httpx.AsyncClient.request",
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
request = mock_request(
|
||||||
|
headers={}, method="POST", request_body=athropic_request_body
|
||||||
|
)
|
||||||
|
response = await pass_through_request(
|
||||||
|
request=request,
|
||||||
|
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
|
||||||
|
custom_headers={},
|
||||||
|
user_api_key_dict=mock_user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert response was returned successfully despite logging failure
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify we got the mock response content
|
||||||
|
if hasattr(response, "body"):
|
||||||
|
content = response.body
|
||||||
|
else:
|
||||||
|
content = await response.aread()
|
||||||
|
|
||||||
|
assert content == b'{"mock": "response"}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_request_logging_failure_with_stream(
|
||||||
|
mock_request, mock_user_api_key_dict
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that pass_through_request still returns a response even if logging raises an Exception
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mock the logging handler to raise an error
|
||||||
|
async def mock_logging_failure(*args, **kwargs):
|
||||||
|
raise Exception("Logging failed!")
|
||||||
|
|
||||||
|
# Create a mock response
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
# Add headers property to mock response
|
||||||
|
mock_response.headers = {
|
||||||
|
"content-type": "application/json", # Not streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create mock chunks for streaming
|
||||||
|
mock_chunks = [b'{"chunk": 1}', b'{"chunk": 2}']
|
||||||
|
mock_response.body_iterator = AsyncMock()
|
||||||
|
mock_response.body_iterator.__aiter__.return_value = mock_chunks
|
||||||
|
|
||||||
|
# Add aread method to mock response
|
||||||
|
mock_response._content = b'{"mock": "response"}'
|
||||||
|
|
||||||
|
async def mock_aread():
|
||||||
|
return mock_response._content
|
||||||
|
|
||||||
|
mock_response.aread = mock_aread
|
||||||
|
|
||||||
|
# Patch both the logging handler and the httpx client
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler",
|
||||||
|
new=mock_logging_failure,
|
||||||
|
), patch(
|
||||||
|
"httpx.AsyncClient.send",
|
||||||
|
return_value=mock_response,
|
||||||
|
), patch(
|
||||||
|
"httpx.AsyncClient.request",
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
request = mock_request(
|
||||||
|
headers={}, method="POST", request_body=athropic_request_body
|
||||||
|
)
|
||||||
|
response = await pass_through_request(
|
||||||
|
request=request,
|
||||||
|
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
|
||||||
|
custom_headers={},
|
||||||
|
user_api_key_dict=mock_user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert response was returned successfully despite logging failure
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# For non-streaming responses, we can access the content directly
|
||||||
|
if hasattr(response, "body"):
|
||||||
|
content = response.body
|
||||||
|
else:
|
||||||
|
# For streaming responses, we need to read the chunks
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
chunks.append(chunk)
|
||||||
|
content = b"".join(chunks)
|
||||||
|
|
||||||
|
# Verify we got some response content
|
||||||
|
assert content is not None
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
assert len(content) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue