From fe60a38c8e43e908f44d8c668a5ba9fae1dca762 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 14 Jan 2025 21:49:25 -0800 Subject: [PATCH] Litellm dev 01 2025 p4 (#7776) * fix(gemini/): support gemini 'frequency_penalty' and 'presence_penalty' Closes https://github.com/BerriAI/litellm/issues/7748 * feat(proxy_server.py): new env var to disable prisma health check on startup * test: fix test --- litellm/llms/gemini/chat/transformation.py | 10 ++---- litellm/proxy/proxy_server.py | 6 +++- tests/llm_translation/test_optional_params.py | 21 +++++------- tests/proxy_unit_tests/test_proxy_utils.py | 32 +++++++++++++++++++ 4 files changed, 47 insertions(+), 22 deletions(-) diff --git a/litellm/llms/gemini/chat/transformation.py b/litellm/llms/gemini/chat/transformation.py index fb891ae0ef..313bb99af7 100644 --- a/litellm/llms/gemini/chat/transformation.py +++ b/litellm/llms/gemini/chat/transformation.py @@ -12,9 +12,7 @@ from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_his from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig -class GoogleAIStudioGeminiConfig( - VertexGeminiConfig -): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported +class GoogleAIStudioGeminiConfig(VertexGeminiConfig): """ Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig @@ -82,6 +80,7 @@ class GoogleAIStudioGeminiConfig( "n", "stop", "logprobs", + "frequency_penalty", ] def map_openai_params( @@ -92,11 +91,6 @@ class GoogleAIStudioGeminiConfig( drop_params: bool, ) -> Dict: - # drop frequency_penalty and presence_penalty - if "frequency_penalty" in non_default_params: - del non_default_params["frequency_penalty"] - if "presence_penalty" in non_default_params: - del non_default_params["presence_penalty"] if litellm.vertex_ai_safety_settings is not None: optional_params["safety_settings"] = litellm.vertex_ai_safety_settings return super().map_openai_params( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d9c3ea8760..c40bdb2b2e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3233,7 +3233,11 @@ class ProxyStartupEvent: ) # set the spend logs row count in proxy state. Don't block execution # run a health check to ensure the DB is ready - await prisma_client.health_check() + if ( + get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False) + is not True + ): + await prisma_client.health_check() return prisma_client @classmethod diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index dcb38f6241..75fe7aa5b1 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -217,19 +217,6 @@ def test_databricks_optional_params(): assert "user" not in optional_params -def test_gemini_optional_params(): - litellm.drop_params = True - optional_params = get_optional_params( - model="", - custom_llm_provider="gemini", - max_tokens=10, - frequency_penalty=10, - ) - print(f"optional_params: {optional_params}") - assert len(optional_params) == 1 - assert "frequency_penalty" not in optional_params - - def test_azure_ai_mistral_optional_params(): litellm.drop_params = True optional_params = get_optional_params( @@ -1063,6 +1050,7 @@ def test_is_vertex_anthropic_model(): is False ) + def test_groq_response_format_json_schema(): optional_params = get_optional_params( model="llama-3.1-70b-versatile", @@ -1072,3 +1060,10 @@ def test_groq_response_format_json_schema(): assert optional_params is not None assert "response_format" in optional_params assert optional_params["response_format"]["type"] == "json_object" + + +def test_gemini_frequency_penalty(): + optional_params = get_optional_params( + model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5 + ) + assert optional_params["frequency_penalty"] == 0.5 diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index bb71b2a1b7..6934521718 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1447,3 +1447,35 @@ def test_update_key_budget_with_temp_budget_increase(): }, ) assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200 + + +from unittest.mock import MagicMock, AsyncMock + + +@pytest.mark.asyncio +async def test_health_check_not_called_when_disabled(monkeypatch): + from litellm.proxy.proxy_server import ProxyStartupEvent + + # Mock environment variable + monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true") + + # Create mock prisma client + mock_prisma = MagicMock() + mock_prisma.connect = AsyncMock() + mock_prisma.health_check = AsyncMock() + mock_prisma.check_view_exists = AsyncMock() + mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock() + # Mock PrismaClient constructor + monkeypatch.setattr( + "litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma + ) + + # Call the setup function + await ProxyStartupEvent._setup_prisma_client( + database_url="mock_url", + proxy_logging_obj=MagicMock(), + user_api_key_cache=MagicMock(), + ) + + # Verify health check wasn't called + mock_prisma.health_check.assert_not_called()