forked from phoenix/litellm-mirror
* feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. * fix(utils.py): allow disabling end user cost tracking with new param Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small * docs(configs.md): add disable_end_user_cost_tracking reference to docs * feat(key_management_endpoints.py): add support for restricting access to `/key/generate` by team/proxy level role Enables admin to restrict key creation, and assign team admins to handle distributing keys * test(test_key_management.py): add unit testing for personal / team key restriction checks * docs: add docs on restricting key creation * docs(finetuned_models.md): add new guide on calling finetuned models * docs(input.md): cleanup anthropic supported params Closes https://github.com/BerriAI/litellm/issues/6856 * test(test_embedding.py): add test for passing extra headers via embedding * feat(cohere/embed): pass client to async embedding * feat(rerank.py): add `/v1/rerank` if missing for cohere base url Closes https://github.com/BerriAI/litellm/issues/6844 * fix(main.py): pass extra_headers param to openai Fixes https://github.com/BerriAI/litellm/issues/6836 * fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically * fix(handler.py): fix linting error * fix: fix typing * build: add conftest to proxy_admin_ui_tests/ * test: fix test * fix: fix linting errors * test: fix test * fix: fix pass through testing
119 lines
4.1 KiB
Python
119 lines
4.1 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
|
|
import httpx
|
|
import pytest
|
|
import litellm
|
|
from typing import AsyncGenerator
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.proxy.pass_through_endpoints.types import EndpointType
|
|
from litellm.proxy.pass_through_endpoints.success_handler import (
|
|
PassThroughEndpointLogging,
|
|
)
|
|
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
|
PassThroughStreamingHandler,
|
|
)
|
|
|
|
|
|
# Helper function to mock async iteration
|
|
async def aiter_mock(iterable):
|
|
for item in iterable:
|
|
yield item
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"endpoint_type,url_route",
|
|
[
|
|
(
|
|
EndpointType.VERTEX_AI,
|
|
"v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent",
|
|
),
|
|
(EndpointType.ANTHROPIC, "/v1/messages"),
|
|
],
|
|
)
|
|
async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route):
|
|
"""
|
|
Test that the chunk_processor yields raw bytes
|
|
|
|
This is CRITICAL for pass throughs streaming with Vertex AI and Anthropic
|
|
"""
|
|
# Mock inputs
|
|
response = AsyncMock(spec=httpx.Response)
|
|
raw_chunks = [
|
|
b'{"id": "1", "content": "Hello"}',
|
|
b'{"id": "2", "content": "World"}',
|
|
b'\n\ndata: {"id": "3"}', # Testing different byte formats
|
|
]
|
|
|
|
# Mock aiter_bytes to return an async generator
|
|
async def mock_aiter_bytes():
|
|
for chunk in raw_chunks:
|
|
yield chunk
|
|
|
|
response.aiter_bytes = mock_aiter_bytes
|
|
|
|
request_body = {"key": "value"}
|
|
litellm_logging_obj = MagicMock()
|
|
start_time = datetime.now()
|
|
passthrough_success_handler_obj = MagicMock()
|
|
litellm_logging_obj.async_success_handler = AsyncMock()
|
|
|
|
# Capture yielded chunks and perform detailed assertions
|
|
received_chunks = []
|
|
async for chunk in PassThroughStreamingHandler.chunk_processor(
|
|
response=response,
|
|
request_body=request_body,
|
|
litellm_logging_obj=litellm_logging_obj,
|
|
endpoint_type=endpoint_type,
|
|
start_time=start_time,
|
|
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
|
url_route=url_route,
|
|
):
|
|
# Assert each chunk is bytes
|
|
assert isinstance(chunk, bytes), f"Chunk should be bytes, got {type(chunk)}"
|
|
# Assert no decoding/encoding occurred (chunk should be exactly as input)
|
|
assert (
|
|
chunk in raw_chunks
|
|
), f"Chunk {chunk} was modified during processing. For pass throughs streaming, chunks should be raw bytes"
|
|
received_chunks.append(chunk)
|
|
|
|
# Assert all chunks were processed
|
|
assert len(received_chunks) == len(raw_chunks), "Not all chunks were processed"
|
|
|
|
# collected chunks all together
|
|
assert b"".join(received_chunks) == b"".join(
|
|
raw_chunks
|
|
), "Collected chunks do not match raw chunks"
|
|
|
|
|
|
def test_convert_raw_bytes_to_str_lines():
|
|
"""
|
|
Test that the _convert_raw_bytes_to_str_lines method correctly converts raw bytes to a list of strings
|
|
"""
|
|
# Test case 1: Single chunk
|
|
raw_bytes = [b'data: {"content": "Hello"}\n']
|
|
result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes)
|
|
assert result == ['data: {"content": "Hello"}']
|
|
|
|
# Test case 2: Multiple chunks
|
|
raw_bytes = [b'data: {"content": "Hello"}\n', b'data: {"content": "World"}\n']
|
|
result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes)
|
|
assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}']
|
|
|
|
# Test case 3: Empty input
|
|
raw_bytes = []
|
|
result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes)
|
|
assert result == []
|
|
|
|
# Test case 4: Chunks with empty lines
|
|
raw_bytes = [b'data: {"content": "Hello"}\n\n', b'\ndata: {"content": "World"}\n']
|
|
result = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(raw_bytes)
|
|
assert result == ['data: {"content": "Hello"}', 'data: {"content": "World"}']
|