forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvement (11/14/2024) (#6730)
* fix(ollama.py): fix get model info request Fixes https://github.com/BerriAI/litellm/issues/6703 * feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param * docs(anthropic.md): document all supported openai params for anthropic * test: fix tests * fix: fix tests * feat(jina_ai/): add rerank support Closes https://github.com/BerriAI/litellm/issues/6691 * test: handle service unavailable error * fix(handler.py): refactor together ai rerank call * test: update test to handle overloaded error * test: fix test * Litellm router trace (#6742) * feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks * feat(router.py): log trace id across retry/fallback logic allows grouping llm logs for the same request * test: fix tests * fix: fix test * fix(transformation.py): only set non-none stop_sequences * Litellm router disable fallbacks (#6743) * bump: version 1.52.6 → 1.52.7 * feat(router.py): enable dynamically disabling fallbacks Allows for enabling/disabling fallbacks per key * feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key * test: fix test * fix(exception_mapping_utils.py): map 'model is overloaded' to internal server error * test: handle gemini error * test: fix test * fix: new run
This commit is contained in:
parent
f8e700064e
commit
e9aa492af3
35 changed files with 853 additions and 246 deletions
|
@ -13,8 +13,11 @@ sys.path.insert(
|
|||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
)
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
115
tests/llm_translation/base_rerank_unit_tests.py
Normal file
115
tests/llm_translation/base_rerank_unit_tests.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
)
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
def assert_response_shape(response, custom_llm_provider):
|
||||
expected_response_shape = {"id": str, "results": list, "meta": dict}
|
||||
|
||||
expected_results_shape = {"index": int, "relevance_score": float}
|
||||
|
||||
expected_meta_shape = {"api_version": dict, "billed_units": dict}
|
||||
|
||||
expected_api_version_shape = {"version": str}
|
||||
|
||||
expected_billed_units_shape = {"search_units": int}
|
||||
|
||||
assert isinstance(response.id, expected_response_shape["id"])
|
||||
assert isinstance(response.results, expected_response_shape["results"])
|
||||
for result in response.results:
|
||||
assert isinstance(result["index"], expected_results_shape["index"])
|
||||
assert isinstance(
|
||||
result["relevance_score"], expected_results_shape["relevance_score"]
|
||||
)
|
||||
assert isinstance(response.meta, expected_response_shape["meta"])
|
||||
|
||||
if custom_llm_provider == "cohere":
|
||||
|
||||
assert isinstance(
|
||||
response.meta["api_version"], expected_meta_shape["api_version"]
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["api_version"]["version"],
|
||||
expected_api_version_shape["version"],
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
||||
)
|
||||
assert isinstance(
|
||||
response.meta["billed_units"]["search_units"],
|
||||
expected_billed_units_shape["search_units"],
|
||||
)
|
||||
|
||||
|
||||
class BaseLLMRerankTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_rerank_call_args(self) -> dict:
|
||||
"""Must return the base rerank call args"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
"""Must return the custom llm provider"""
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_basic_rerank(self, sync_mode):
|
||||
rerank_call_args = self.get_base_rerank_call_args()
|
||||
custom_llm_provider = self.get_custom_llm_provider()
|
||||
if sync_mode is True:
|
||||
response = litellm.rerank(
|
||||
**rerank_call_args,
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(
|
||||
response=response, custom_llm_provider=custom_llm_provider.value
|
||||
)
|
||||
else:
|
||||
response = await litellm.arerank(
|
||||
**rerank_call_args,
|
||||
query="hello",
|
||||
documents=["hello", "world"],
|
||||
top_n=3,
|
||||
)
|
||||
|
||||
print("async re rank response: ", response)
|
||||
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
||||
assert_response_shape(
|
||||
response=response, custom_llm_provider=custom_llm_provider.value
|
||||
)
|
23
tests/llm_translation/test_jina_ai.py
Normal file
23
tests/llm_translation/test_jina_ai.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from base_rerank_unit_tests import BaseLLMRerankTest
|
||||
import litellm
|
||||
|
||||
|
||||
class TestJinaAI(BaseLLMRerankTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.JINA_AI
|
||||
|
||||
def get_base_rerank_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "jina_ai/jina-reranker-v2-base-multilingual",
|
||||
}
|
|
@ -921,3 +921,16 @@ def test_watsonx_text_top_k():
|
|||
)
|
||||
print(optional_params)
|
||||
assert optional_params["top_k"] == 10
|
||||
|
||||
|
||||
def test_forward_user_param():
|
||||
from litellm.utils import get_supported_openai_params, get_optional_params
|
||||
|
||||
model = "claude-3-5-sonnet-20240620"
|
||||
optional_params = get_optional_params(
|
||||
model=model,
|
||||
user="test_user",
|
||||
custom_llm_provider="anthropic",
|
||||
)
|
||||
|
||||
assert optional_params["metadata"]["user_id"] == "test_user"
|
||||
|
|
|
@ -679,6 +679,8 @@ async def test_anthropic_no_content_error():
|
|||
frequency_penalty=0.8,
|
||||
)
|
||||
|
||||
pass
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except litellm.APIError as e:
|
||||
assert e.status_code == 500
|
||||
|
|
|
@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
|
|||
print(f"standard_logging_object usage: {built_response.usage}")
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
|
||||
|
||||
def test_standard_logging_retries():
|
||||
"""
|
||||
know if a request was retried.
|
||||
"""
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.router import Router
|
||||
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-3.5-turbo",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
customHandler, "log_failure_event", new=MagicMock()
|
||||
) as mock_client:
|
||||
try:
|
||||
router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
num_retries=1,
|
||||
mock_response="litellm.RateLimitError",
|
||||
)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
assert mock_client.call_count == 2
|
||||
assert (
|
||||
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
|
||||
"trace_id"
|
||||
]
|
||||
is not None
|
||||
)
|
||||
assert (
|
||||
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
|
||||
"trace_id"
|
||||
]
|
||||
== mock_client.call_args_list[1].kwargs["kwargs"][
|
||||
"standard_logging_object"
|
||||
]["trace_id"]
|
||||
)
|
||||
|
|
|
@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
|
|||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="jina_ai/jina-embeddings-v3",
|
||||
)
|
||||
assert custom_llm_provider == "openai_like"
|
||||
assert custom_llm_provider == "jina_ai"
|
||||
assert api_base == "https://api.jina.ai/v1"
|
||||
assert model == "jina-embeddings-v3"
|
||||
|
||||
|
|
|
@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat():
|
|||
"template": "tools",
|
||||
}
|
||||
),
|
||||
):
|
||||
) as mock_client:
|
||||
info = OllamaConfig().get_model_info("mistral")
|
||||
print("info", info)
|
||||
assert info["supports_function_calling"] is True
|
||||
|
||||
info = get_model_info("ollama/mistral")
|
||||
print("info", info)
|
||||
|
||||
assert info["supports_function_calling"] is True
|
||||
|
||||
mock_client.assert_called()
|
||||
|
||||
print(mock_client.call_args.kwargs)
|
||||
|
||||
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"
|
||||
|
|
|
@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
|
|||
assert isinstance(
|
||||
exc_info.value, litellm.AuthenticationError
|
||||
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_disable_fallbacks_dynamically():
|
||||
from litellm.router import run_async_fallback
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "bad-model",
|
||||
"litellm_params": {
|
||||
"model": "openai/my-bad-model",
|
||||
"api_key": "my-bad-api-key",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "good-model",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"bad-model": ["good-model"]}],
|
||||
default_fallbacks=["good-model"],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
router,
|
||||
"log_retry",
|
||||
new=MagicMock(return_value=None),
|
||||
) as mock_client:
|
||||
try:
|
||||
resp = await router.acompletion(
|
||||
model="bad-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
disable_fallbacks=True,
|
||||
)
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_client.assert_not_called()
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -83,3 +84,93 @@ def test_returned_settings():
|
|||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
pytest.fail("An error occurred - " + traceback.format_exc())
|
||||
|
||||
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
|
||||
def test_update_kwargs_before_fallbacks_unit_test():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
|
||||
|
||||
router._update_kwargs_before_fallbacks(
|
||||
model="gpt-3.5-turbo",
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
assert kwargs["litellm_trace_id"] is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"call_type",
|
||||
[
|
||||
CallTypes.acompletion,
|
||||
CallTypes.atext_completion,
|
||||
CallTypes.aembedding,
|
||||
CallTypes.arerank,
|
||||
CallTypes.atranscription,
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_kwargs_before_fallbacks(call_type):
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
if call_type.value.startswith("a"):
|
||||
with patch.object(router, "async_function_with_fallbacks") as mock_client:
|
||||
if call_type.value == "acompletion":
|
||||
input_kwarg = {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
elif (
|
||||
call_type.value == "atext_completion"
|
||||
or call_type.value == "aimage_generation"
|
||||
):
|
||||
input_kwarg = {
|
||||
"prompt": "Hello, how are you?",
|
||||
}
|
||||
elif call_type.value == "aembedding" or call_type.value == "arerank":
|
||||
input_kwarg = {
|
||||
"input": "Hello, how are you?",
|
||||
}
|
||||
elif call_type.value == "atranscription":
|
||||
input_kwarg = {
|
||||
"file": "path/to/file",
|
||||
}
|
||||
else:
|
||||
input_kwarg = {}
|
||||
|
||||
await getattr(router, call_type.value)(
|
||||
model="gpt-3.5-turbo",
|
||||
**input_kwarg,
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
print(mock_client.call_args.kwargs)
|
||||
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
|
||||
|
|
|
@ -172,6 +172,8 @@ def test_stream_chunk_builder_litellm_usage_chunks():
|
|||
"""
|
||||
Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks
|
||||
"""
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Tell me the funniest joke you know."},
|
||||
{
|
||||
|
@ -182,24 +184,28 @@ def test_stream_chunk_builder_litellm_usage_chunks():
|
|||
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
|
||||
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
|
||||
]
|
||||
# make a regular gemini call
|
||||
response = completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
usage: litellm.Usage = response.usage
|
||||
usage: litellm.Usage = Usage(
|
||||
completion_tokens=27,
|
||||
prompt_tokens=55,
|
||||
total_tokens=82,
|
||||
completion_tokens_details=None,
|
||||
prompt_tokens_details=None,
|
||||
)
|
||||
|
||||
gemini_pt = usage.prompt_tokens
|
||||
|
||||
# make a streaming gemini call
|
||||
response = completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
complete_response=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
try:
|
||||
response = completion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
complete_response=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
except litellm.InternalServerError as e:
|
||||
pytest.skip(f"Skipping test due to internal server error - {str(e)}")
|
||||
|
||||
usage: litellm.Usage = response.usage
|
||||
|
||||
|
|
|
@ -736,6 +736,8 @@ async def test_acompletion_claude_2_stream():
|
|||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -3272,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming():
|
|||
], # "claude-3-opus-20240229"
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude_3_function_call_with_streaming(model):
|
||||
async def test_acompletion_function_call_with_streaming(model):
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
|
@ -3331,6 +3333,10 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
|
|||
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||
idx += 1
|
||||
# raise Exception("it worked! ")
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except litellm.ServiceUnavailableError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -188,7 +188,8 @@ def test_completion_claude_3_function_call_with_otel(model):
|
|||
)
|
||||
|
||||
print("response from LiteLLM", response)
|
||||
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
finally:
|
||||
|
|
|
@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
|
|||
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"disable_fallbacks_set",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
|
||||
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
|
||||
existing_data = {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
||||
}
|
||||
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
||||
key_metadata=key_metadata,
|
||||
data=existing_data,
|
||||
_metadata_variable_name="metadata",
|
||||
)
|
||||
|
||||
assert data["disable_fallbacks"] == disable_fallbacks_set
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue