test(base_llm_unit_tests.py): add test to ensure drop params is respe… (#8224)

* test(base_llm_unit_tests.py): add test to ensure drop params is respected

* fix(types/prometheus.py): use typing_extensions for python3.8 compatibility

* build: add cherry picked commits
This commit is contained in:
Krish Dholakia 2025-02-03 16:04:44 -08:00 committed by GitHub
parent d60d3ee970
commit c8494abdea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 250 additions and 71 deletions

1
.gitignore vendored
View file

@ -71,3 +71,4 @@ tests/local_testing/log.txt
.codegpt .codegpt
litellm/proxy/_new_new_secret_config.yaml litellm/proxy/_new_new_secret_config.yaml
litellm/proxy/custom_guardrail.py

View file

@ -524,7 +524,7 @@ guardrails:
- guardrail_name: string # Required: Name of the guardrail - guardrail_name: string # Required: Name of the guardrail
litellm_params: # Required: Configuration parameters litellm_params: # Required: Configuration parameters
guardrail: string # Required: One of "aporia", "bedrock", "guardrails_ai", "lakera", "presidio", "hide-secrets" guardrail: string # Required: One of "aporia", "bedrock", "guardrails_ai", "lakera", "presidio", "hide-secrets"
mode: string # Required: One of "pre_call", "post_call", "during_call", "logging_only" mode: Union[string, List[string]] # Required: One or more of "pre_call", "post_call", "during_call", "logging_only"
api_key: string # Required: API key for the guardrail service api_key: string # Required: API key for the guardrail service
api_base: string # Optional: Base URL for the guardrail service api_base: string # Optional: Base URL for the guardrail service
default_on: boolean # Optional: Default False. When set to True, will run on every request, does not need client to specify guardrail in request default_on: boolean # Optional: Default False. When set to True, will run on every request, does not need client to specify guardrail in request

View file

@ -12,7 +12,9 @@ class CustomGuardrail(CustomLogger):
self, self,
guardrail_name: Optional[str] = None, guardrail_name: Optional[str] = None,
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None, supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
event_hook: Optional[GuardrailEventHooks] = None, event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
] = None,
default_on: bool = False, default_on: bool = False,
**kwargs, **kwargs,
): ):
@ -27,16 +29,34 @@ class CustomGuardrail(CustomLogger):
""" """
self.guardrail_name = guardrail_name self.guardrail_name = guardrail_name
self.supported_event_hooks = supported_event_hooks self.supported_event_hooks = supported_event_hooks
self.event_hook: Optional[GuardrailEventHooks] = event_hook self.event_hook: Optional[
Union[GuardrailEventHooks, List[GuardrailEventHooks]]
] = event_hook
self.default_on: bool = default_on self.default_on: bool = default_on
if supported_event_hooks: if supported_event_hooks:
## validate event_hook is in supported_event_hooks ## validate event_hook is in supported_event_hooks
if event_hook and event_hook not in supported_event_hooks: self._validate_event_hook(event_hook, supported_event_hooks)
super().__init__(**kwargs)
def _validate_event_hook(
self,
event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]],
supported_event_hooks: List[GuardrailEventHooks],
) -> None:
if event_hook is None:
return
if isinstance(event_hook, list):
for hook in event_hook:
if hook not in supported_event_hooks:
raise ValueError(
f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
)
elif isinstance(event_hook, GuardrailEventHooks):
if event_hook not in supported_event_hooks:
raise ValueError( raise ValueError(
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}" f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
) )
super().__init__(**kwargs)
def get_guardrail_from_metadata( def get_guardrail_from_metadata(
self, data: dict self, data: dict
@ -88,7 +108,7 @@ class CustomGuardrail(CustomLogger):
): ):
return False return False
if self.event_hook and self.event_hook != event_type.value: if not self._event_hook_is_event_type(event_type):
return False return False
return True return True
@ -100,6 +120,11 @@ class CustomGuardrail(CustomLogger):
eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
""" """
if self.event_hook is None:
return True
if isinstance(self.event_hook, list):
return event_type.value in self.event_hook
return self.event_hook == event_type.value return self.event_hook == event_type.value
def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict: def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:

View file

@ -142,7 +142,7 @@ def completion(
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response.choices[0].logprobs = sum_logprob model_response.choices[0].logprobs = sum_logprob # type: ignore
else: else:
raise BasetenError( raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}", message=f"Unable to parse response. Original response: {response.text}",

View file

@ -130,7 +130,7 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
pass pass
else: else:
raise litellm.utils.UnsupportedParamsError( raise litellm.utils.UnsupportedParamsError(
message="O-1 doesn't support temperature={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format( message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
temperature_value temperature_value
), ),
status_code=400, status_code=400,

View file

@ -94,7 +94,10 @@ class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
role="assistant", role="assistant",
) )
choice = Choices( choice = Choices(
finish_reason=choice["finish_reason"], index=idx, message=message finish_reason=choice["finish_reason"],
index=idx,
message=message,
logprobs=choice.get("logprobs", None),
) )
choice_list.append(choice) choice_list.append(choice)
model_response_object.choices = choice_list model_response_object.choices = choice_list

