mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm dev 01 2025 p4 (#7776)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
* fix(gemini/): support gemini 'frequency_penalty' and 'presence_penalty' Closes https://github.com/BerriAI/litellm/issues/7748 * feat(proxy_server.py): new env var to disable prisma health check on startup * test: fix test
This commit is contained in:
parent
8353caa485
commit
fe60a38c8e
4 changed files with 47 additions and 22 deletions
|
@ -12,9 +12,7 @@ from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_his
|
|||
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
||||
|
||||
|
||||
class GoogleAIStudioGeminiConfig(
|
||||
VertexGeminiConfig
|
||||
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
|
||||
class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
|
||||
"""
|
||||
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
||||
|
||||
|
@ -82,6 +80,7 @@ class GoogleAIStudioGeminiConfig(
|
|||
"n",
|
||||
"stop",
|
||||
"logprobs",
|
||||
"frequency_penalty",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
|
@ -92,11 +91,6 @@ class GoogleAIStudioGeminiConfig(
|
|||
drop_params: bool,
|
||||
) -> Dict:
|
||||
|
||||
# drop frequency_penalty and presence_penalty
|
||||
if "frequency_penalty" in non_default_params:
|
||||
del non_default_params["frequency_penalty"]
|
||||
if "presence_penalty" in non_default_params:
|
||||
del non_default_params["presence_penalty"]
|
||||
if litellm.vertex_ai_safety_settings is not None:
|
||||
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||
return super().map_openai_params(
|
||||
|
|
|
@ -3233,6 +3233,10 @@ class ProxyStartupEvent:
|
|||
) # set the spend logs row count in proxy state. Don't block execution
|
||||
|
||||
# run a health check to ensure the DB is ready
|
||||
if (
|
||||
get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False)
|
||||
is not True
|
||||
):
|
||||
await prisma_client.health_check()
|
||||
return prisma_client
|
||||
|
||||
|
|
|
@ -217,19 +217,6 @@ def test_databricks_optional_params():
|
|||
assert "user" not in optional_params
|
||||
|
||||
|
||||
def test_gemini_optional_params():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params(
|
||||
model="",
|
||||
custom_llm_provider="gemini",
|
||||
max_tokens=10,
|
||||
frequency_penalty=10,
|
||||
)
|
||||
print(f"optional_params: {optional_params}")
|
||||
assert len(optional_params) == 1
|
||||
assert "frequency_penalty" not in optional_params
|
||||
|
||||
|
||||
def test_azure_ai_mistral_optional_params():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params(
|
||||
|
@ -1063,6 +1050,7 @@ def test_is_vertex_anthropic_model():
|
|||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_groq_response_format_json_schema():
|
||||
optional_params = get_optional_params(
|
||||
model="llama-3.1-70b-versatile",
|
||||
|
@ -1072,3 +1060,10 @@ def test_groq_response_format_json_schema():
|
|||
assert optional_params is not None
|
||||
assert "response_format" in optional_params
|
||||
assert optional_params["response_format"]["type"] == "json_object"
|
||||
|
||||
|
||||
def test_gemini_frequency_penalty():
|
||||
optional_params = get_optional_params(
|
||||
model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5
|
||||
)
|
||||
assert optional_params["frequency_penalty"] == 0.5
|
||||
|
|
|
@ -1447,3 +1447,35 @@ def test_update_key_budget_with_temp_budget_increase():
|
|||
},
|
||||
)
|
||||
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
|
||||
|
||||
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_not_called_when_disabled(monkeypatch):
|
||||
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||||
|
||||
# Mock environment variable
|
||||
monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true")
|
||||
|
||||
# Create mock prisma client
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.connect = AsyncMock()
|
||||
mock_prisma.health_check = AsyncMock()
|
||||
mock_prisma.check_view_exists = AsyncMock()
|
||||
mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock()
|
||||
# Mock PrismaClient constructor
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
|
||||
)
|
||||
|
||||
# Call the setup function
|
||||
await ProxyStartupEvent._setup_prisma_client(
|
||||
database_url="mock_url",
|
||||
proxy_logging_obj=MagicMock(),
|
||||
user_api_key_cache=MagicMock(),
|
||||
)
|
||||
|
||||
# Verify health check wasn't called
|
||||
mock_prisma.health_check.assert_not_called()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue