Litellm dev 12 24 2024 p2 (#7400)

* fix(utils.py): default custom_llm_provider=None for 'supports_response_schema'

Closes https://github.com/BerriAI/litellm/issues/7397

* refactor(langfuse/): call langfuse logger inside customlogger compatible langfuse class, refactor langfuse logger to use verbose_logger.debug instead of print_verbose

* refactor(litellm_pre_call_utils.py): move config based team callbacks inside dynamic team callback logic

enables simpler unit testing for config-based team callbacks

* fix(proxy/_types.py): handle teamcallbackmetadata - none values

drop none values if present. if all none, use default dict to avoid downstream errors

* test(test_proxy_utils.py): add unit test preventing future issues - asserts team_id in config state not popped off across calls

Fixes https://github.com/BerriAI/litellm/issues/6787

* fix(langfuse_prompt_management.py): add success + failure logging event support

* fix: fix linting error

* test: fix test

* test: fix test

* test: override o1 prompt caching - openai currently not working

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-24 20:33:41 -08:00 committed by GitHub
parent d790ba0897
commit c95351e70f
12 changed files with 227 additions and 62 deletions

View file

@ -148,12 +148,7 @@ class LangFuseLogger:
return metadata
# def log_error(kwargs, response_obj, start_time, end_time):
# generation = trace.generation(
# level ="ERROR" # can be any of DEBUG, DEFAULT, WARNING or ERROR
# status_message='error' # can be any string (e.g. stringified stack trace or error body)
# )
def log_event( # noqa: PLR0915
def _old_log_event( # noqa: PLR0915
self,
kwargs,
response_obj,
@ -167,7 +162,7 @@ class LangFuseLogger:
# Method definition
try:
print_verbose(
verbose_logger.debug(
f"Langfuse Logging - Enters logging function for model {kwargs}"
)
@ -260,7 +255,9 @@ class LangFuseLogger:
):
input = prompt
output = response_obj.get("response", "")
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
verbose_logger.debug(
f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}"
)
trace_id = None
generation_id = None
if self._is_langfuse_v2():
@ -291,7 +288,7 @@ class LangFuseLogger:
input,
response_obj,
)
print_verbose(
verbose_logger.debug(
f"Langfuse Layer Logging - final response object: {response_obj}"
)
verbose_logger.info("Langfuse Layer Logging - logging success")
@ -444,7 +441,7 @@ class LangFuseLogger:
) -> tuple:
import langfuse
print_verbose("Langfuse Layer Logging - logging to langfuse v2")
verbose_logger.debug("Langfuse Layer Logging - logging to langfuse v2")
try:
metadata = self._prepare_metadata(metadata)
@ -577,7 +574,7 @@ class LangFuseLogger:
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
verbose_logger.debug(f"trace: {cost}")
clean_metadata["litellm_response_cost"] = cost
if standard_logging_object is not None:

View file

@ -14,7 +14,9 @@ from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
from .langfuse import LangFuseLogger
if TYPE_CHECKING:
from langfuse import Langfuse
@ -92,7 +94,7 @@ def langfuse_client_init(
return client
class LangfusePromptManagement(CustomLogger):
class LangfusePromptManagement(LangFuseLogger, CustomLogger):
def __init__(
self,
langfuse_public_key=None,
@ -248,3 +250,31 @@ class LangfusePromptManagement(CustomLogger):
model = self._get_model_from_prompt(langfuse_prompt_client, model)
return model, messages, non_default_params
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self._old_log_event(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=None,
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_object = cast(
Optional[StandardLoggingPayload],
kwargs.get("standard_logging_object", None),
)
if standard_logging_object is None:
return
self._old_log_event(
start_time=start_time,
end_time=end_time,
response_obj=None,
user_id=kwargs.get("user", None),
print_verbose=None,
status_message=standard_logging_object["error_str"],
level="ERROR",
kwargs=kwargs,
)

View file

@ -1202,7 +1202,7 @@ class Logging(LiteLLMLoggingBaseClass):
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
if langfuse_logger_to_use is not None:
_response = langfuse_logger_to_use.log_event(
_response = langfuse_logger_to_use._old_log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
@ -1925,7 +1925,7 @@ class Logging(LiteLLMLoggingBaseClass):
standard_callback_dynamic_params=self.standard_callback_dynamic_params,
in_memory_dynamic_logger_cache=in_memory_dynamic_logger_cache,
)
_response = langfuse_logger_to_use.log_event(
_response = langfuse_logger_to_use._old_log_event(
start_time=start_time,
end_time=end_time,
response_obj=None,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1011,8 +1011,24 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
@model_validator(mode="before")
@classmethod
def validate_callback_vars(cls, values):
success_callback = values.get("success_callback", [])
if success_callback is None:
values.pop("success_callback", None)
failure_callback = values.get("failure_callback", [])
if failure_callback is None:
values.pop("failure_callback", None)
callback_vars = values.get("callback_vars", {})
if callback_vars is None:
values.pop("callback_vars", None)
if all(val is None for val in values.values()):
return {
"success_callback": [],
"failure_callback": [],
"callback_vars": {},
}
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())
if callback_vars is not None:
for key in callback_vars:
if key not in valid_keys:
raise ValueError(

View file

@ -120,7 +120,7 @@ def convert_key_logging_metadata_to_callback(
def _get_dynamic_logging_metadata(
user_api_key_dict: UserAPIKeyAuth,
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
) -> Optional[TeamCallbackMetadata]:
callback_settings_obj: Optional[TeamCallbackMetadata] = None
if (
@ -132,14 +132,10 @@ def _get_dynamic_logging_metadata(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
elif user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata:
callback_settings = team_metadata.get("callback_settings", None) or {}
callback_settings_obj = TeamCallbackMetadata(**callback_settings)
verbose_proxy_logger.debug(
"Team callback settings activated: %s", callback_settings_obj
)
elif (
user_api_key_dict.team_metadata is not None
and "callback_settings" in user_api_key_dict.team_metadata
):
"""
callback_settings = {
{
@ -149,7 +145,18 @@ def _get_dynamic_logging_metadata(
}
}
"""
team_metadata = user_api_key_dict.team_metadata
callback_settings = team_metadata.get("callback_settings", None) or {}
callback_settings_obj = TeamCallbackMetadata(**callback_settings)
verbose_proxy_logger.debug(
"Team callback settings activated: %s", callback_settings_obj
)
elif user_api_key_dict.team_id is not None:
callback_settings_obj = (
LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config(
team_id=user_api_key_dict.team_id, proxy_config=proxy_config
)
)
return callback_settings_obj
@ -343,6 +350,29 @@ class LiteLLMProxyRequestSetup:
return final_tags
@staticmethod
def add_team_based_callbacks_from_config(
team_id: str,
proxy_config: ProxyConfig,
) -> Optional[TeamCallbackMetadata]:
"""
Add team-based callbacks from the config
"""
team_config = proxy_config.load_team_config(team_id=team_id)
if len(team_config.keys()) == 0:
return None
callback_vars_dict = {**team_config.get("callback_vars", team_config)}
callback_vars_dict.pop("team_id", None)
callback_vars_dict.pop("success_callback", None)
callback_vars_dict.pop("failure_callback", None)
return TeamCallbackMetadata(
success_callback=team_config.get("success_callback", None),
failure_callback=team_config.get("failure_callback", None),
callback_vars=callback_vars_dict,
)
async def add_litellm_data_to_request( # noqa: PLR0915
data: dict,
@ -551,24 +581,9 @@ async def add_litellm_data_to_request( # noqa: PLR0915
if "tags" in data:
data[_metadata_variable_name]["tags"] = data["tags"]
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data[_metadata_variable_name]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# Team Callbacks controls
callback_settings_obj = _get_dynamic_logging_metadata(
user_api_key_dict=user_api_key_dict
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
)
if callback_settings_obj is not None:
data["success_callback"] = callback_settings_obj.success_callback

View file

@ -1362,14 +1362,14 @@ class ProxyConfig:
team_config[k] = get_secret(v)
return team_config
async def load_team_config(self, team_id: str):
def load_team_config(self, team_id: str):
"""
- for a given team id
- return the relevant completion() call params
"""
# load existing config
config = self.config
config = self.get_config_state()
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", {})
@ -1459,6 +1459,14 @@ class ProxyConfig:
def update_config_state(self, config: dict):
self.config = config
def get_config_state(self):
"""
Returns a deep copy of the config,
Do this, to avoid mutating the config state outside of allowed methods
"""
return copy.deepcopy(self.config)
async def load_config( # noqa: PLR0915
self, router: Optional[litellm.Router], config_file_path: str
):

View file

@ -53,6 +53,8 @@ async def langfuse_proxy_route(
[Docs](https://docs.litellm.ai/docs/pass_through/langfuse)
"""
from litellm.proxy.proxy_server import proxy_config
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
api_key = request.headers.get("Authorization") or ""
@ -68,7 +70,9 @@ async def langfuse_proxy_route(
)
callback_settings_obj: Optional[TeamCallbackMetadata] = (
_get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
_get_dynamic_logging_metadata(
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
)
)
dynamic_langfuse_public_key: Optional[str] = None

View file

@ -1658,7 +1658,9 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
)
def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool:
def supports_response_schema(
model: str, custom_llm_provider: Optional[str] = None
) -> bool:
"""
Check if the given model + provider supports 'response_schema' as a param.

View file

@ -231,6 +231,9 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
os.environ["LANGFUSE_PUBLIC_KEY_TEMP"] = "pk-lf-9636b7a6-c066"
os.environ["LANGFUSE_SECRET_KEY_TEMP"] = "sk-lf-7cc8b620"
os.environ["LANGFUSE_HOST_TEMP"] = "https://us.cloud.langfuse.com"
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
user_api_key_dict = UserAPIKeyAuth(
token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432",
key_name="sk-...63Fg",
@ -288,7 +291,9 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
rpm_limit_per_model=None,
tpm_limit_per_model=None,
)
callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
callbacks = _get_dynamic_logging_metadata(
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
)
assert callbacks is not None
@ -308,6 +313,9 @@ def test_dynamic_logging_metadata_key_and_team_metadata(callback_vars):
],
)
def test_dynamic_turn_off_message_logging(callback_vars):
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
user_api_key_dict = UserAPIKeyAuth(
token="6f8688eaff1d37555bb9e9a6390b6d7032b3ab2526ba0152da87128eab956432",
key_name="sk-...63Fg",
@ -364,7 +372,9 @@ def test_dynamic_turn_off_message_logging(callback_vars):
rpm_limit_per_model=None,
tpm_limit_per_model=None,
)
callbacks = _get_dynamic_logging_metadata(user_api_key_dict=user_api_key_dict)
callbacks = _get_dynamic_logging_metadata(
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
)
assert callbacks is not None
assert (
@ -1008,3 +1018,89 @@ def test_get_complete_model_list(proxy_model_list, provider):
for _model in complete_list:
assert provider in _model
def test_team_callback_metadata_all_none_values():
from litellm.proxy._types import TeamCallbackMetadata
resp = TeamCallbackMetadata(
success_callback=None,
failure_callback=None,
callback_vars=None,
)
assert resp.success_callback == []
assert resp.failure_callback == []
assert resp.callback_vars == {}
@pytest.mark.parametrize(
"none_key",
[
"success_callback",
"failure_callback",
"callback_vars",
],
)
def test_team_callback_metadata_none_values(none_key):
from litellm.proxy._types import TeamCallbackMetadata
if none_key == "success_callback":
args = {
"success_callback": None,
"failure_callback": ["test"],
"callback_vars": None,
}
elif none_key == "failure_callback":
args = {
"success_callback": ["test"],
"failure_callback": None,
"callback_vars": None,
}
elif none_key == "callback_vars":
args = {
"success_callback": ["test"],
"failure_callback": ["test"],
"callback_vars": None,
}
resp = TeamCallbackMetadata(**args)
assert none_key not in resp
def test_proxy_config_state_post_init_callback_call():
"""
Ensures team_id is still in config, after callback is called
Addresses issue: https://github.com/BerriAI/litellm/issues/6787
Where team_id was being popped from config, after callback was called
"""
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.proxy.proxy_server import ProxyConfig
pc = ProxyConfig()
pc.update_config_state(
config={
"litellm_settings": {
"default_team_settings": [
{
"team_id": "test",
"success_callback": ["langfuse"],
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
"langfuse_secret": "os.environ/LANGFUSE_SECRET_KEY",
}
]
}
}
)
LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config(
team_id="test",
proxy_config=pc,
)
config = pc.get_config_state()
assert config["litellm_settings"]["default_team_settings"][0]["team_id"] == "test"