forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (11/19/2024) (#6820)
* fix(anthropic/chat/transformation.py): add json schema as values: json_schema fixes passing pydantic obj to anthropic Fixes https://github.com/BerriAI/litellm/issues/6766 * (feat): Add timestamp_granularities parameter to transcription API (#6457) * Add timestamp_granularities parameter to transcription API * add param to the local test * fix(databricks/chat.py): handle max_retries optional param handling for openai-like calls Fixes issue with calling finetuned vertex ai models via databricks route * build(ui/): add team admins via proxy ui * fix: fix linting error * test: fix test * docs(vertex.md): refactor docs * test: handle overloaded anthropic model error * test: remove duplicate test * test: fix test * test: update test to handle model overloaded error --------- Co-authored-by: Show <35062952+BrunooShow@users.noreply.github.com>
This commit is contained in:
parent
7d0e1f05ac
commit
b0be5bf3a1
15 changed files with 200 additions and 193 deletions
|
@ -42,11 +42,14 @@ class BaseLLMChatTest(ABC):
|
|||
"content": [{"type": "text", "text": "Hello, how are you?"}],
|
||||
}
|
||||
]
|
||||
response = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
||||
try:
|
||||
response = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
|
||||
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||
assert response.choices[0].message.content is not None
|
||||
|
@ -89,6 +92,36 @@ class BaseLLMChatTest(ABC):
|
|||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_json_response_pydantic_obj(self):
|
||||
from pydantic import BaseModel
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
class TestModel(BaseModel):
|
||||
first_response: str
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_response_schema(base_completion_call_args["model"], None):
|
||||
pytest.skip("Model does not support response schema")
|
||||
|
||||
try:
|
||||
res = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?",
|
||||
},
|
||||
],
|
||||
response_format=TestModel,
|
||||
)
|
||||
assert res is not None
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Model is overloaded")
|
||||
|
||||
def test_json_response_format_stream(self):
|
||||
"""
|
||||
Test that the JSON response format with streaming is supported by the LLM API
|
||||
|
|
|
@ -657,7 +657,7 @@ def test_create_json_tool_call_for_response_format():
|
|||
_input_schema = tool.get("input_schema")
|
||||
assert _input_schema is not None
|
||||
assert _input_schema.get("type") == "object"
|
||||
assert _input_schema.get("properties") == custom_schema
|
||||
assert _input_schema.get("properties") == {"values": custom_schema}
|
||||
assert "additionalProperties" not in _input_schema
|
||||
|
||||
|
||||
|
|
|
@ -923,7 +923,6 @@ def test_watsonx_text_top_k():
|
|||
assert optional_params["top_k"] == 10
|
||||
|
||||
|
||||
|
||||
def test_together_ai_model_params():
|
||||
optional_params = get_optional_params(
|
||||
model="together_ai", custom_llm_provider="together_ai", logprobs=1
|
||||
|
@ -931,6 +930,7 @@ def test_together_ai_model_params():
|
|||
print(optional_params)
|
||||
assert optional_params["logprobs"] == 1
|
||||
|
||||
|
||||
def test_forward_user_param():
|
||||
from litellm.utils import get_supported_openai_params, get_optional_params
|
||||
|
||||
|
@ -943,6 +943,7 @@ def test_forward_user_param():
|
|||
|
||||
assert optional_params["metadata"]["user_id"] == "test_user"
|
||||
|
||||
|
||||
def test_lm_studio_embedding_params():
|
||||
optional_params = get_optional_params_embeddings(
|
||||
model="lm_studio/gemma2-9b-it",
|
||||
|
|
|
@ -3129,9 +3129,12 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
|
|||
assert all(isinstance(x, float) for x in embedding["embedding"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_retries", [None, 3])
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter):
|
||||
async def test_vertexai_model_garden_model_completion(
|
||||
respx_mock: MockRouter, max_retries
|
||||
):
|
||||
"""
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/6480
|
||||
|
||||
|
@ -3189,6 +3192,7 @@ async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter):
|
|||
messages=messages,
|
||||
vertex_project="633608382793",
|
||||
vertex_location="us-central1",
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
# Assert request was made correctly
|
||||
|
|
|
@ -1222,32 +1222,6 @@ def test_completion_mistral_api_modified_input():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_claude2_1():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("claude2.1 test request")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your goal is generate a joke on the topic user gives.",
|
||||
},
|
||||
{"role": "user", "content": "Generate a 3 liner joke for me"},
|
||||
]
|
||||
# test without max tokens
|
||||
response = completion(model="claude-2.1", messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
print(response.usage)
|
||||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_claude2_1()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude2_1():
|
||||
try:
|
||||
|
@ -1268,6 +1242,8 @@ async def test_acompletion_claude2_1():
|
|||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("model is overloaded.")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -4514,19 +4490,22 @@ async def test_dynamic_azure_params(stream, sync_mode):
|
|||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_completion_ai21_chat():
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.acompletion(
|
||||
model="jamba-1.5-large",
|
||||
user="ishaan",
|
||||
tool_choice="auto",
|
||||
seed=123,
|
||||
messages=[{"role": "user", "content": "what does the document say"}],
|
||||
documents=[
|
||||
{
|
||||
"content": "hello world",
|
||||
"metadata": {"source": "google", "author": "ishaan"},
|
||||
}
|
||||
],
|
||||
)
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="jamba-1.5-large",
|
||||
user="ishaan",
|
||||
tool_choice="auto",
|
||||
seed=123,
|
||||
messages=[{"role": "user", "content": "what does the document say"}],
|
||||
documents=[
|
||||
{
|
||||
"content": "hello world",
|
||||
"metadata": {"source": "google", "author": "ishaan"},
|
||||
}
|
||||
],
|
||||
)
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Model is overloaded")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -51,10 +51,15 @@ from litellm import Router
|
|||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("response_format", ["json", "vtt"])
|
||||
@pytest.mark.parametrize(
|
||||
"response_format, timestamp_granularities",
|
||||
[("json", None), ("vtt", None), ("verbose_json", ["word"])],
|
||||
)
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription(model, api_key, api_base, response_format, sync_mode):
|
||||
async def test_transcription(
|
||||
model, api_key, api_base, response_format, sync_mode, timestamp_granularities
|
||||
):
|
||||
if sync_mode:
|
||||
transcript = litellm.transcription(
|
||||
model=model,
|
||||
|
@ -62,6 +67,7 @@ async def test_transcription(model, api_key, api_base, response_format, sync_mod
|
|||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
response_format=response_format,
|
||||
timestamp_granularities=timestamp_granularities,
|
||||
drop_params=True,
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue