Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Ishaan Jaff
6715905994 test_pass_through_request_logging_failure 2024-11-25 21:51:00 -08:00
Ishaan Jaff
f12141be44 fix anthropic pt logging test 2024-11-25 21:44:12 -08:00
Ishaan Jaff
904ece6757 test_pass_through_request_logging_failure_with_stream 2024-11-25 16:51:53 -08:00
Ishaan Jaff
068f1af120 fix use thread_pool_executor for pass through logging 2024-11-25 16:13:06 -08:00
Ishaan Jaff
68408c4d77 run pass through logging async 2024-11-25 13:17:25 -08:00
6 changed files with 201 additions and 33 deletions

View file

@ -529,16 +529,18 @@ async def pass_through_request( # noqa: PLR0915
response_body: Optional[dict] = get_response_body(response) response_body: Optional[dict] = get_response_body(response)
passthrough_logging_payload["response_body"] = response_body passthrough_logging_payload["response_body"] = response_body
end_time = datetime.now() end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler( asyncio.create_task(
httpx_response=response, pass_through_endpoint_logging.pass_through_async_success_handler(
response_body=response_body, httpx_response=response,
url_route=str(url), response_body=response_body,
result="", url_route=str(url),
start_time=start_time, result="",
end_time=end_time, start_time=start_time,
logging_obj=logging_obj, end_time=end_time,
cache_hit=False, logging_obj=logging_obj,
**kwargs, cache_hit=False,
**kwargs,
)
) )
return Response( return Response(

View file

@ -58,15 +58,17 @@ class PassThroughStreamingHandler:
# After all chunks are processed, handle post-processing # After all chunks are processed, handle post-processing
end_time = datetime.now() end_time = datetime.now()
await PassThroughStreamingHandler._route_streaming_logging_to_handler( asyncio.create_task(
litellm_logging_obj=litellm_logging_obj, PassThroughStreamingHandler._route_streaming_logging_to_handler(
passthrough_success_handler_obj=passthrough_success_handler_obj, litellm_logging_obj=litellm_logging_obj,
url_route=url_route, passthrough_success_handler_obj=passthrough_success_handler_obj,
request_body=request_body or {}, url_route=url_route,
endpoint_type=endpoint_type, request_body=request_body or {},
start_time=start_time, endpoint_type=endpoint_type,
raw_bytes=raw_bytes, start_time=start_time,
end_time=end_time, raw_bytes=raw_bytes,
end_time=end_time,
)
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
@ -108,9 +110,9 @@ class PassThroughStreamingHandler:
all_chunks=all_chunks, all_chunks=all_chunks,
end_time=end_time, end_time=end_time,
) )
standard_logging_response_object = anthropic_passthrough_logging_handler_result[ standard_logging_response_object = (
"result" anthropic_passthrough_logging_handler_result["result"]
] )
kwargs = anthropic_passthrough_logging_handler_result["kwargs"] kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI: elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = ( vertex_passthrough_logging_handler_result = (
@ -125,9 +127,9 @@ class PassThroughStreamingHandler:
end_time=end_time, end_time=end_time,
) )
) )
standard_logging_response_object = vertex_passthrough_logging_handler_result[ standard_logging_response_object = (
"result" vertex_passthrough_logging_handler_result["result"]
] )
kwargs = vertex_passthrough_logging_handler_result["kwargs"] kwargs = vertex_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None: if standard_logging_response_object is None:
@ -168,4 +170,4 @@ class PassThroughStreamingHandler:
# Split by newlines and filter out empty lines # Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()] lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
return lines return lines

View file

@ -18,6 +18,7 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
from litellm.proxy._types import PassThroughEndpointLoggingResultValues from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import StandardPassThroughResponseObject from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler, AnthropicPassthroughLoggingHandler,
@ -93,15 +94,16 @@ class PassThroughEndpointLogging:
standard_logging_response_object = StandardPassThroughResponseObject( standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text response=httpx_response.text
) )
threading.Thread( thread_pool_executor.submit(
target=logging_obj.success_handler, logging_obj.success_handler,
args=( args=(
standard_logging_response_object, standard_logging_response_object,
start_time, start_time,
end_time, end_time,
cache_hit, cache_hit,
), ),
).start() )
await logging_obj.async_success_handler( await logging_obj.async_success_handler(
result=( result=(
json.dumps(result) json.dumps(result)

View file

@ -21,4 +21,5 @@ router_settings:
redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
litellm_settings: litellm_settings:
callbacks: ["prometheus"] callbacks: ["prometheus"]
success_callback: ["langfuse"]

View file

@ -141,7 +141,9 @@ async def test_anthropic_basic_completion_with_headers():
), "Start time should be before end time" ), "Start time should be before end time"
# Metadata assertions # Metadata assertions
assert log_entry["cache_hit"] == "False", "Cache should be off" assert (
str(log_entry["cache_hit"]).lower() != "true"
), "Cache should be off"
assert log_entry["request_tags"] == [ assert log_entry["request_tags"] == [
"test-tag-1", "test-tag-1",
"test-tag-2", "test-tag-2",
@ -251,7 +253,9 @@ async def test_anthropic_streaming_with_headers():
), "Start time should be before end time" ), "Start time should be before end time"
# Metadata assertions # Metadata assertions
assert log_entry["cache_hit"] == "False", "Cache should be off" assert (
str(log_entry["cache_hit"]).lower() != "true"
), "Cache should be off"
assert log_entry["request_tags"] == [ assert log_entry["request_tags"] == [
"test-tag-stream-1", "test-tag-stream-1",
"test-tag-stream-2", "test-tag-stream-2",

View file

@ -3,11 +3,13 @@ import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock from unittest.mock import AsyncMock, Mock, patch, MagicMock
from typing import Optional
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import fastapi
import httpx import httpx
import pytest import pytest
import litellm import litellm
@ -21,6 +23,9 @@ from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler, PassThroughStreamingHandler,
) )
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)
from fastapi import Request from fastapi import Request
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
@ -33,9 +38,21 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
@pytest.fixture @pytest.fixture
def mock_request(): def mock_request():
# Create a mock request with headers # Create a mock request with headers
class QueryParams:
def __init__(self):
self._dict = {}
class MockRequest: class MockRequest:
def __init__(self, headers=None): def __init__(
self, headers=None, method="POST", request_body: Optional[dict] = None
):
self.headers = headers or {} self.headers = headers or {}
self.query_params = QueryParams()
self.method = method
self.request_body = request_body or {}
async def body(self) -> bytes:
return bytes(json.dumps(self.request_body), "utf-8")
return MockRequest return MockRequest
@ -163,3 +180,143 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
metadata = result["litellm_params"]["metadata"] metadata = result["litellm_params"]["metadata"]
print("metadata", metadata) print("metadata", metadata)
assert metadata["tags"] == ["tag1", "tag2"] assert metadata["tags"] == ["tag1", "tag2"]
athropic_request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}],
"litellm_metadata": {"tags": ["hi", "hello"]},
}
@pytest.mark.asyncio
async def test_pass_through_request_logging_failure(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""
# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")
# Create a mock response
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
# Add mock content
mock_response._content = b'{"mock": "response"}'
async def mock_aread():
return mock_response._content
mock_response.aread = mock_aread
# Patch both the logging handler and the httpx client
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler",
new=mock_logging_failure,
), patch(
"httpx.AsyncClient.send",
return_value=mock_response,
), patch(
"httpx.AsyncClient.request",
return_value=mock_response,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)
# Assert response was returned successfully despite logging failure
assert response.status_code == 200
# Verify we got the mock response content
if hasattr(response, "body"):
content = response.body
else:
content = await response.aread()
assert content == b'{"mock": "response"}'
@pytest.mark.asyncio
async def test_pass_through_request_logging_failure_with_stream(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""
# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")
# Create a mock response
mock_response = AsyncMock()
mock_response.status_code = 200
# Add headers property to mock response
mock_response.headers = {
"content-type": "application/json", # Not streaming
}
# Create mock chunks for streaming
mock_chunks = [b'{"chunk": 1}', b'{"chunk": 2}']
mock_response.body_iterator = AsyncMock()
mock_response.body_iterator.__aiter__.return_value = mock_chunks
# Add aread method to mock response
mock_response._content = b'{"mock": "response"}'
async def mock_aread():
return mock_response._content
mock_response.aread = mock_aread
# Patch both the logging handler and the httpx client
with patch(
"litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler",
new=mock_logging_failure,
), patch(
"httpx.AsyncClient.send",
return_value=mock_response,
), patch(
"httpx.AsyncClient.request",
return_value=mock_response,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)
# Assert response was returned successfully despite logging failure
assert response.status_code == 200
# For non-streaming responses, we can access the content directly
if hasattr(response, "body"):
content = response.body
else:
# For streaming responses, we need to read the chunks
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk)
content = b"".join(chunks)
# Verify we got some response content
assert content is not None
if isinstance(content, bytes):
assert len(content) > 0