View file

@ -10,6 +10,7 @@ model_list:
- model_name: anthropic-claude - model_name: anthropic-claude
litellm_params: litellm_params:
model: claude-3-5-haiku-20241022 model: claude-3-5-haiku-20241022
mock_response: Hi!
- model_name: groq/* - model_name: groq/*
litellm_params: litellm_params:
model: groq/* model: groq/*
@ -28,4 +29,12 @@ model_list:
litellm_settings: litellm_settings:
callbacks: ["langsmith"] callbacks: ["langsmith"]
disable_no_log_param: true
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
user_email_jwt_field: "email"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD

View file

@ -168,4 +168,4 @@ def init_guardrails_v2(
guardrail_list.append(parsed_guardrail) guardrail_list.append(parsed_guardrail)
print(f"\nGuardrail List:{guardrail_list}\n") # noqa verbose_proxy_logger.info(f"\nGuardrail List:{guardrail_list}\n")

View file

@ -2,6 +2,7 @@ from enum import Enum
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
import litellm import litellm
@ -233,41 +234,55 @@ from pydantic import BaseModel, Field
class UserAPIKeyLabelValues(BaseModel): class UserAPIKeyLabelValues(BaseModel):
end_user: Optional[str] = None end_user: Annotated[
user: Optional[str] = None Optional[str], Field(..., alias=UserAPIKeyLabelNames.END_USER.value)
hashed_api_key: Optional[str] = None ] = None
api_key_alias: Optional[str] = None user: Annotated[
team: Optional[str] = None Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER.value)
team_alias: Optional[str] = None ] = None
requested_model: Optional[str] = None hashed_api_key: Annotated[
model: Optional[str] = None Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_HASH.value)
litellm_model_name: Optional[str] = None ] = None
api_key_alias: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_ALIAS.value)
] = None
team: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.TEAM.value)
] = None
team_alias: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.TEAM_ALIAS.value)
] = None
requested_model: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.REQUESTED_MODEL.value)
] = None
model: Annotated[
Optional[str],
Field(..., alias=UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value),
] = None
litellm_model_name: Annotated[
Optional[str],
Field(..., alias=UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME.value),
] = None
tags: List[str] = [] tags: List[str] = []
custom_metadata_labels: Dict[str, str] = {} custom_metadata_labels: Dict[str, str] = {}
model_id: Optional[str] = None model_id: Annotated[
api_base: Optional[str] = None Optional[str], Field(..., alias=UserAPIKeyLabelNames.MODEL_ID.value)
api_provider: Optional[str] = None ] = None
exception_status: Optional[str] = None api_base: Annotated[
exception_class: Optional[str] = None Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_BASE.value)
status_code: Optional[str] = None ] = None
fallback_model: Optional[str] = None api_provider: Annotated[
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_PROVIDER.value)
class Config: ] = None
fields = { exception_status: Annotated[
"end_user": {"alias": UserAPIKeyLabelNames.END_USER}, Optional[str], Field(..., alias=UserAPIKeyLabelNames.EXCEPTION_STATUS.value)
"user": {"alias": UserAPIKeyLabelNames.USER}, ] = None
"hashed_api_key": {"alias": UserAPIKeyLabelNames.API_KEY_HASH}, exception_class: Annotated[
"api_key_alias": {"alias": UserAPIKeyLabelNames.API_KEY_ALIAS}, Optional[str], Field(..., alias=UserAPIKeyLabelNames.EXCEPTION_CLASS.value)
"team": {"alias": UserAPIKeyLabelNames.TEAM}, ] = None
"team_alias": {"alias": UserAPIKeyLabelNames.TEAM_ALIAS}, status_code: Annotated[
"requested_model": {"alias": UserAPIKeyLabelNames.REQUESTED_MODEL}, Optional[str], Field(..., alias=UserAPIKeyLabelNames.STATUS_CODE.value)
"model": {"alias": UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME}, ] = None
"litellm_model_name": {"alias": UserAPIKeyLabelNames.v2_LITELLM_MODEL_NAME}, fallback_model: Annotated[
"model_id": {"alias": UserAPIKeyLabelNames.MODEL_ID}, Optional[str], Field(..., alias=UserAPIKeyLabelNames.FALLBACK_MODEL.value)
"api_base": {"alias": UserAPIKeyLabelNames.API_BASE}, ] = None
"api_provider": {"alias": UserAPIKeyLabelNames.API_PROVIDER},
"exception_status": {"alias": UserAPIKeyLabelNames.EXCEPTION_STATUS},
"exception_class": {"alias": UserAPIKeyLabelNames.EXCEPTION_CLASS},
"status_code": {"alias": UserAPIKeyLabelNames.STATUS_CODE},
"fallback_model": {"alias": UserAPIKeyLabelNames.FALLBACK_MODEL},
}

View file

@ -259,11 +259,35 @@ class ChatCompletionTokenLogprob(OpenAIObject):
returned. returned.
""" """
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
class ChoiceLogprobs(OpenAIObject): class ChoiceLogprobs(OpenAIObject):
content: Optional[List[ChatCompletionTokenLogprob]] = None content: Optional[List[ChatCompletionTokenLogprob]] = None
"""A list of message content tokens with log probability information.""" """A list of message content tokens with log probability information."""
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
class FunctionCall(OpenAIObject): class FunctionCall(OpenAIObject):
arguments: str arguments: str
@ -600,7 +624,10 @@ class Choices(OpenAIObject):
elif isinstance(message, dict): elif isinstance(message, dict):
self.message = Message(**message) self.message = Message(**message)
if logprobs is not None: if logprobs is not None:
self.logprobs = logprobs if isinstance(logprobs, dict):
self.logprobs = ChoiceLogprobs(**logprobs)
else:
self.logprobs = logprobs
if enhancements is not None: if enhancements is not None:
self.enhancements = enhancements self.enhancements = enhancements
@ -1544,7 +1571,7 @@ class StandardLoggingPayloadErrorInformation(TypedDict, total=False):
class StandardLoggingGuardrailInformation(TypedDict, total=False): class StandardLoggingGuardrailInformation(TypedDict, total=False):
guardrail_name: Optional[str] guardrail_name: Optional[str]
guardrail_mode: Optional[GuardrailEventHooks] guardrail_mode: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]]
guardrail_response: Optional[Union[dict, str]] guardrail_response: Optional[Union[dict, str]]
guardrail_status: Literal["success", "failure"] guardrail_status: Literal["success", "failure"]

View file

@ -715,3 +715,42 @@ class BaseOSeriesModelsTest(ABC): # test across azure/openai
request_body["messages"][0]["role"] == "developer" request_body["messages"][0]["role"] == "developer"
), "Got={} instead of system".format(request_body["messages"][0]["role"]) ), "Got={} instead of system".format(request_body["messages"][0]["role"])
assert request_body["messages"][0]["content"] == "Be a good bot!" assert request_body["messages"][0]["content"] == "Be a good bot!"
def test_completion_o_series_models_temperature(self):
"""
Test that temperature is not passed to O-series models
"""
try:
from litellm import completion
client = self.get_client()
completion_args = self.get_base_completion_call_args()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
completion(
**completion_args,
temperature=0.0,
messages=[
{
"role": "user",
"content": "Hello, world!",
}
],
drop_params=True,
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert (
"temperature" not in request_body
), "temperature should not be in the request body"
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -4,13 +4,12 @@ import sys
from datetime import datetime from datetime import datetime
sys.path.insert( sys.path.insert(
0, os.path.abspath("../../") 0, os.path.abspath("../../../")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
import pytest import pytest
from datetime import timedelta from datetime import timedelta
from litellm.utils import convert_to_model_response_object
from litellm.types.utils import ( from litellm.types.utils import (
ModelResponse, ModelResponse,
@ -20,6 +19,10 @@ from litellm.types.utils import (
CompletionTokensDetailsWrapper, CompletionTokensDetailsWrapper,
) )
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
convert_to_model_response_object,
)
def test_convert_to_model_response_object_basic(): def test_convert_to_model_response_object_basic():
"""Test basic conversion with all fields present.""" """Test basic conversion with all fields present."""
@ -621,16 +624,21 @@ def test_convert_to_model_response_object_with_logprobs():
"system_fingerprint": None, "system_fingerprint": None,
} }
result = convert_to_model_response_object( print("ENTERING CONVERT")
model_response_object=ModelResponse(), try:
response_object=response_object, result = convert_to_model_response_object(
stream=False, model_response_object=ModelResponse(),
start_time=datetime.now(), response_object=response_object,
end_time=datetime.now(), stream=False,
hidden_params=None, start_time=datetime.now(),
_response_headers=None, end_time=datetime.now(),
convert_tool_call_to_json_mode=False, hidden_params=None,
) _response_headers=None,
convert_tool_call_to_json_mode=False,
)
except Exception as e:
print(f"ERROR: {e}")
raise e
assert isinstance(result, ModelResponse) assert isinstance(result, ModelResponse)
assert result.id == "chatcmpl-123" assert result.id == "chatcmpl-123"
@ -648,7 +656,7 @@ def test_convert_to_model_response_object_with_logprobs():
# Check logprobs # Check logprobs
assert choice.logprobs is not None assert choice.logprobs is not None
assert len(choice.logprobs["content"]) == 9 assert len(choice.logprobs.content) == 9
# Check each logprob entry # Check each logprob entry
expected_tokens = [ expected_tokens = [
@ -662,14 +670,14 @@ def test_convert_to_model_response_object_with_logprobs():
" today", " today",
"?", "?",
] ]
for i, logprob in enumerate(choice.logprobs["content"]): for i, logprob in enumerate(choice.logprobs.content):
assert logprob["token"] == expected_tokens[i] assert logprob.token == expected_tokens[i]
assert isinstance(logprob["logprob"], float) assert isinstance(logprob.logprob, float)
assert isinstance(logprob["bytes"], list) assert isinstance(logprob.bytes, list)
assert len(logprob["top_logprobs"]) == 2 assert len(logprob.top_logprobs) == 2
assert isinstance(logprob["top_logprobs"][0]["token"], str) assert isinstance(logprob.top_logprobs[0].token, str)
assert isinstance(logprob["top_logprobs"][0]["logprob"], float) assert isinstance(logprob.top_logprobs[0].logprob, float)
assert isinstance(logprob["top_logprobs"][0]["bytes"], (list, type(None))) assert isinstance(logprob.top_logprobs[0].bytes, (list, type(None)))
assert result.usage.prompt_tokens == 9 assert result.usage.prompt_tokens == 9
assert result.usage.completion_tokens == 9 assert result.usage.completion_tokens == 9

View file

@ -4599,6 +4599,21 @@ def test_provider_specific_header(custom_llm_provider, expected_result):
assert "anthropic-beta" in mock_post.call_args.kwargs["headers"] assert "anthropic-beta" in mock_post.call_args.kwargs["headers"]
def test_qwen_text_completion():
# litellm._turn_on_debug()
resp = litellm.completion(
model="gpt-3.5-turbo-instruct",
messages=[{"content": "hello", "role": "user"}],
stream=False,
logprobs=1,
)
assert resp.choices[0].message.content is not None
assert resp.choices[0].logprobs.token_logprobs[0] is not None
print(
f"resp.choices[0].logprobs.token_logprobs[0]: {resp.choices[0].logprobs.token_logprobs[0]}"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"enable_preview_features", "enable_preview_features",
[True, False], [True, False],
@ -4631,3 +4646,22 @@ def test_completion_openai_metadata(monkeypatch, enable_preview_features):
} }
else: else:
assert "metadata" not in mock_completion.call_args.kwargs assert "metadata" not in mock_completion.call_args.kwargs
def test_completion_o3_mini_temperature():
try:
litellm.set_verbose = True
resp = litellm.completion(
model="o3-mini",
temperature=0.0,
messages=[
{
"role": "user",
"content": "Hello, world!",
}
],
drop_params=True,
)
assert resp.choices[0].message.content is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -72,3 +72,21 @@ def test_guardrail_masking_logging_only():
mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"] mock_call.call_args.kwargs["kwargs"]["messages"][0]["content"]
== "Hey, my name is [NAME]." == "Hey, my name is [NAME]."
) )
def test_guardrail_list_of_event_hooks():
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.guardrails import GuardrailEventHooks
cg = CustomGuardrail(
guardrail_name="custom-guard", event_hook=["pre_call", "post_call"]
)
data = {"model": "gpt-3.5-turbo", "metadata": {"guardrails": ["custom-guard"]}}
assert cg.should_run_guardrail(data=data, event_type=GuardrailEventHooks.pre_call)
assert cg.should_run_guardrail(data=data, event_type=GuardrailEventHooks.post_call)
assert not cg.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.during_call
)

View file

@ -267,7 +267,7 @@ async def test_chat_completion_request_with_redaction():
setattr(proxy_server, "llm_router", router) setattr(proxy_server, "llm_router", router)
_test_logger = testLogger() _test_logger = testLogger()
litellm.callbacks = [_ENTERPRISE_SecretDetection(), _test_logger] litellm.callbacks = [_ENTERPRISE_SecretDetection(), _test_logger]
litellm.set_verbose = True litellm._turn_on_debug()
# Prepare the query string # Prepare the query string
query_params = "param1=value1&param2=value2" query_params = "param1=value1&param2=value2"