forked from phoenix/litellm-mirror
* build(model_prices_and_context_window.json): add bedrock llama3.2 pricing
* build(model_prices_and_context_window.json): add bedrock cross region inference pricing
* Revert "(perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165)"
This reverts commit 2a5624af47
.
* add azure/gpt-4o-2024-05-13 (#6174)
* LiteLLM Minor Fixes & Improvements (10/10/2024) (#6158)
* refactor(vertex_ai_partner_models/anthropic): refactor anthropic to use partner model logic
* fix(vertex_ai/): support passing custom api base to partner models
Fixes https://github.com/BerriAI/litellm/issues/4317
* fix(proxy_server.py): Fix prometheus premium user check logic
* docs(prometheus.md): update quick start docs
* fix(custom_llm.py): support passing dynamic api key + api base
* fix(realtime_api/main.py): Add request/response logging for realtime api endpoints
Closes https://github.com/BerriAI/litellm/issues/6081
* feat(openai/realtime): add openai realtime api logging
Closes https://github.com/BerriAI/litellm/issues/6081
* fix(realtime_streaming.py): fix linting errors
* fix(realtime_streaming.py): fix linting errors
* fix: fix linting errors
* fix pattern match router
* Add literalai in the sidebar observability category (#6163)
* fix: add literalai in the sidebar
* fix: typo
* update (#6160)
* Feat: Add Langtrace integration (#5341)
* Feat: Add Langtrace integration
* add langtrace service name
* fix timestamps for traces
* add tests
* Discard Callback + use existing otel logger
* cleanup
* remove print statments
* remove callback
* add docs
* docs
* add logging docs
* format logging
* remove emoji and add litellm proxy example
* format logging
* format `logging.md`
* add langtrace docs to logging.md
* sync conflict
* docs fix
* (perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165)
* fix move s3 to use customLogger
* add basic s3 logging test
* add s3 to custom logger compatible
* use batch logger for s3
* s3 set flush interval and batch size
* fix s3 logging
* add notes on s3 logging
* fix s3 logging
* add basic s3 logging test
* fix s3 type errors
* add test for sync logging on s3
* fix: fix to debug log
---------
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Willy Douhard <willy.douhard@gmail.com>
Co-authored-by: yujonglee <yujonglee.dev@gmail.com>
Co-authored-by: Ali Waleed <ali@scale3labs.com>
* docs(custom_llm_server.md): update doc on passing custom params
* fix(pass_through_endpoints.py): don't require headers
Fixes https://github.com/BerriAI/litellm/issues/6128
* feat(utils.py): add support for caching rerank endpoints
Closes https://github.com/BerriAI/litellm/issues/6144
* feat(litellm_logging.py'): add response headers for failed requests
Closes https://github.com/BerriAI/litellm/issues/6159
---------
Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Willy Douhard <willy.douhard@gmail.com>
Co-authored-by: yujonglee <yujonglee.dev@gmail.com>
Co-authored-by: Ali Waleed <ali@scale3labs.com>
399 lines
11 KiB
Python
399 lines
11 KiB
Python
# What is this?
|
|
## Unit tests for the CustomLLM class
|
|
|
|
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
|
|
import openai
|
|
import pytest
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import os
|
|
from collections import defaultdict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
AsyncIterator,
|
|
Callable,
|
|
Coroutine,
|
|
Iterator,
|
|
Optional,
|
|
Union,
|
|
)
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import httpx
|
|
from dotenv import load_dotenv
|
|
|
|
import litellm
|
|
from litellm import (
|
|
ChatCompletionDeltaChunk,
|
|
ChatCompletionUsageBlock,
|
|
CustomLLM,
|
|
GenericStreamingChunk,
|
|
ModelResponse,
|
|
acompletion,
|
|
completion,
|
|
get_llm_provider,
|
|
image_generation,
|
|
)
|
|
from litellm.utils import ModelResponseIterator
|
|
from litellm.types.utils import ImageResponse, ImageObject
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
|
|
|
|
class CustomModelResponseIterator:
|
|
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
|
|
self.streaming_response = streaming_response
|
|
|
|
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
|
|
return GenericStreamingChunk(
|
|
text="hello world",
|
|
tool_use=None,
|
|
is_finished=True,
|
|
finish_reason="stop",
|
|
usage=ChatCompletionUsageBlock(
|
|
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
|
),
|
|
index=0,
|
|
)
|
|
|
|
# Sync iterator
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self) -> GenericStreamingChunk:
|
|
try:
|
|
chunk: Any = self.streaming_response.__next__() # type: ignore
|
|
except StopIteration:
|
|
raise StopIteration
|
|
except ValueError as e:
|
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
|
|
|
try:
|
|
return self.chunk_parser(chunk=chunk)
|
|
except StopIteration:
|
|
raise StopIteration
|
|
except ValueError as e:
|
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
|
|
# Async iterator
|
|
def __aiter__(self):
|
|
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
|
return self.streaming_response
|
|
|
|
async def __anext__(self) -> GenericStreamingChunk:
|
|
try:
|
|
chunk = await self.async_response_iterator.__anext__()
|
|
except StopAsyncIteration:
|
|
raise StopAsyncIteration
|
|
except ValueError as e:
|
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
|
|
|
try:
|
|
return self.chunk_parser(chunk=chunk)
|
|
except StopIteration:
|
|
raise StopIteration
|
|
except ValueError as e:
|
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
|
|
|
|
class MyCustomLLM(CustomLLM):
|
|
def completion(
|
|
self,
|
|
model: str,
|
|
messages: list,
|
|
api_base: str,
|
|
custom_prompt_dict: dict,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable[..., Any],
|
|
encoding,
|
|
api_key,
|
|
logging_obj,
|
|
optional_params: dict,
|
|
acompletion=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
headers={},
|
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
|
client: Optional[litellm.HTTPHandler] = None,
|
|
) -> ModelResponse:
|
|
return litellm.completion(
|
|
model="gpt-3.5-turbo",
|
|
messages=[{"role": "user", "content": "Hello world"}],
|
|
mock_response="Hi!",
|
|
) # type: ignore
|
|
|
|
async def acompletion(
|
|
self,
|
|
model: str,
|
|
messages: list,
|
|
api_base: str,
|
|
custom_prompt_dict: dict,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable[..., Any],
|
|
encoding,
|
|
api_key,
|
|
logging_obj,
|
|
optional_params: dict,
|
|
acompletion=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
headers={},
|
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
|
client: Optional[litellm.AsyncHTTPHandler] = None,
|
|
) -> litellm.ModelResponse:
|
|
return litellm.completion(
|
|
model="gpt-3.5-turbo",
|
|
messages=[{"role": "user", "content": "Hello world"}],
|
|
mock_response="Hi!",
|
|
) # type: ignore
|
|
|
|
def streaming(
|
|
self,
|
|
model: str,
|
|
messages: list,
|
|
api_base: str,
|
|
custom_prompt_dict: dict,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable[..., Any],
|
|
encoding,
|
|
api_key,
|
|
logging_obj,
|
|
optional_params: dict,
|
|
acompletion=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
headers={},
|
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
|
client: Optional[litellm.HTTPHandler] = None,
|
|
) -> Iterator[GenericStreamingChunk]:
|
|
generic_streaming_chunk: GenericStreamingChunk = {
|
|
"finish_reason": "stop",
|
|
"index": 0,
|
|
"is_finished": True,
|
|
"text": "Hello world",
|
|
"tool_use": None,
|
|
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
completion_stream = ModelResponseIterator(
|
|
model_response=generic_streaming_chunk # type: ignore
|
|
)
|
|
custom_iterator = CustomModelResponseIterator(
|
|
streaming_response=completion_stream
|
|
)
|
|
return custom_iterator
|
|
|
|
async def astreaming( # type: ignore
|
|
self,
|
|
model: str,
|
|
messages: list,
|
|
api_base: str,
|
|
custom_prompt_dict: dict,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable[..., Any],
|
|
encoding,
|
|
api_key,
|
|
logging_obj,
|
|
optional_params: dict,
|
|
acompletion=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
headers={},
|
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
|
client: Optional[litellm.AsyncHTTPHandler] = None,
|
|
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
|
generic_streaming_chunk: GenericStreamingChunk = {
|
|
"finish_reason": "stop",
|
|
"index": 0,
|
|
"is_finished": True,
|
|
"text": "Hello world",
|
|
"tool_use": None,
|
|
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
|
}
|
|
|
|
yield generic_streaming_chunk # type: ignore
|
|
|
|
def image_generation(
|
|
self,
|
|
model: str,
|
|
prompt: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
model_response: ImageResponse,
|
|
optional_params: dict,
|
|
logging_obj: Any,
|
|
timeout=None,
|
|
client: Optional[HTTPHandler] = None,
|
|
):
|
|
return ImageResponse(
|
|
created=int(time.time()),
|
|
data=[ImageObject(url="https://example.com/image.png")],
|
|
response_ms=1000,
|
|
)
|
|
|
|
async def aimage_generation(
|
|
self,
|
|
model: str,
|
|
prompt: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
model_response: ImageResponse,
|
|
optional_params: dict,
|
|
logging_obj: Any,
|
|
timeout=None,
|
|
client: Optional[AsyncHTTPHandler] = None,
|
|
):
|
|
return ImageResponse(
|
|
created=int(time.time()),
|
|
data=[ImageObject(url="https://example.com/image.png")],
|
|
response_ms=1000,
|
|
)
|
|
|
|
|
|
def test_get_llm_provider():
|
|
""""""
|
|
from litellm.utils import custom_llm_setup
|
|
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
|
|
custom_llm_setup()
|
|
|
|
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
|
|
|
|
assert provider == "custom_llm"
|
|
|
|
|
|
def test_simple_completion():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
resp = completion(
|
|
model="custom_llm/my-fake-model",
|
|
messages=[{"role": "user", "content": "Hello world!"}],
|
|
)
|
|
|
|
assert resp.choices[0].message.content == "Hi!"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_acompletion():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
resp = await acompletion(
|
|
model="custom_llm/my-fake-model",
|
|
messages=[{"role": "user", "content": "Hello world!"}],
|
|
)
|
|
|
|
assert resp.choices[0].message.content == "Hi!"
|
|
|
|
|
|
def test_simple_completion_streaming():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
resp = completion(
|
|
model="custom_llm/my-fake-model",
|
|
messages=[{"role": "user", "content": "Hello world!"}],
|
|
stream=True,
|
|
)
|
|
|
|
for chunk in resp:
|
|
print(chunk)
|
|
if chunk.choices[0].finish_reason is None:
|
|
assert isinstance(chunk.choices[0].delta.content, str)
|
|
else:
|
|
assert chunk.choices[0].finish_reason == "stop"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_completion_async_streaming():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
resp = await litellm.acompletion(
|
|
model="custom_llm/my-fake-model",
|
|
messages=[{"role": "user", "content": "Hello world!"}],
|
|
stream=True,
|
|
)
|
|
|
|
async for chunk in resp:
|
|
print(chunk)
|
|
if chunk.choices[0].finish_reason is None:
|
|
assert isinstance(chunk.choices[0].delta.content, str)
|
|
else:
|
|
assert chunk.choices[0].finish_reason == "stop"
|
|
|
|
|
|
def test_simple_image_generation():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
resp = image_generation(
|
|
model="custom_llm/my-fake-model",
|
|
prompt="Hello world",
|
|
)
|
|
|
|
print(resp)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_image_generation_async():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
resp = await litellm.aimage_generation(
|
|
model="custom_llm/my-fake-model",
|
|
prompt="Hello world",
|
|
)
|
|
|
|
print(resp)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_image_generation_async_additional_params():
|
|
my_custom_llm = MyCustomLLM()
|
|
litellm.custom_provider_map = [
|
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
|
]
|
|
|
|
with patch.object(
|
|
my_custom_llm, "aimage_generation", new=AsyncMock()
|
|
) as mock_client:
|
|
try:
|
|
resp = await litellm.aimage_generation(
|
|
model="custom_llm/my-fake-model",
|
|
prompt="Hello world",
|
|
api_key="my-api-key",
|
|
api_base="my-api-base",
|
|
my_custom_param="my-custom-param",
|
|
)
|
|
|
|
print(resp)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
mock_client.assert_awaited_once()
|
|
|
|
mock_client.call_args.kwargs["api_key"] == "my-api-key"
|
|
mock_client.call_args.kwargs["api_base"] == "my-api-base"
|
|
mock_client.call_args.kwargs["optional_params"] == {
|
|
"my_custom_param": "my-custom-param"
|
|
}
|