mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 01 2025 p4 (#7776)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
* fix(gemini/): support gemini 'frequency_penalty' and 'presence_penalty' Closes https://github.com/BerriAI/litellm/issues/7748 * feat(proxy_server.py): new env var to disable prisma health check on startup * test: fix test
This commit is contained in:
parent
8353caa485
commit
fe60a38c8e
4 changed files with 47 additions and 22 deletions
|
@ -12,9 +12,7 @@ from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_his
|
||||||
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
||||||
|
|
||||||
|
|
||||||
class GoogleAIStudioGeminiConfig(
|
class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
|
||||||
VertexGeminiConfig
|
|
||||||
): # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
|
|
||||||
"""
|
"""
|
||||||
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
||||||
|
|
||||||
|
@ -82,6 +80,7 @@ class GoogleAIStudioGeminiConfig(
|
||||||
"n",
|
"n",
|
||||||
"stop",
|
"stop",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
|
"frequency_penalty",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
|
@ -92,11 +91,6 @@ class GoogleAIStudioGeminiConfig(
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
||||||
# drop frequency_penalty and presence_penalty
|
|
||||||
if "frequency_penalty" in non_default_params:
|
|
||||||
del non_default_params["frequency_penalty"]
|
|
||||||
if "presence_penalty" in non_default_params:
|
|
||||||
del non_default_params["presence_penalty"]
|
|
||||||
if litellm.vertex_ai_safety_settings is not None:
|
if litellm.vertex_ai_safety_settings is not None:
|
||||||
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||||
return super().map_openai_params(
|
return super().map_openai_params(
|
||||||
|
|
|
@ -3233,6 +3233,10 @@ class ProxyStartupEvent:
|
||||||
) # set the spend logs row count in proxy state. Don't block execution
|
) # set the spend logs row count in proxy state. Don't block execution
|
||||||
|
|
||||||
# run a health check to ensure the DB is ready
|
# run a health check to ensure the DB is ready
|
||||||
|
if (
|
||||||
|
get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
await prisma_client.health_check()
|
await prisma_client.health_check()
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
|
|
|
@ -217,19 +217,6 @@ def test_databricks_optional_params():
|
||||||
assert "user" not in optional_params
|
assert "user" not in optional_params
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_optional_params():
|
|
||||||
litellm.drop_params = True
|
|
||||||
optional_params = get_optional_params(
|
|
||||||
model="",
|
|
||||||
custom_llm_provider="gemini",
|
|
||||||
max_tokens=10,
|
|
||||||
frequency_penalty=10,
|
|
||||||
)
|
|
||||||
print(f"optional_params: {optional_params}")
|
|
||||||
assert len(optional_params) == 1
|
|
||||||
assert "frequency_penalty" not in optional_params
|
|
||||||
|
|
||||||
|
|
||||||
def test_azure_ai_mistral_optional_params():
|
def test_azure_ai_mistral_optional_params():
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
|
@ -1063,6 +1050,7 @@ def test_is_vertex_anthropic_model():
|
||||||
is False
|
is False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_groq_response_format_json_schema():
|
def test_groq_response_format_json_schema():
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
model="llama-3.1-70b-versatile",
|
model="llama-3.1-70b-versatile",
|
||||||
|
@ -1072,3 +1060,10 @@ def test_groq_response_format_json_schema():
|
||||||
assert optional_params is not None
|
assert optional_params is not None
|
||||||
assert "response_format" in optional_params
|
assert "response_format" in optional_params
|
||||||
assert optional_params["response_format"]["type"] == "json_object"
|
assert optional_params["response_format"]["type"] == "json_object"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_frequency_penalty():
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="gemini-1.5-flash", custom_llm_provider="gemini", frequency_penalty=0.5
|
||||||
|
)
|
||||||
|
assert optional_params["frequency_penalty"] == 0.5
|
||||||
|
|
|
@ -1447,3 +1447,35 @@ def test_update_key_budget_with_temp_budget_increase():
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
|
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
|
||||||
|
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_not_called_when_disabled(monkeypatch):
|
||||||
|
from litellm.proxy.proxy_server import ProxyStartupEvent
|
||||||
|
|
||||||
|
# Mock environment variable
|
||||||
|
monkeypatch.setenv("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", "true")
|
||||||
|
|
||||||
|
# Create mock prisma client
|
||||||
|
mock_prisma = MagicMock()
|
||||||
|
mock_prisma.connect = AsyncMock()
|
||||||
|
mock_prisma.health_check = AsyncMock()
|
||||||
|
mock_prisma.check_view_exists = AsyncMock()
|
||||||
|
mock_prisma._set_spend_logs_row_count_in_proxy_state = AsyncMock()
|
||||||
|
# Mock PrismaClient constructor
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.proxy_server.PrismaClient", lambda **kwargs: mock_prisma
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the setup function
|
||||||
|
await ProxyStartupEvent._setup_prisma_client(
|
||||||
|
database_url="mock_url",
|
||||||
|
proxy_logging_obj=MagicMock(),
|
||||||
|
user_api_key_cache=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify health check wasn't called
|
||||||
|
mock_prisma.health_check.assert_not_called()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue