test(base_llm_unit_tests.py): add test to ensure drop params is respe… (#8224)

* test(base_llm_unit_tests.py): add test to ensure drop params is respected

* fix(types/prometheus.py): use typing_extensions for python3.8 compatibility

* build: add cherry picked commits
This commit is contained in:
Krish Dholakia 2025-02-03 16:04:44 -08:00 committed by GitHub
parent d60d3ee970
commit c8494abdea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 250 additions and 71 deletions

1
.gitignore vendored
View file

@ -71,3 +71,4 @@ tests/local_testing/log.txt
.codegpt
litellm/proxy/_new_new_secret_config.yaml
litellm/proxy/custom_guardrail.py

View file

@ -524,7 +524,7 @@ guardrails:
- guardrail_name: string # Required: Name of the guardrail
litellm_params: # Required: Configuration parameters
guardrail: string # Required: One of "aporia", "bedrock", "guardrails_ai", "lakera", "presidio", "hide-secrets"
mode: string # Required: One of "pre_call", "post_call", "during_call", "logging_only"
mode: Union[string, List[string]] # Required: One or more of "pre_call", "post_call", "during_call", "logging_only"
api_key: string # Required: API key for the guardrail service
api_base: string # Optional: Base URL for the guardrail service
default_on: boolean # Optional: Default False. When set to True, will run on every request, does not need client to specify guardrail in request

View file

@ -12,7 +12,9 @@ class CustomGuardrail(CustomLogger):
self,
guardrail_name: Optional[str] = None,
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
event_hook: Optional[GuardrailEventHooks] = None,
event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
] = None,
default_on: bool = False,
**kwargs,
):
@ -27,16 +29,34 @@ class CustomGuardrail(CustomLogger):
"""
self.guardrail_name = guardrail_name
self.supported_event_hooks = supported_event_hooks
self.event_hook: Optional[GuardrailEventHooks] = event_hook
self.event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
] = event_hook
self.default_on: bool = default_on
if supported_event_hooks:
## validate event_hook is in supported_event_hooks
if event_hook and event_hook not in supported_event_hooks:
self._validate_event_hook(event_hook, supported_event_hooks)
super().__init__(**kwargs)
def _validate_event_hook(
self,
event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]],
supported_event_hooks: List[GuardrailEventHooks],
) -> None:
if event_hook is None:
return
if isinstance(event_hook, list):
for hook in event_hook:
if hook not in supported_event_hooks:
raise ValueError(
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
)
elif isinstance(event_hook, GuardrailEventHooks):
if event_hook not in supported_event_hooks:
raise ValueError(
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
)
super().__init__(**kwargs)
def get_guardrail_from_metadata(
self, data: dict
@ -88,7 +108,7 @@ class CustomGuardrail(CustomLogger):
):
return False
if self.event_hook and self.event_hook != event_type.value:
if not self._event_hook_is_event_type(event_type):
return False
return True
@ -100,6 +120,11 @@ class CustomGuardrail(CustomLogger):
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
"""
if self.event_hook is None:
return True
if isinstance(self.event_hook, list):
return event_type.value in self.event_hook
return self.event_hook == event_type.value
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:

View file

@ -142,7 +142,7 @@ def completion(
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response.choices[0].logprobs = sum_logprob
model_response.choices[0].logprobs = sum_logprob # type: ignore
else:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",

View file

@ -130,7 +130,7 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="O-1 doesn't support temperature={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
temperature_value
),
status_code=400,

View file

@ -94,7 +94,10 @@ class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message
finish_reason=choice["finish_reason"],
index=idx,
message=message,
logprobs=choice.get("logprobs", None),
)
choice_list.append(choice)
model_response_object.choices = choice_list

View file

