From 07ba5379709f3a9f6e46ee7015c3b57d3a140bbc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 20 Nov 2024 04:56:18 +0530 Subject: [PATCH] fix(databricks/chat.py): handle max_retries optional param handling for openai-like calls Fixes issue with calling finetuned vertex ai models via databricks route --- litellm/llms/databricks/chat.py | 3 +++ tests/llm_translation/test_optional_params.py | 12 +++++++++++- .../local_testing/test_amazing_vertex_completion.py | 6 +++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat.py index eb0cb341e..79e885646 100644 --- a/litellm/llms/databricks/chat.py +++ b/litellm/llms/databricks/chat.py @@ -470,6 +470,9 @@ class DatabricksChatCompletion(BaseLLM): optional_params[k] = v stream: bool = optional_params.get("stream", None) or False + optional_params.pop( + "max_retries", None + ) # [TODO] add max retry support at llm api call level optional_params["stream"] = stream data = { diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 029e91513..33e708ff4 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -923,7 +923,6 @@ def test_watsonx_text_top_k(): assert optional_params["top_k"] == 10 - def test_together_ai_model_params(): optional_params = get_optional_params( model="together_ai", custom_llm_provider="together_ai", logprobs=1 @@ -931,6 +930,7 @@ def test_together_ai_model_params(): print(optional_params) assert optional_params["logprobs"] == 1 + def test_forward_user_param(): from litellm.utils import get_supported_openai_params, get_optional_params @@ -943,6 +943,7 @@ def test_forward_user_param(): assert optional_params["metadata"]["user_id"] == "test_user" + def test_lm_studio_embedding_params(): optional_params = get_optional_params_embeddings( model="lm_studio/gemma2-9b-it", @@ -951,3 +952,12 @@ def test_lm_studio_embedding_params(): drop_params=True, ) assert len(optional_params) == 0 + + +def test_vertex_ft_models_optional_params(): + optional_params = get_optional_params( + model="meta-llama/Llama-3.1-8B-Instruct", + custom_llm_provider="vertex_ai", + max_retries=3, + ) + assert "max_retries" not in optional_params diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 3bf36dda8..f801a53ce 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -3129,9 +3129,12 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter): assert all(isinstance(x, float) for x in embedding["embedding"]) +@pytest.mark.parametrize("max_retries", [None, 3]) @pytest.mark.asyncio @pytest.mark.respx -async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter): +async def test_vertexai_model_garden_model_completion( + respx_mock: MockRouter, max_retries +): """ Relevant issue: https://github.com/BerriAI/litellm/issues/6480 @@ -3189,6 +3192,7 @@ async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter): messages=messages, vertex_project="633608382793", vertex_location="us-central1", + max_retries=max_retries, ) # Assert request was made correctly