test_pass_through_request_logging_failure_with_stream

This commit is contained in:
Ishaan Jaff 2024-11-25 16:51:53 -08:00
parent 068f1af120
commit 904ece6757

View file

@ -3,11 +3,13 @@ import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from typing import Optional
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import fastapi
import httpx
import pytest
import litellm
@ -21,6 +23,9 @@ from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)
from fastapi import Request
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
@ -33,9 +38,21 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
@pytest.fixture
def mock_request():
# Create a mock request with headers
class QueryParams:
def __init__(self):
self._dict = {}
class MockRequest:
def __init__(self, headers=None):
def __init__(
self, headers=None, method="POST", request_body: Optional[dict] = None
):
self.headers = headers or {}
self.query_params = QueryParams()
self.method = method
self.request_body = request_body or {}
async def body(self) -> bytes:
return bytes(json.dumps(self.request_body), "utf-8")
return MockRequest
@ -163,3 +180,85 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
metadata = result["litellm_params"]["metadata"]
print("metadata", metadata)
assert metadata["tags"] == ["tag1", "tag2"]
athropic_request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}],
"litellm_metadata": {"tags": ["hi", "hello"]},
}
@pytest.mark.asyncio
async def test_pass_through_request_logging_failure(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""
# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")
# Patch only the logging handler
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler",
new=mock_logging_failure,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)
# Assert response was returned successfully despite logging failure
assert response.status_code == 200
print("response", response)
print(vars(response))
@pytest.mark.asyncio
async def test_pass_through_request_logging_failure_with_stream(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""
# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")
athropic_request_body["stream"] = True
# Patch only the logging handler
with patch(
"litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler",
new=mock_logging_failure,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)
# Assert response was returned successfully despite logging failure
assert response.status_code == 200
print(vars(response))
print(dir(response))
body_iterator = response.body_iterator
async for chunk in body_iterator:
assert chunk
print("response", response)
print(vars(response))