Litellm dev 01 06 2025 p1 (#7594)

* fix(custom_logger.py): expose new 'async_get_chat_completion_prompt' event hook

* fix(custom_logger.py): langfuse_prompt_management.py

remove 'headers' from custom logger 'async_get_chat_completion_prompt' and 'get_chat_completion_prompt' event hooks

* feat(router.py): expose new function for prompt management based routing

* feat(router.py): partial working router prompt factory logic

allows load balanced model to be used for model name w/ langfuse prompt management call

* feat(router.py): fix prompt management with load balanced model group

* feat(langfuse_prompt_management.py): support reading in openai params from langfuse

enables user to define optional params on langfuse vs. client code

* test(test_Router.py): add unit test for router based langfuse prompt management

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2025-01-06 21:26:21 -08:00 committed by GitHub
parent 56827bde7a
commit 4760693094
9 changed files with 214 additions and 90 deletions

View file

@ -47,6 +47,9 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
@ -118,6 +121,7 @@ from litellm.utils import (
EmbeddingResponse,
ModelResponse,
get_llm_provider,
get_non_default_completion_params,
get_secret,
get_utc_datetime,
is_region_allowed,
@ -774,19 +778,19 @@ class Router:
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
self, model: str, messages: List[AllMessageValues], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
self, model: str, messages: List[AllMessageValues], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
self, model: str, messages: List[AllMessageValues], stream: Union[Literal[True], Literal[False]] = False, **kwargs
) -> Union[CustomStreamWrapper, ModelResponse]:
...
@ -794,7 +798,11 @@ class Router:
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
self,
model: str,
messages: List[AllMessageValues],
stream: bool = False,
**kwargs,
):
try:
kwargs["model"] = model
@ -804,6 +812,14 @@ class Router:
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
start_time = time.time()
_is_prompt_management_model = self._is_prompt_management_model(model)
if _is_prompt_management_model:
return await self._prompt_management_factory(
model=model,
messages=messages,
kwargs=kwargs,
)
if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs)
else:
@ -1081,7 +1097,7 @@ class Router:
############## Helpers for async completion ##################
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
model: str, messages: List[AllMessageValues], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
@ -1093,7 +1109,7 @@ class Router:
async def _async_completion_no_exceptions_return_idx(
model: str,
messages: List[Dict[str, str]],
messages: List[AllMessageValues],
idx: int, # index of message this response corresponds to
**kwargs,
):
@ -1137,7 +1153,7 @@ class Router:
return final_responses
async def abatch_completion_one_model_multiple_requests(
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
self, model: str, messages: List[List[AllMessageValues]], **kwargs
):
"""
Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
@ -1159,7 +1175,7 @@ class Router:
"""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
model: str, messages: List[AllMessageValues], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
@ -1282,13 +1298,13 @@ class Router:
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@ -1297,7 +1313,7 @@ class Router:
async def schedule_acompletion(
self,
model: str,
messages: List[Dict[str, str]],
messages: List[AllMessageValues],
priority: int,
stream=False,
**kwargs,
@ -1417,6 +1433,88 @@ class Router:
llm_provider="openai",
)
def _is_prompt_management_model(self, model: str) -> bool:
model_list = self.get_model_list(model_name=model)
if model_list is None:
return False
if len(model_list) != 1:
return False
litellm_model = model_list[0]["litellm_params"].get("model", None)
if litellm_model is None:
return False
if "/" in litellm_model:
split_litellm_model = litellm_model.split("/")[0]
if split_litellm_model in litellm._known_custom_logger_compatible_callbacks:
return True
return False
async def _prompt_management_factory(
self,
model: str,
messages: List[AllMessageValues],
kwargs: Dict[str, Any],
):
prompt_management_deployment = self.get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
litellm_model = prompt_management_deployment["litellm_params"].get(
"model", None
)
prompt_id = kwargs.get("prompt_id") or prompt_management_deployment[
"litellm_params"
].get("prompt_id", None)
prompt_variables = kwargs.get(
"prompt_variables"
) or prompt_management_deployment["litellm_params"].get(
"prompt_variables", None
)
if litellm_model is None or "/" not in litellm_model:
raise ValueError(
f"Model is not a custom logger compatible callback. Got={litellm_model}"
)
custom_logger_compatible_callback = litellm_model.split("/", 1)[0]
split_litellm_model = litellm_model.split("/", 1)[1]
custom_logger = _init_custom_logger_compatible_class(
logging_integration=custom_logger_compatible_callback,
internal_usage_cache=None,
llm_router=None,
)
if custom_logger is None:
raise ValueError(
f"Custom logger is not initialized. Got={custom_logger_compatible_callback}"
)
model, messages, optional_params = (
await custom_logger.async_get_chat_completion_prompt(
model=split_litellm_model,
messages=messages,
non_default_params=get_non_default_completion_params(kwargs=kwargs),
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params={},
)
)
kwargs = {**kwargs, **optional_params}
kwargs["model"] = model
kwargs["messages"] = messages
_model_list = self.get_model_list(model_name=model)
if _model_list is None or len(_model_list) == 0: # if direct call to model
kwargs.pop("original_function")
return await litellm.acompletion(**kwargs)
return await self.async_function_with_fallbacks(**kwargs)
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model