@ -10,6 +10,7 @@ model_list:
- model_name: anthropic-claude
litellm_params:
model: claude-3-5-haiku-20241022
mock_response: Hi!
- model_name: groq/*
litellm_params:
model: groq/*
@ -28,4 +29,12 @@ model_list:
litellm_settings:
callbacks: ["langsmith"]
callbacks: ["langsmith"]
disable_no_log_param: true
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
user_email_jwt_field: "email"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD

View file

@ -168,4 +168,4 @@ def init_guardrails_v2(
guardrail_list.append(parsed_guardrail)
print(f"\nGuardrail List:{guardrail_list}\n") # noqa
verbose_proxy_logger.info(f"\nGuardrail List:{guardrail_list}\n")

View file

@ -2,6 +2,7 @@ from enum import Enum
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
import litellm
@ -233,41 +234,55 @@ from pydantic import BaseModel, Field
class UserAPIKeyLabelValues(BaseModel):
end_user: Optional[str] = None
user: Optional[str] = None
hashed_api_key: Optional[str] = None
api_key_alias: Optional[str] = None
team: Optional[str] = None
team_alias: Optional[str] = None
requested_model: Optional[str] = None
model: Optional[str] = None
litellm_model_name: Optional[str] = None
end_user: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.END_USER.value)
] = None
user: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER.value)
] = None
hashed_api_key: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_HASH.value)
] = None
api_key_alias: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_ALIAS.value)
] = None
team: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.TEAM.value)
] = None
team_alias: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.TEAM_ALIAS.value)
] = None
requested_model: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.REQUESTED_MODEL.value)
] = None
model: Annotated[
Optional[str],
Field(..., alias=UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value),
] = None
litellm_model_name: Annotated[
Optional[str],
Field(..., alias=UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value),
] = None
tags: List[str] = []
custom_metadata_labels: Dict[str, str] = {}
model_id: Optional[str] = None
api_base: Optional[str] = None
api_provider: Optional[str] = None
exception_status: Optional[str] = None
exception_class: Optional[str] = None
status_code: Optional[str] = None
fallback_model: Optional[str] = None
class Config:
fields = {
"end_user": {"alias": UserAPIKeyLabelNames.END_USER},
"user": {"alias": UserAPIKeyLabelNames.USER},
"hashed_api_key": {"alias": UserAPIKeyLabelNames.API_KEY_HASH},
"api_key_alias": {"alias": UserAPIKeyLabelNames.API_KEY_ALIAS},
"team": {"alias": UserAPIKeyLabelNames.TEAM},
"team_alias": {"alias": UserAPIKeyLabelNames.TEAM_ALIAS},
"requested_model": {"alias": UserAPIKeyLabelNames.REQUESTED_MODEL},
"model": {"alias": UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME},
"litellm_model_name": {"alias": UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME},
"model_id": {"alias": UserAPIKeyLabelNames.MODEL_ID},
"api_base": {"alias": UserAPIKeyLabelNames.API_BASE},
"api_provider": {"alias": UserAPIKeyLabelNames.API_PROVIDER},
"exception_status": {"alias": UserAPIKeyLabelNames.EXCEPTION_STATUS},
"exception_class": {"alias": UserAPIKeyLabelNames.EXCEPTION_CLASS},
"status_code": {"alias": UserAPIKeyLabelNames.STATUS_CODE},
"fallback_model": {"alias": UserAPIKeyLabelNames.FALLBACK_MODEL},
}
model_id: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.MODEL_ID.value)
] = None
api_base: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_BASE.value)
] = None
api_provider: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_PROVIDER.value)
] = None
exception_status: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.EXCEPTION_STATUS.value)
] = None
exception_class: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.EXCEPTION_CLASS.value)
] = None
status_code: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.STATUS_CODE.value)
] = None
fallback_model: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.FALLBACK_MODEL.value)
] = None

View file

@ -259,11 +259,35 @@ class ChatCompletionTokenLogprob(OpenAIObject):
returned.
"""
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
class ChoiceLogprobs(OpenAIObject):
content: Optional[List[ChatCompletionTokenLogprob]] = None
"""A list of message content tokens with log probability information."""
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
class FunctionCall(OpenAIObject):
arguments: str
@ -600,7 +624,10 @@ class Choices(OpenAIObject):
elif isinstance(message, dict):
self.message = Message(**message)
if logprobs is not None:
self.logprobs = logprobs
if isinstance(logprobs, dict):
self.logprobs = ChoiceLogprobs(**logprobs)
else:
self.logprobs = logprobs
if enhancements is not None:
self.enhancements = enhancements
@ -1544,7 +1571,7 @@ class StandardLoggingPayloadErrorInformation(TypedDict, total=False):
class StandardLoggingGuardrailInformation(TypedDict, total=False):
guardrail_name: Optional[str]
guardrail_mode: Optional[GuardrailEventHooks]
guardrail_mode: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]]
guardrail_response: Optional[Union[dict, str]]
guardrail_status: Literal["success", "failure"]

View file

@ -715,3 +715,42 @@ class BaseOSeriesModelsTest(ABC): # test across azure/openai
request_body["messages"][0]["role"] == "developer"
), "Got={} instead of system".format(request_body["messages"][0]["role"])
assert request_body["messages"][0]["content"] == "Be a good bot!"
def test_completion_o_series_models_temperature(self):
"""
Test that temperature is not passed to O-series models
"""
try:
from litellm import completion
client = self.get_client()
completion_args = self.get_base_completion_call_args()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
completion(
**completion_args,
temperature=0.0,
messages=[
{
"role": "user",
"content": "Hello, world!",
}
],
drop_params=True,
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert (
"temperature" not in request_body
), "temperature should not be in the request body"
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -4,13 +4,12 @@ import sys
from datetime import datetime
sys.path.insert(
0, os.path.abspath("../../")
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
import litellm
import pytest
from datetime import timedelta
from litellm.utils import convert_to_model_response_object
from litellm.types.utils import (
ModelResponse,
@ -20,6 +19,10 @@ from litellm.types.utils import (
CompletionTokensDetailsWrapper,
)
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
convert_to_model_response_object,
)
def test_convert_to_model_response_object_basic():
"""Test basic conversion with all fields present."""
@ -621,16 +624,21 @@ def test_convert_to_model_response_object_with_logprobs():
"system_fingerprint": None,
}
result = convert_to_model_response_object(
model_response_object=ModelResponse(),
response_object=response_object,
stream=False,
start_time=datetime.now(),
end_time=datetime.now(),
hidden_params=None,
_response_headers=None,
convert_tool_call_to_json_mode=False,
)
print("ENTERING CONVERT")
try:
result = convert_to_model_response_object(
model_response_object=ModelResponse(),
response_object=response_object,
stream=False,
start_time=datetime.now(),
end_time=datetime.now(),
hidden_params=None,
_response_headers=None,
convert_tool_call_to_json_mode=False,
)
except Exception as e:
print(f"ERROR: {e}")
raise e
assert isinstance(result, ModelResponse)
assert result.id == "chatcmpl-123"
@ -648,7 +656,7 @@ def test_convert_to_model_response_object_with_logprobs():
# Check logprobs
assert choice.logprobs is not None
assert len(choice.logprobs["content"]) == 9
assert len(choice.logprobs.content) == 9
# Check each logprob entry
expected_tokens = [
@ -662,14 +670,14 @@ def test_convert_to_model_response_object_with_logprobs():
" today",
"?",
]
for i, logprob in enumerate(choice.logprobs["content"]):
assert logprob["token"] == expected_tokens[i]
assert isinstance(logprob["logprob"], float)
assert isinstance(logprob["bytes"], list)
assert len(logprob["top_logprobs"]) == 2
assert isinstance(logprob["top_logprobs"][0]["token"], str)
assert isinstance(logprob["top_logprobs"][0]["logprob"], float)
assert isinstance(logprob["top_logprobs"][0]["bytes"], (list, type(None)))
for i, logprob in enumerate(choice.logprobs.content):
assert logprob.token == expected_tokens[i]
assert isinstance(logprob.logprob, float)
assert isinstance(logprob.bytes, list)
assert len(logprob.top_logprobs) == 2
assert isinstance(logprob.top_logprobs[0].token, str)
assert isinstance(logprob.top_logprobs[0].logprob, float)
assert isinstance(logprob.top_logprobs[0].bytes, (list, type(None)))
assert result.usage.prompt_tokens == 9
assert result.usage.completion_tokens == 9

View file

@ -4599,6 +4599,21 @@ def test_provider_specific_header(custom_llm_provider, expected_result):
assert "anthropic-beta" in mock_post.call_args.kwargs["headers"]
def test_qwen_text_completion():
# litellm._turn_on_debug()
resp = litellm.completion(
model="gpt-3.5-turbo-instruct",
messages=[{"content": "hello", "role": "user"}],
stream=False,
logprobs=1,
)
assert resp.choices[0].message.content is not None
assert resp.choices[0].logprobs.token_logprobs[0] is not None
print(
f"resp.choices[0].logprobs.token_logprobs[0]: {resp.choices[0].logprobs.token_logprobs[0]}"
)
@pytest.mark.parametrize(
"enable_preview_features",
[True, False],
@ -4631,3 +4646,22 @@ def test_completion_openai_metadata(monkeypatch, enable_preview_features):
}
else:
assert "metadata" not in mock_completion.call_args.kwargs
def test_completion_o3_mini_temperature():
try:
litellm.set_verbose = True
resp = litellm.completion(
model="o3-mini",
temperature=0.0,
messages=[
{
"role": "user",
"content": "Hello, world!",
}
],
drop_params=True,
)
assert resp.choices[0].message.content is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -72,3 +72,21 @@ def test_guardrail_masking_logging_only():
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
== "Hey, my name is [NAME]."
)
def test_guardrail_list_of_event_hooks():
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.guardrails import GuardrailEventHooks
cg = CustomGuardrail(
guardrail_name="custom-guard", event_hook=["pre_call", "post_call"]
)
data = {"model": "gpt-3.5-turbo", "metadata": {"guardrails": ["custom-guard"]}}
assert cg.should_run_guardrail(data=data, event_type=GuardrailEventHooks.pre_call)
assert cg.should_run_guardrail(data=data, event_type=GuardrailEventHooks.post_call)
assert not cg.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.during_call
)

View file

@ -267,7 +267,7 @@ async def test_chat_completion_request_with_redaction():
setattr(proxy_server, "llm_router", router)
_test_logger = testLogger()
litellm.callbacks = [_ENTERPRISE_SecretDetection(), _test_logger]
litellm.set_verbose = True
litellm._turn_on_debug()
# Prepare the query string
query_params = "param1=value1&param2=value2"