diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 0164cbc322..457c0537bd 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -63,12 +63,28 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac #### PROMPT MANAGEMENT HOOKS #### + async def async_get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + def get_chat_completion_prompt( self, model: str, messages: List[AllMessageValues], non_default_params: dict, - headers: dict, prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, diff --git a/litellm/integrations/humanloop.py b/litellm/integrations/humanloop.py index 0ebb423c52..fd3463f9e3 100644 --- a/litellm/integrations/humanloop.py +++ b/litellm/integrations/humanloop.py @@ -152,7 +152,6 @@ class HumanloopLogger(CustomLogger): model: str, messages: List[AllMessageValues], non_default_params: dict, - headers: dict, prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, @@ -170,7 +169,6 @@ class HumanloopLogger(CustomLogger): model=model, messages=messages, non_default_params=non_default_params, - headers=headers, prompt_id=prompt_id, prompt_variables=prompt_variables, dynamic_callback_params=dynamic_callback_params, diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index f900024b59..6d13045597 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -9,10 +9,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, ca from packaging.version import Version from typing_extensions import TypeAlias -from litellm._logging import verbose_proxy_logger -from litellm.caching.dual_cache import DualCache from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import UserAPIKeyAuth from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload @@ -144,67 +141,36 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger): else: return model.replace("langfuse/", "") - async def async_pre_call_hook( + def _get_optional_params_from_langfuse( + self, langfuse_prompt_client: PROMPT_CLIENT + ) -> dict: + config = langfuse_prompt_client.config + optional_params = {} + for k, v in config.items(): + if k != "model": + optional_params[k] = v + return optional_params + + async def async_get_chat_completion_prompt( self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: Union[ - Literal["completion"], - Literal["text_completion"], - Literal["embeddings"], - Literal["image_generation"], - Literal["moderation"], - Literal["audio_transcription"], - Literal["pass_through_endpoint"], - Literal["rerank"], - ], - ) -> Union[Exception, str, dict, None]: - - metadata = data.get("metadata") or {} - - if isinstance(metadata, dict): - langfuse_prompt_id = cast(Optional[str], metadata.get("langfuse_prompt_id")) - - langfuse_prompt_variables = cast( - Optional[dict], metadata.get("langfuse_prompt_variables") or {} - ) - else: - return None - - if langfuse_prompt_id is None: - return None - - prompt_client = self._get_prompt_from_id( - langfuse_prompt_id=langfuse_prompt_id, langfuse_client=self.Langfuse - ) - compiled_prompt: Optional[Union[str, list]] = None - if call_type == "completion" or call_type == "text_completion": - compiled_prompt = self._compile_prompt( - langfuse_prompt_client=prompt_client, - langfuse_prompt_variables=langfuse_prompt_variables, - call_type=call_type, - ) - if compiled_prompt is None: - return await super().async_pre_call_hook( - user_api_key_dict, cache, data, call_type - ) - if call_type == "completion": - if isinstance(compiled_prompt, list): - data["messages"] = compiled_prompt + data["messages"] - else: - data["messages"] = [ - {"role": "system", "content": compiled_prompt} - ] + data["messages"] - elif call_type == "text_completion" and isinstance(compiled_prompt, str): - data["prompt"] = compiled_prompt + "\n" + data["prompt"] - - verbose_proxy_logger.debug( - f"LangfusePromptManagement.async_pre_call_hook compiled_prompt: {compiled_prompt}, type: {type(compiled_prompt)}" - ) - - return await super().async_pre_call_hook( - user_api_key_dict, cache, data, call_type + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: + return self.get_chat_completion_prompt( + model, + messages, + non_default_params, + prompt_id, + prompt_variables, + dynamic_callback_params, ) def get_chat_completion_prompt( @@ -212,7 +178,6 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger): model: str, messages: List[AllMessageValues], non_default_params: dict, - headers: dict, prompt_id: str, prompt_variables: Optional[dict], dynamic_callback_params: StandardCallbackDynamicParams, @@ -255,7 +220,11 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger): ## SET MODEL model = self._get_model_from_prompt(langfuse_prompt_client, model) - return model, messages, non_default_params + optional_params = self._get_optional_params_from_langfuse( + langfuse_prompt_client + ) + + return model, messages, optional_params async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): standard_callback_dynamic_params = kwargs.get( diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index b36cd88f9d..3faad6fbbe 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -398,7 +398,6 @@ class Logging(LiteLLMLoggingBaseClass): model: str, messages: List[AllMessageValues], non_default_params: dict, - headers: dict, prompt_id: str, prompt_variables: Optional[dict], ) -> Tuple[str, List[AllMessageValues], dict]: @@ -420,7 +419,6 @@ class Logging(LiteLLMLoggingBaseClass): model=model, messages=messages, non_default_params=non_default_params, - headers=headers, prompt_id=prompt_id, prompt_variables=prompt_variables, dynamic_callback_params=self.standard_callback_dynamic_params, diff --git a/litellm/main.py b/litellm/main.py index 157d2a7865..7e852f70ec 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -79,6 +79,7 @@ from litellm.utils import ( create_tokenizer, get_api_key, get_llm_provider, + get_non_default_completion_params, get_optional_params_embeddings, get_optional_params_image_gen, get_optional_params_transcription, @@ -881,12 +882,8 @@ def completion( # type: ignore # noqa: PLR0915 assistant_continue_message=assistant_continue_message, ) ######## end of unpacking kwargs ########### - openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS - default_params = openai_params + all_litellm_params + non_default_params = get_non_default_completion_params(kwargs=kwargs) litellm_params = {} # used to prevent unbound var errors - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider ## PROMPT MANAGEMENT HOOKS ## if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: @@ -895,7 +892,6 @@ def completion( # type: ignore # noqa: PLR0915 model=model, messages=messages, non_default_params=non_default_params, - headers=headers, prompt_id=prompt_id, prompt_variables=prompt_variables, ) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 191a1bbee0..3222e69d86 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -9,11 +9,13 @@ model_list: api_key: os.environ/OPENAI_API_KEY - model_name: chatbot_actions litellm_params: - model: langfuse/azure/gpt-4o - api_base: "os.environ/AZURE_API_BASE" - api_key: "os.environ/AZURE_API_KEY" + model: langfuse/openai-gpt-3.5-turbo tpm: 1000000 prompt_id: "jokes" + - model_name: openai-gpt-3.5-turbo + litellm_params: + model: openai/gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY # litellm_settings: # callbacks: ["otel"] \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index d837b203ad..1da1a97730 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -47,6 +47,9 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from litellm.litellm_core_utils.litellm_logging import ( + _init_custom_logger_compatible_class, +) from litellm.router_strategy.budget_limiter import RouterBudgetLimiting from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler @@ -118,6 +121,7 @@ from litellm.utils import ( EmbeddingResponse, ModelResponse, get_llm_provider, + get_non_default_completion_params, get_secret, get_utc_datetime, is_region_allowed, @@ -774,19 +778,19 @@ class Router: @overload async def acompletion( - self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs + self, model: str, messages: List[AllMessageValues], stream: Literal[True], **kwargs ) -> CustomStreamWrapper: ... @overload async def acompletion( - self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs + self, model: str, messages: List[AllMessageValues], stream: Literal[False] = False, **kwargs ) -> ModelResponse: ... @overload async def acompletion( - self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs + self, model: str, messages: List[AllMessageValues], stream: Union[Literal[True], Literal[False]] = False, **kwargs ) -> Union[CustomStreamWrapper, ModelResponse]: ... @@ -794,7 +798,11 @@ class Router: # The actual implementation of the function async def acompletion( - self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs + self, + model: str, + messages: List[AllMessageValues], + stream: bool = False, + **kwargs, ): try: kwargs["model"] = model @@ -804,6 +812,14 @@ class Router: self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) request_priority = kwargs.get("priority") or self.default_priority start_time = time.time() + _is_prompt_management_model = self._is_prompt_management_model(model) + + if _is_prompt_management_model: + return await self._prompt_management_factory( + model=model, + messages=messages, + kwargs=kwargs, + ) if request_priority is not None and isinstance(request_priority, int): response = await self.schedule_acompletion(**kwargs) else: @@ -1081,7 +1097,7 @@ class Router: ############## Helpers for async completion ################## async def _async_completion_no_exceptions( - model: str, messages: List[Dict[str, str]], **kwargs + model: str, messages: List[AllMessageValues], **kwargs ): """ Wrapper around self.async_completion that catches exceptions and returns them as a result @@ -1093,7 +1109,7 @@ class Router: async def _async_completion_no_exceptions_return_idx( model: str, - messages: List[Dict[str, str]], + messages: List[AllMessageValues], idx: int, # index of message this response corresponds to **kwargs, ): @@ -1137,7 +1153,7 @@ class Router: return final_responses async def abatch_completion_one_model_multiple_requests( - self, model: str, messages: List[List[Dict[str, str]]], **kwargs + self, model: str, messages: List[List[AllMessageValues]], **kwargs ): """ Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router @@ -1159,7 +1175,7 @@ class Router: """ async def _async_completion_no_exceptions( - model: str, messages: List[Dict[str, str]], **kwargs + model: str, messages: List[AllMessageValues], **kwargs ): """ Wrapper around self.async_completion that catches exceptions and returns them as a result @@ -1282,13 +1298,13 @@ class Router: @overload async def schedule_acompletion( - self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs + self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[False] = False, **kwargs ) -> ModelResponse: ... @overload async def schedule_acompletion( - self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs + self, model: str, messages: List[AllMessageValues], priority: int, stream: Literal[True], **kwargs ) -> CustomStreamWrapper: ... @@ -1297,7 +1313,7 @@ class Router: async def schedule_acompletion( self, model: str, - messages: List[Dict[str, str]], + messages: List[AllMessageValues], priority: int, stream=False, **kwargs, @@ -1417,6 +1433,88 @@ class Router: llm_provider="openai", ) + def _is_prompt_management_model(self, model: str) -> bool: + model_list = self.get_model_list(model_name=model) + if model_list is None: + return False + if len(model_list) != 1: + return False + + litellm_model = model_list[0]["litellm_params"].get("model", None) + + if litellm_model is None: + return False + + if "/" in litellm_model: + split_litellm_model = litellm_model.split("/")[0] + if split_litellm_model in litellm._known_custom_logger_compatible_callbacks: + return True + return False + + async def _prompt_management_factory( + self, + model: str, + messages: List[AllMessageValues], + kwargs: Dict[str, Any], + ): + prompt_management_deployment = self.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + + litellm_model = prompt_management_deployment["litellm_params"].get( + "model", None + ) + prompt_id = kwargs.get("prompt_id") or prompt_management_deployment[ + "litellm_params" + ].get("prompt_id", None) + prompt_variables = kwargs.get( + "prompt_variables" + ) or prompt_management_deployment["litellm_params"].get( + "prompt_variables", None + ) + + if litellm_model is None or "/" not in litellm_model: + raise ValueError( + f"Model is not a custom logger compatible callback. Got={litellm_model}" + ) + + custom_logger_compatible_callback = litellm_model.split("/", 1)[0] + split_litellm_model = litellm_model.split("/", 1)[1] + + custom_logger = _init_custom_logger_compatible_class( + logging_integration=custom_logger_compatible_callback, + internal_usage_cache=None, + llm_router=None, + ) + + if custom_logger is None: + raise ValueError( + f"Custom logger is not initialized. Got={custom_logger_compatible_callback}" + ) + model, messages, optional_params = ( + await custom_logger.async_get_chat_completion_prompt( + model=split_litellm_model, + messages=messages, + non_default_params=get_non_default_completion_params(kwargs=kwargs), + prompt_id=prompt_id, + prompt_variables=prompt_variables, + dynamic_callback_params={}, + ) + ) + + kwargs = {**kwargs, **optional_params} + kwargs["model"] = model + kwargs["messages"] = messages + + _model_list = self.get_model_list(model_name=model) + if _model_list is None or len(_model_list) == 0: # if direct call to model + kwargs.pop("original_function") + return await litellm.acompletion(**kwargs) + + return await self.async_function_with_fallbacks(**kwargs) + def image_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model diff --git a/litellm/utils.py b/litellm/utils.py index 2ab06ba89a..2a2e6ff45b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -144,6 +144,7 @@ from litellm.types.utils import ( TextCompletionResponse, TranscriptionResponse, Usage, + all_litellm_params, ) with resources.open_text( @@ -6448,3 +6449,12 @@ def _add_path_to_api_base(api_base: str, ending_path: str) -> str: # Re-add the original query parameters return str(modified_url.copy_with(params=original_url.params)) + + +def get_non_default_completion_params(kwargs: dict) -> dict: + openai_params = litellm.OPENAI_CHAT_COMPLETION_PARAMS + default_params = openai_params + all_litellm_params + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + return non_default_params diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 829447ad38..d2c1c8fbec 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2700,3 +2700,40 @@ def test_router_completion_with_model_id(): ) as mock_pre_call_checks: router.completion(model="123", messages=[{"role": "user", "content": "hi"}]) mock_pre_call_checks.assert_not_called() + + +def test_router_prompt_management_factory(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + }, + { + "model_name": "chatbot_actions", + "litellm_params": { + "model": "langfuse/openai-gpt-3.5-turbo", + "tpm": 1000000, + "prompt_id": "jokes", + }, + }, + { + "model_name": "openai-gpt-3.5-turbo", + "litellm_params": { + "model": "openai/gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ] + ) + + assert router._is_prompt_management_model("chatbot_actions") is True + assert router._is_prompt_management_model("openai-gpt-3.5-turbo") is False + + response = router._prompt_management_factory( + model="chatbot_actions", + messages=[{"role": "user", "content": "Hello world!"}], + kwargs={}, + ) + + print(response)