forked from phoenix/litellm-mirror
add unit tests for vertex pass through
This commit is contained in:
parent
8aa18b3977
commit
0afdba0822
2 changed files with 84 additions and 135 deletions
|
@ -1,135 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
import litellm
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
||||||
|
|
||||||
# Import the class we're testing
|
|
||||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
|
||||||
AnthropicPassthroughLoggingHandler,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_response():
|
|
||||||
return {
|
|
||||||
"model": "claude-3-opus-20240229",
|
|
||||||
"content": [{"text": "Hello, world!", "type": "text"}],
|
|
||||||
"role": "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_httpx_response():
|
|
||||||
mock_resp = Mock(spec=httpx.Response)
|
|
||||||
mock_resp.json.return_value = {
|
|
||||||
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
|
|
||||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
|
||||||
"model": "claude-3-5-sonnet-20241022",
|
|
||||||
"role": "assistant",
|
|
||||||
"stop_reason": "end_turn",
|
|
||||||
"stop_sequence": None,
|
|
||||||
"type": "message",
|
|
||||||
"usage": {"input_tokens": 2095, "output_tokens": 503},
|
|
||||||
}
|
|
||||||
mock_resp.status_code = 200
|
|
||||||
mock_resp.headers = {"Content-Type": "application/json"}
|
|
||||||
return mock_resp
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_logging_obj():
|
|
||||||
logging_obj = LiteLLMLoggingObj(
|
|
||||||
model="claude-3-opus-20240229",
|
|
||||||
messages=[],
|
|
||||||
stream=False,
|
|
||||||
call_type="completion",
|
|
||||||
start_time=datetime.now(),
|
|
||||||
litellm_call_id="123",
|
|
||||||
function_id="456",
|
|
||||||
)
|
|
||||||
|
|
||||||
logging_obj.async_success_handler = AsyncMock()
|
|
||||||
return logging_obj
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_anthropic_passthrough_handler(
|
|
||||||
mock_httpx_response, mock_response, mock_logging_obj
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Unit test - Assert that the anthropic passthrough handler calls the litellm logging object's async_success_handler
|
|
||||||
"""
|
|
||||||
start_time = datetime.now()
|
|
||||||
end_time = datetime.now()
|
|
||||||
|
|
||||||
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
|
||||||
httpx_response=mock_httpx_response,
|
|
||||||
response_body=mock_response,
|
|
||||||
logging_obj=mock_logging_obj,
|
|
||||||
url_route="/v1/chat/completions",
|
|
||||||
result="success",
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
cache_hit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that async_success_handler was called
|
|
||||||
assert mock_logging_obj.async_success_handler.called
|
|
||||||
|
|
||||||
call_args = mock_logging_obj.async_success_handler.call_args
|
|
||||||
call_kwargs = call_args.kwargs
|
|
||||||
print("call_kwargs", call_kwargs)
|
|
||||||
|
|
||||||
# Assert required fields are present in call_kwargs
|
|
||||||
assert "result" in call_kwargs
|
|
||||||
assert "start_time" in call_kwargs
|
|
||||||
assert "end_time" in call_kwargs
|
|
||||||
assert "cache_hit" in call_kwargs
|
|
||||||
assert "response_cost" in call_kwargs
|
|
||||||
assert "model" in call_kwargs
|
|
||||||
assert "standard_logging_object" in call_kwargs
|
|
||||||
|
|
||||||
# Assert specific values and types
|
|
||||||
assert isinstance(call_kwargs["result"], litellm.ModelResponse)
|
|
||||||
assert isinstance(call_kwargs["start_time"], datetime)
|
|
||||||
assert isinstance(call_kwargs["end_time"], datetime)
|
|
||||||
assert isinstance(call_kwargs["cache_hit"], bool)
|
|
||||||
assert isinstance(call_kwargs["response_cost"], float)
|
|
||||||
assert call_kwargs["model"] == "claude-3-opus-20240229"
|
|
||||||
assert isinstance(call_kwargs["standard_logging_object"], dict)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_anthropic_response_logging_payload(mock_logging_obj):
|
|
||||||
# Test the logging payload creation
|
|
||||||
model_response = litellm.ModelResponse()
|
|
||||||
model_response.choices = [{"message": {"content": "Test response"}}]
|
|
||||||
|
|
||||||
start_time = datetime.now()
|
|
||||||
end_time = datetime.now()
|
|
||||||
|
|
||||||
result = (
|
|
||||||
AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
|
||||||
litellm_model_response=model_response,
|
|
||||||
model="claude-3-opus-20240229",
|
|
||||||
kwargs={},
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
logging_obj=mock_logging_obj,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "model" in result
|
|
||||||
assert "response_cost" in result
|
|
||||||
assert "standard_logging_object" in result
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
get_litellm_virtual_key,
|
||||||
|
vertex_proxy_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_litellm_virtual_key():
|
||||||
|
"""
|
||||||
|
Test that the get_litellm_virtual_key function correctly handles the API key authentication
|
||||||
|
"""
|
||||||
|
# Test with x-litellm-api-key
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||||
|
result = get_litellm_virtual_key(mock_request)
|
||||||
|
assert result == "Bearer test-key-123"
|
||||||
|
|
||||||
|
# Test with Authorization header
|
||||||
|
mock_request.headers = {"Authorization": "Bearer auth-key-456"}
|
||||||
|
result = get_litellm_virtual_key(mock_request)
|
||||||
|
assert result == "Bearer auth-key-456"
|
||||||
|
|
||||||
|
# Test with both headers (x-litellm-api-key should take precedence)
|
||||||
|
mock_request.headers = {
|
||||||
|
"x-litellm-api-key": "test-key-123",
|
||||||
|
"Authorization": "Bearer auth-key-456",
|
||||||
|
}
|
||||||
|
result = get_litellm_virtual_key(mock_request)
|
||||||
|
assert result == "Bearer test-key-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_proxy_route_api_key_auth():
|
||||||
|
"""
|
||||||
|
Critical
|
||||||
|
|
||||||
|
This is how Vertex AI JS SDK will Auth to Litellm Proxy
|
||||||
|
"""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||||
|
mock_request.method = "POST"
|
||||||
|
mock_response = Mock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth"
|
||||||
|
) as mock_auth:
|
||||||
|
mock_auth.return_value = {"api_key": "test-key-123"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route"
|
||||||
|
) as mock_pass_through:
|
||||||
|
mock_pass_through.return_value = AsyncMock(
|
||||||
|
return_value={"status": "success"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = await vertex_proxy_route(
|
||||||
|
endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent",
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user_api_key_auth was called with the correct Bearer token
|
||||||
|
mock_auth.assert_called_once()
|
||||||
|
call_args = mock_auth.call_args[1]
|
||||||
|
assert call_args["api_key"] == "Bearer test-key-123"
|
Loading…
Add table
Add a link
Reference in a new issue