mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 01 06 2025 p1 (#7594)
* fix(custom_logger.py): expose new 'async_get_chat_completion_prompt' event hook * fix(custom_logger.py): langfuse_prompt_management.py remove 'headers' from custom logger 'async_get_chat_completion_prompt' and 'get_chat_completion_prompt' event hooks * feat(router.py): expose new function for prompt management based routing * feat(router.py): partial working router prompt factory logic allows load balanced model to be used for model name w/ langfuse prompt management call * feat(router.py): fix prompt management with load balanced model group * feat(langfuse_prompt_management.py): support reading in openai params from langfuse enables user to define optional params on langfuse vs. client code * test(test_Router.py): add unit test for router based langfuse prompt management * fix: fix linting errors
This commit is contained in:
parent
56827bde7a
commit
4760693094
9 changed files with 214 additions and 90 deletions
|
@ -63,12 +63,28 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
|
||||
#### PROMPT MANAGEMENT HOOKS ####
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Returns:
|
||||
- model: str - the model to use (can be pulled from prompt management tool)
|
||||
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
||||
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
||||
"""
|
||||
return model, messages, non_default_params
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
headers: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
|
|
|
@ -152,7 +152,6 @@ class HumanloopLogger(CustomLogger):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
headers: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
|
@ -170,7 +169,6 @@ class HumanloopLogger(CustomLogger):
|
|||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
headers=headers,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
|
|
|
@ -9,10 +9,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, ca
|
|||
from packaging.version import Version
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
|
@ -144,67 +141,36 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
|||
else:
|
||||
return model.replace("langfuse/", "")
|
||||
|
||||
async def async_pre_call_hook(
|
||||
def _get_optional_params_from_langfuse(
|
||||
self, langfuse_prompt_client: PROMPT_CLIENT
|
||||
) -> dict:
|
||||
config = langfuse_prompt_client.config
|
||||
optional_params = {}
|
||||
for k, v in config.items():
|
||||
if k != "model":
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Union[
|
||||
Literal["completion"],
|
||||
Literal["text_completion"],
|
||||
Literal["embeddings"],
|
||||
Literal["image_generation"],
|
||||
Literal["moderation"],
|
||||
Literal["audio_transcription"],
|
||||
Literal["pass_through_endpoint"],
|
||||
Literal["rerank"],
|
||||
],
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
|
||||
metadata = data.get("metadata") or {}
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
langfuse_prompt_id = cast(Optional[str], metadata.get("langfuse_prompt_id"))
|
||||
|
||||
langfuse_prompt_variables = cast(
|
||||
Optional[dict], metadata.get("langfuse_prompt_variables") or {}
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
if langfuse_prompt_id is None:
|
||||
return None
|
||||
|
||||
prompt_client = self._get_prompt_from_id(
|
||||
langfuse_prompt_id=langfuse_prompt_id, langfuse_client=self.Langfuse
|
||||
)
|
||||
compiled_prompt: Optional[Union[str, list]] = None
|
||||
if call_type == "completion" or call_type == "text_completion":
|
||||
compiled_prompt = self._compile_prompt(
|
||||
langfuse_prompt_client=prompt_client,
|
||||
langfuse_prompt_variables=langfuse_prompt_variables,
|
||||
call_type=call_type,
|
||||
)
|
||||
if compiled_prompt is None:
|
||||
return await super().async_pre_call_hook(
|
||||
user_api_key_dict, cache, data, call_type
|
||||
)
|
||||
if call_type == "completion":
|
||||
if isinstance(compiled_prompt, list):
|
||||
data["messages"] = compiled_prompt + data["messages"]
|
||||
else:
|
||||
data["messages"] = [
|
||||
{"role": "system", "content": compiled_prompt}
|
||||
] + data["messages"]
|
||||
elif call_type == "text_completion" and isinstance(compiled_prompt, str):
|
||||
data["prompt"] = compiled_prompt + "\n" + data["prompt"]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"LangfusePromptManagement.async_pre_call_hook compiled_prompt: {compiled_prompt}, type: {type(compiled_prompt)}"
|
||||
)
|
||||
|
||||
return await super().async_pre_call_hook(
|
||||
user_api_key_dict, cache, data, call_type
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
return self.get_chat_completion_prompt(
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
prompt_id,
|
||||
prompt_variables,
|
||||
dynamic_callback_params,
|
||||
)
|
||||
|
||||
def get_chat_completion_prompt(
|
||||
|
@ -212,7 +178,6 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
headers: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
|
@ -255,7 +220,11 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
|
|||
## SET MODEL
|
||||
model = self._get_model_from_prompt(langfuse_prompt_client, model)
|
||||
|
||||
return model, messages, non_default_params
|
||||
optional_params = self._get_optional_params_from_langfuse(
|
||||
langfuse_prompt_client
|
||||
)
|
||||
|
||||
return model, messages, optional_params
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_callback_dynamic_params = kwargs.get(
|
||||
|
|
|
@ -398,7 +398,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
headers: dict,
|
||||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
@ -420,7 +419,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
headers=headers,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
|
|
|
@ -79,6 +79,7 @@ from litellm.utils import (
|
|||
create_tokenizer,
|
||||
get_api_key,
|
||||
get_llm_provider,
|
||||
get_non_default_completion_params,
|
||||
get_optional_params_embeddings,
|
||||
get_optional_params_image_gen,
|
||||
get_optional_params_transcription,
|
||||
|
@ -881,12 +882,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
assistant_continue_message=assistant_continue_message,
|
||||
)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS
|
||||
default_params = openai_params + all_litellm_params
|
||||
non_default_params = get_non_default_completion_params(kwargs=kwargs)
|
||||
litellm_params = {} # used to prevent unbound var errors
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
## PROMPT MANAGEMENT HOOKS ##
|
||||
|
||||
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
|
||||
|
@ -895,7 +892,6 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
headers=headers,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
)
|
||||
|
|
|
@ -9,11 +9,13 @@ model_list:
|
|||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: chatbot_actions
|
||||
litellm_params:
|
||||
model: langfuse/azure/gpt-4o
|
||||
api_base: "os.environ/AZURE_API_BASE"
|
||||
api_key: "os.environ/AZURE_API_KEY"
|
||||
model: langfuse/openai-gpt-3.5-turbo
|
||||
tpm: 1000000
|
||||
prompt_id: "jokes"
|
||||
- model_name: openai-gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
# litellm_settings:
|
||||
# callbacks: ["otel"]
|
|
@ -47,6 +47,9 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
|
@ -118,6 +121,7 @@ from litellm.utils import (
|
|||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
get_llm_provider,
|
||||
get_non_default_completion_params,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
is_region_allowed,
|
||||
|
@ -774,19 +778,19 @@ class Router:
|
|||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
|
||||
self, model: str, messages: List[AllMessageValues], stream: Literal[True], **kwargs
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
|
||||
self, model: str, messages: List[AllMessageValues], stream: Literal[False] = False, **kwargs
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
|
||||
self, model: str, messages: List[AllMessageValues], stream: Union[Literal[True], Literal[False]] = False, **kwargs
|
||||
) -> Union[CustomStreamWrapper, ModelResponse]:
|
||||
...
|
||||
|
||||
|
@ -794,7 +798,11 @@ class Router:
|
|||
|
||||
# The actual implementation of the function
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
@ -804,6 +812,14 @@ class Router:
|
|||
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||
request_priority = kwargs.get("priority") or self.default_priority
|
||||
start_time = time.time()
|
||||
_is_prompt_management_model = self._is_prompt_management_model(model)
|
||||
|
||||
if _is_prompt_management_model:
|
||||
return await self._prompt_management_factory(
|
||||
model=model,
|
||||
messages=messages,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
if request_priority is not None and isinstance(request_priority, int):
|
||||
response = await self.schedule_acompletion(**kwargs)
|
||||
else:
|
||||
|
@ -1081,7 +1097,7 @@ class Router:
|
|||
############## Helpers for async completion ##################
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
model: str, messages: List[AllMessageValues], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
|
@ -1093,7 +1109,7 @@ class Router:
|
|||
|
||||
async def _async_completion_no_exceptions_return_idx(
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
messages: List[AllMessageValues],
|
||||
idx: int, # index of message this response corresponds to
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1137,7 +1153,7 @@ class Router:
|
|||
return final_responses
|
||||
|
||||
async def abatch_completion_one_model_multiple_requests(
|
||||
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
|
||||
self, model: str, messages: List[List[AllMessageValues]], **kwargs
|
||||
):
|
||||
"""
|
||||
Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
|
||||
|
@ -1159,7 +1175,7 @@ class Router:
|
|||
"""
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
model: str, messages: List[AllMessageValues], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
|
@ -1282,13 +1298,13 @@ class Router:
|
|||
|
||||
@overload
|
||||
async def schedule_acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
|
||||
self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[False] = False, **kwargs
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def schedule_acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
|
||||
self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[True], **kwargs
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
|
@ -1297,7 +1313,7 @@ class Router:
|
|||
async def schedule_acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
messages: List[AllMessageValues],
|
||||
priority: int,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
|
@ -1417,6 +1433,88 @@ class Router:
|
|||
llm_provider="openai",
|
||||
)
|
||||
|
||||
def _is_prompt_management_model(self, model: str) -> bool:
|
||||
model_list = self.get_model_list(model_name=model)
|
||||
if model_list is None:
|
||||
return False
|
||||
if len(model_list) != 1:
|
||||
return False
|
||||
|
||||
litellm_model = model_list[0]["litellm_params"].get("model", None)
|
||||
|
||||
if litellm_model is None:
|
||||
return False
|
||||
|
||||
if "/" in litellm_model:
|
||||
split_litellm_model = litellm_model.split("/")[0]
|
||||
if split_litellm_model in litellm._known_custom_logger_compatible_callbacks:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _prompt_management_factory(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
kwargs: Dict[str, Any],
|
||||
):
|
||||
prompt_management_deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "prompt"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
|
||||
litellm_model = prompt_management_deployment["litellm_params"].get(
|
||||
"model", None
|
||||
)
|
||||
prompt_id = kwargs.get("prompt_id") or prompt_management_deployment[
|
||||
"litellm_params"
|
||||
].get("prompt_id", None)
|
||||
prompt_variables = kwargs.get(
|
||||
"prompt_variables"
|
||||
) or prompt_management_deployment["litellm_params"].get(
|
||||
"prompt_variables", None
|
||||
)
|
||||
|
||||
if litellm_model is None or "/" not in litellm_model:
|
||||
raise ValueError(
|
||||
f"Model is not a custom logger compatible callback. Got={litellm_model}"
|
||||
)
|
||||
|
||||
custom_logger_compatible_callback = litellm_model.split("/", 1)[0]
|
||||
split_litellm_model = litellm_model.split("/", 1)[1]
|
||||
|
||||
custom_logger = _init_custom_logger_compatible_class(
|
||||
logging_integration=custom_logger_compatible_callback,
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
|
||||
if custom_logger is None:
|
||||
raise ValueError(
|
||||
f"Custom logger is not initialized. Got={custom_logger_compatible_callback}"
|
||||
)
|
||||
model, messages, optional_params = (
|
||||
await custom_logger.async_get_chat_completion_prompt(
|
||||
model=split_litellm_model,
|
||||
messages=messages,
|
||||
non_default_params=get_non_default_completion_params(kwargs=kwargs),
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params={},
|
||||
)
|
||||
)
|
||||
|
||||
kwargs = {**kwargs, **optional_params}
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
|
||||
_model_list = self.get_model_list(model_name=model)
|
||||
if _model_list is None or len(_model_list) == 0: # if direct call to model
|
||||
kwargs.pop("original_function")
|
||||
return await litellm.acompletion(**kwargs)
|
||||
|
||||
return await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
|
|
@ -144,6 +144,7 @@ from litellm.types.utils import (
|
|||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
Usage,
|
||||
all_litellm_params,
|
||||
)
|
||||
|
||||
with resources.open_text(
|
||||
|
@ -6448,3 +6449,12 @@ def _add_path_to_api_base(api_base: str, ending_path: str) -> str:
|
|||
|
||||
# Re-add the original query parameters
|
||||
return str(modified_url.copy_with(params=original_url.params))
|
||||
|
||||
|
||||
def get_non_default_completion_params(kwargs: dict) -> dict:
|
||||
openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS
|
||||
default_params = openai_params + all_litellm_params
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
return non_default_params
|
||||
|
|
|
@ -2700,3 +2700,40 @@ def test_router_completion_with_model_id():
|
|||
) as mock_pre_call_checks:
|
||||
router.completion(model="123", messages=[{"role": "user", "content": "hi"}])
|
||||
mock_pre_call_checks.assert_not_called()
|
||||
|
||||
|
||||
def test_router_prompt_management_factory():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
},
|
||||
{
|
||||
"model_name": "chatbot_actions",
|
||||
"litellm_params": {
|
||||
"model": "langfuse/openai-gpt-3.5-turbo",
|
||||
"tpm": 1000000,
|
||||
"prompt_id": "jokes",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai-gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert router._is_prompt_management_model("chatbot_actions") is True
|
||||
assert router._is_prompt_management_model("openai-gpt-3.5-turbo") is False
|
||||
|
||||
response = router._prompt_management_factory(
|
||||
model="chatbot_actions",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue