Litellm dev 01 06 2025 p1 (#7594)

* fix(custom_logger.py): expose new 'async_get_chat_completion_prompt' event hook

* fix(custom_logger.py): langfuse_prompt_management.py

remove 'headers' from custom logger 'async_get_chat_completion_prompt' and 'get_chat_completion_prompt' event hooks

* feat(router.py): expose new function for prompt management based routing

* feat(router.py): partial working router prompt factory logic

allows load balanced model to be used for model name w/ langfuse prompt management call

* feat(router.py): fix prompt management with load balanced model group

* feat(langfuse_prompt_management.py): support reading in openai params from langfuse

enables user to define optional params on langfuse vs. client code

* test(test_Router.py): add unit test for router based langfuse prompt management

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-06 21:26:21 -08:00 committed by GitHub
parent 7133cf5b74
commit fef7839e8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 214 additions and 90 deletions

View file

@ -63,12 +63,28 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
#### PROMPT MANAGEMENT HOOKS ####
async def async_get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[str, List[AllMessageValues], dict]:
"""
Returns:
- model: str - the model to use (can be pulled from prompt management tool)
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
"""
return model, messages, non_default_params
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
headers: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,

View file

@ -152,7 +152,6 @@ class HumanloopLogger(CustomLogger):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
headers: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
@ -170,7 +169,6 @@ class HumanloopLogger(CustomLogger):
model=model,
messages=messages,
non_default_params=non_default_params,
headers=headers,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,

View file

@ -9,10 +9,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, ca
from packaging.version import Version
from typing_extensions import TypeAlias
from litellm._logging import verbose_proxy_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
@ -144,67 +141,36 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
else:
return model.replace("langfuse/", "")
async def async_pre_call_hook(
def _get_optional_params_from_langfuse(
self, langfuse_prompt_client: PROMPT_CLIENT
) -> dict:
config = langfuse_prompt_client.config
optional_params = {}
for k, v in config.items():
if k != "model":
optional_params[k] = v
return optional_params
async def async_get_chat_completion_prompt(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Union[
Literal["completion"],
Literal["text_completion"],
Literal["embeddings"],
Literal["image_generation"],
Literal["moderation"],
Literal["audio_transcription"],
Literal["pass_through_endpoint"],
Literal["rerank"],
],
) -> Union[Exception, str, dict, None]:
metadata = data.get("metadata") or {}
if isinstance(metadata, dict):
langfuse_prompt_id = cast(Optional[str], metadata.get("langfuse_prompt_id"))
langfuse_prompt_variables = cast(
Optional[dict], metadata.get("langfuse_prompt_variables") or {}
)
else:
return None
if langfuse_prompt_id is None:
return None
prompt_client = self._get_prompt_from_id(
langfuse_prompt_id=langfuse_prompt_id, langfuse_client=self.Langfuse
)
compiled_prompt: Optional[Union[str, list]] = None
if call_type == "completion" or call_type == "text_completion":
compiled_prompt = self._compile_prompt(
langfuse_prompt_client=prompt_client,
langfuse_prompt_variables=langfuse_prompt_variables,
call_type=call_type,
)
if compiled_prompt is None:
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)
if call_type == "completion":
if isinstance(compiled_prompt, list):
data["messages"] = compiled_prompt + data["messages"]
else:
data["messages"] = [
{"role": "system", "content": compiled_prompt}
] + data["messages"]
elif call_type == "text_completion" and isinstance(compiled_prompt, str):
data["prompt"] = compiled_prompt + "\n" + data["prompt"]
verbose_proxy_logger.debug(
f"LangfusePromptManagement.async_pre_call_hook compiled_prompt: {compiled_prompt}, type: {type(compiled_prompt)}"
)
return await super().async_pre_call_hook(
user_api_key_dict, cache, data, call_type
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
return self.get_chat_completion_prompt(
model,
messages,
non_default_params,
prompt_id,
prompt_variables,
dynamic_callback_params,
)
def get_chat_completion_prompt(
@ -212,7 +178,6 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
headers: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
@ -255,7 +220,11 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
## SET MODEL
model = self._get_model_from_prompt(langfuse_prompt_client, model)
return model, messages, non_default_params
optional_params = self._get_optional_params_from_langfuse(
langfuse_prompt_client
)
return model, messages, optional_params
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_callback_dynamic_params = kwargs.get(

View file

@ -398,7 +398,6 @@ class Logging(LiteLLMLoggingBaseClass):
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
headers: dict,
prompt_id: str,
prompt_variables: Optional[dict],
) -> Tuple[str, List[AllMessageValues], dict]:
@ -420,7 +419,6 @@ class Logging(LiteLLMLoggingBaseClass):
model=model,
messages=messages,
non_default_params=non_default_params,
headers=headers,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,

View file

@ -79,6 +79,7 @@ from litellm.utils import (
create_tokenizer,
get_api_key,
get_llm_provider,
get_non_default_completion_params,
get_optional_params_embeddings,
get_optional_params_image_gen,
get_optional_params_transcription,
@ -881,12 +882,8 @@ def completion( # type: ignore # noqa: PLR0915
assistant_continue_message=assistant_continue_message,
)
######## end of unpacking kwargs ###########
openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS
default_params = openai_params + all_litellm_params
non_default_params = get_non_default_completion_params(kwargs=kwargs)
litellm_params = {} # used to prevent unbound var errors
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
## PROMPT MANAGEMENT HOOKS ##
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
@ -895,7 +892,6 @@ def completion( # type: ignore # noqa: PLR0915
model=model,
messages=messages,
non_default_params=non_default_params,
headers=headers,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
)

View file

@ -9,11 +9,13 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
- model_name: chatbot_actions
litellm_params:
model: langfuse/azure/gpt-4o
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
model: langfuse/openai-gpt-3.5-turbo
tpm: 1000000
prompt_id: "jokes"
- model_name: openai-gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
# litellm_settings:
# callbacks: ["otel"]

View file

@ -47,6 +47,9 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
@ -118,6 +121,7 @@ from litellm.utils import (
EmbeddingResponse,
ModelResponse,
get_llm_provider,
get_non_default_completion_params,
get_secret,
get_utc_datetime,
is_region_allowed,
@ -774,19 +778,19 @@ class Router:
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
self, model: str, messages: List[AllMessageValues], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
self, model: str, messages: List[AllMessageValues], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
self, model: str, messages: List[AllMessageValues], stream: Union[Literal[True], Literal[False]] = False, **kwargs
) -> Union[CustomStreamWrapper, ModelResponse]:
...
@ -794,7 +798,11 @@ class Router:
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
self,
model: str,
messages: List[AllMessageValues],
stream: bool = False,
**kwargs,
):
try:
kwargs["model"] = model
@ -804,6 +812,14 @@ class Router:
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
start_time = time.time()
_is_prompt_management_model = self._is_prompt_management_model(model)
if _is_prompt_management_model:
return await self._prompt_management_factory(
model=model,
messages=messages,
kwargs=kwargs,
)
if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs)
else:
@ -1081,7 +1097,7 @@ class Router:
############## Helpers for async completion ##################
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
model: str, messages: List[AllMessageValues], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
@ -1093,7 +1109,7 @@ class Router:
async def _async_completion_no_exceptions_return_idx(
model: str,
messages: List[Dict[str, str]],
messages: List[AllMessageValues],
idx: int, # index of message this response corresponds to
**kwargs,
):
@ -1137,7 +1153,7 @@ class Router:
return final_responses
async def abatch_completion_one_model_multiple_requests(
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
self, model: str, messages: List[List[AllMessageValues]], **kwargs
):
"""
Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
@ -1159,7 +1175,7 @@ class Router:
"""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
model: str, messages: List[AllMessageValues], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
@ -1282,13 +1298,13 @@ class Router:
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@ -1297,7 +1313,7 @@ class Router:
async def schedule_acompletion(
self,
model: str,
messages: List[Dict[str, str]],
messages: List[AllMessageValues],
priority: int,
stream=False,
**kwargs,
@ -1417,6 +1433,88 @@ class Router:
llm_provider="openai",
)
def _is_prompt_management_model(self, model: str) -> bool:
model_list = self.get_model_list(model_name=model)
if model_list is None:
return False
if len(model_list) != 1:
return False
litellm_model = model_list[0]["litellm_params"].get("model", None)
if litellm_model is None:
return False
if "/" in litellm_model:
split_litellm_model = litellm_model.split("/")[0]
if split_litellm_model in litellm._known_custom_logger_compatible_callbacks:
return True
return False
async def _prompt_management_factory(
self,
model: str,
messages: List[AllMessageValues],
kwargs: Dict[str, Any],
):
prompt_management_deployment = self.get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
litellm_model = prompt_management_deployment["litellm_params"].get(
"model", None
)
prompt_id = kwargs.get("prompt_id") or prompt_management_deployment[
"litellm_params"
].get("prompt_id", None)
prompt_variables = kwargs.get(
"prompt_variables"
) or prompt_management_deployment["litellm_params"].get(
"prompt_variables", None
)
if litellm_model is None or "/" not in litellm_model:
raise ValueError(
f"Model is not a custom logger compatible callback. Got={litellm_model}"
)
custom_logger_compatible_callback = litellm_model.split("/", 1)[0]
split_litellm_model = litellm_model.split("/", 1)[1]
custom_logger = _init_custom_logger_compatible_class(
logging_integration=custom_logger_compatible_callback,
internal_usage_cache=None,
llm_router=None,
)
if custom_logger is None:
raise ValueError(
f"Custom logger is not initialized. Got={custom_logger_compatible_callback}"
)
model, messages, optional_params = (
await custom_logger.async_get_chat_completion_prompt(
model=split_litellm_model,
messages=messages,
non_default_params=get_non_default_completion_params(kwargs=kwargs),
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params={},
)
)
kwargs = {**kwargs, **optional_params}
kwargs["model"] = model
kwargs["messages"] = messages
_model_list = self.get_model_list(model_name=model)
if _model_list is None or len(_model_list) == 0: # if direct call to model
kwargs.pop("original_function")
return await litellm.acompletion(**kwargs)
return await self.async_function_with_fallbacks(**kwargs)
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model

View file

@ -144,6 +144,7 @@ from litellm.types.utils import (
TextCompletionResponse,
TranscriptionResponse,
Usage,
all_litellm_params,
)
with resources.open_text(
@ -6448,3 +6449,12 @@ def _add_path_to_api_base(api_base: str, ending_path: str) -> str:
# Re-add the original query parameters
return str(modified_url.copy_with(params=original_url.params))
def get_non_default_completion_params(kwargs: dict) -> dict:
openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS
default_params = openai_params + all_litellm_params
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
return non_default_params

View file

@ -2700,3 +2700,40 @@ def test_router_completion_with_model_id():
) as mock_pre_call_checks:
router.completion(model="123", messages=[{"role": "user", "content": "hi"}])
mock_pre_call_checks.assert_not_called()
def test_router_prompt_management_factory():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{
"model_name": "chatbot_actions",
"litellm_params": {
"model": "langfuse/openai-gpt-3.5-turbo",
"tpm": 1000000,
"prompt_id": "jokes",
},
},
{
"model_name": "openai-gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
)
assert router._is_prompt_management_model("chatbot_actions") is True
assert router._is_prompt_management_model("openai-gpt-3.5-turbo") is False
response = router._prompt_management_factory(
model="chatbot_actions",
messages=[{"role": "user", "content": "Hello world!"}],
kwargs={},
)
print(response)