Litellm dev 01 2025 p4 (#7776)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s

* fix(gemini/): support gemini 'frequency_penalty' and 'presence_penalty'

Closes https://github.com/BerriAI/litellm/issues/7748

* feat(proxy_server.py): new env var to disable prisma health check on startup

* test: fix test
This commit is contained in:
Krish Dholakia 2025-01-14 21:49:25 -08:00 committed by GitHub
parent 8353caa485
commit fe60a38c8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 47 additions and 22 deletions

View file

@ -12,9 +12,7 @@ from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_his
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
"""
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
@ -82,6 +80,7 @@ class GoogleAIStudioGeminiConfig(
"n",
"stop",
"logprobs",
"frequency_penalty",
]
def map_openai_params(
@ -92,11 +91,6 @@ class GoogleAIStudioGeminiConfig(
drop_params: bool,
) -> Dict:
# drop frequency_penalty and presence_penalty
if "frequency_penalty" in non_default_params:
del non_default_params["frequency_penalty"]
if "presence_penalty" in non_default_params:
del non_default_params["presence_penalty"]
if litellm.vertex_ai_safety_settings is not None:
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
return super().map_openai_params(

View file

@ -3233,6 +3233,10 @@ class ProxyStartupEvent:
) # set the spend logs row count in proxy state. Don't block execution
# run a health check to ensure the DB is ready
if (
get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False)
is not True
):
await prisma_client.health_check()
return prisma_client

View file

@ -217,19 +217,6 @@ def test_databricks_optional_params():
assert "user" not in optional_params
def test_gemini_optional_params():
litellm.drop_params = True
optional_params = get_optional_params(
model="",
custom_llm_provider="gemini",
max_tokens=10,
frequency_penalty=10,
)
print(f"optional_params: {optional_params}")
assert len(optional_params) == 1
assert "frequency_penalty" not in optional_params
def test_azure_ai_mistral_optional_params():
litellm.drop_params = True
optional_params = get_optional_params(
@ -1063,6 +1050,7 @@ def test_is_vertex_anthropic_model():
is False
)
def test_groq_response_format_json_schema():
optional_params = get_optional_params(
model="llama-3.1-70b-versatile",
@ -1072,3 +1060,10 @@ def test_groq_response_format_json_schema():
assert optional_params is not None
assert "response_format" in optional_params
assert optional_params["response_format"]["type"] == "json_object"
def test_gemini_frequency_penalty():
optional_params = get_optional_params(
model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5
)
assert optional_params["frequency_penalty"] == 0.5

View file

@ -1447,3 +1447,35 @@ def test_update_key_budget_with_temp_budget_increase():
},
)
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
from unittest.mock import MagicMock, AsyncMock
@pytest.mark.asyncio
async def test_health_check_not_called_when_disabled(monkeypatch):
from litellm.proxy.proxy_server import ProxyStartupEvent
# Mock environment variable
monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true")
# Create mock prisma client
mock_prisma = MagicMock()
mock_prisma.connect = AsyncMock()
mock_prisma.health_check = AsyncMock()
mock_prisma.check_view_exists = AsyncMock()
mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock()
# Mock PrismaClient constructor
monkeypatch.setattr(
"litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
)
# Call the setup function
await ProxyStartupEvent._setup_prisma_client(
database_url="mock_url",
proxy_logging_obj=MagicMock(),
user_api_key_cache=MagicMock(),
)
# Verify health check wasn't called
mock_prisma.health_check.assert_not_called()