fix(databricks/chat.py): handle max_retries optional param handling for openai-like calls

Fixes issue with calling finetuned vertex ai models via databricks route
This commit is contained in:
Krrish Dholakia 2024-11-20 04:56:18 +05:30
parent e89dcccdd9
commit 07ba537970
3 changed files with 19 additions and 2 deletions

View file

@ -470,6 +470,9 @@ class DatabricksChatCompletion(BaseLLM):
optional_params[k] = v
stream: bool = optional_params.get("stream", None) or False
optional_params.pop(
"max_retries", None
) # [TODO] add max retry support at llm api call level
optional_params["stream"] = stream
data = {

View file

@ -923,7 +923,6 @@ def test_watsonx_text_top_k():
assert optional_params["top_k"] == 10
def test_together_ai_model_params():
optional_params = get_optional_params(
model="together_ai", custom_llm_provider="together_ai", logprobs=1
@ -931,6 +930,7 @@ def test_together_ai_model_params():
print(optional_params)
assert optional_params["logprobs"] == 1
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
@ -943,6 +943,7 @@ def test_forward_user_param():
assert optional_params["metadata"]["user_id"] == "test_user"
def test_lm_studio_embedding_params():
optional_params = get_optional_params_embeddings(
model="lm_studio/gemma2-9b-it",
@ -951,3 +952,12 @@ def test_lm_studio_embedding_params():
drop_params=True,
)
assert len(optional_params) == 0
def test_vertex_ft_models_optional_params():
optional_params = get_optional_params(
model="meta-llama/Llama-3.1-8B-Instruct",
custom_llm_provider="vertex_ai",
max_retries=3,
)
assert "max_retries" not in optional_params

View file

@ -3129,9 +3129,12 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
assert all(isinstance(x, float) for x in embedding["embedding"])
@pytest.mark.parametrize("max_retries", [None, 3])
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter):
async def test_vertexai_model_garden_model_completion(
respx_mock: MockRouter, max_retries
):
"""
Relevant issue: https://github.com/BerriAI/litellm/issues/6480
@ -3189,6 +3192,7 @@ async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter):
messages=messages,
vertex_project="633608382793",
vertex_location="us-central1",
max_retries=max_retries,
)
# Assert request was made correctly