test_pass_through_request_logging_failure_with_stream

This commit is contained in:
Ishaan Jaff 2024-11-25 16:51:53 -08:00
parent 068f1af120
commit 904ece6757

View file

@ -3,11 +3,13 @@ import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock from unittest.mock import AsyncMock, Mock, patch, MagicMock
from typing import Optional
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import fastapi
import httpx import httpx
import pytest import pytest
import litellm import litellm
@ -21,6 +23,9 @@ from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler, PassThroughStreamingHandler,
) )
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)
from fastapi import Request from fastapi import Request
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
@ -33,9 +38,21 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
@pytest.fixture @pytest.fixture
def mock_request(): def mock_request():
# Create a mock request with headers # Create a mock request with headers
class QueryParams:
def __init__(self):
self._dict = {}
class MockRequest: class MockRequest:
def __init__(self, headers=None): def __init__(
self, headers=None, method="POST", request_body: Optional[dict] = None
):
self.headers = headers or {} self.headers = headers or {}
self.query_params = QueryParams()
self.method = method
self.request_body = request_body or {}
async def body(self) -> bytes:
return bytes(json.dumps(self.request_body), "utf-8")
return MockRequest return MockRequest
@ -163,3 +180,85 @@ def test_init_kwargs_with_tags_in_header(mock_request, mock_user_api_key_dict):
metadata = result["litellm_params"]["metadata"] metadata = result["litellm_params"]["metadata"]
print("metadata", metadata) print("metadata", metadata)
assert metadata["tags"] == ["tag1", "tag2"] assert metadata["tags"] == ["tag1", "tag2"]
athropic_request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world tell me 2 sentences "}],
"litellm_metadata": {"tags": ["hi", "hello"]},
}
@pytest.mark.asyncio
async def test_pass_through_request_logging_failure(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""
# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")
# Patch only the logging handler
with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.PassThroughEndpointLogging.pass_through_async_success_handler",
new=mock_logging_failure,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)
# Assert response was returned successfully despite logging failure
assert response.status_code == 200
print("response", response)
print(vars(response))
@pytest.mark.asyncio
async def test_pass_through_request_logging_failure_with_stream(
mock_request, mock_user_api_key_dict
):
"""
Test that pass_through_request still returns a response even if logging raises an Exception
"""
# Mock the logging handler to raise an error
async def mock_logging_failure(*args, **kwargs):
raise Exception("Logging failed!")
athropic_request_body["stream"] = True
# Patch only the logging handler
with patch(
"litellm.proxy.pass_through_endpoints.streaming_handler.PassThroughStreamingHandler._route_streaming_logging_to_handler",
new=mock_logging_failure,
):
request = mock_request(
headers={}, method="POST", request_body=athropic_request_body
)
response = await pass_through_request(
request=request,
target="https://exampleopenaiendpoint-production.up.railway.app/v1/messages",
custom_headers={},
user_api_key_dict=mock_user_api_key_dict,
)
# Assert response was returned successfully despite logging failure
assert response.status_code == 200
print(vars(response))
print(dir(response))
body_iterator = response.body_iterator
async for chunk in body_iterator:
assert chunk
print("response", response)
print(vars(response))