fix(vertex_ai/gemini/transformation.py): handle 'http://' in gemini p… (#7660)

* fix(vertex_ai/gemini/transformation.py): handle 'http://' in gemini process url

* refactor(router.py): refactor '_prompt_management_factory' to use logging obj get_chat_completion logic

deduplicates code

* fix(litellm_logging.py): update 'get_chat_completion_prompt' to update logging object messages

* docs(prompt_management.md): update prompt management to be in beta

given feedback - this still needs to be revised (e.g. passing in user message, not ignoring)

* refactor(prompt_management_base.py): introduce base class for prompt management

allows consistent behaviour across prompt management integrations

* feat(prompt_management_base.py): support adding client message to template message + refactor langfuse prompt management to use prompt management base

* fix(litellm_logging.py): log prompt id + prompt variables to langfuse if set

allows tracking what prompt was used for what purpose

* feat(litellm_logging.py): log prompt management metadata in standard logging payload + use in langfuse

allows logging prompt id / prompt variables to langfuse

* test: fix test

* fix(router.py): cleanup unused imports

* fix: fix linting error

* fix: fix trace param typing

* fix: fix linting errors

* fix: fix code qa check
This commit is contained in:
Krish Dholakia 2025-01-10 07:31:59 -08:00 committed by GitHub
parent 865e6d5bda
commit c10ae8879e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 340 additions and 76 deletions

View file

@ -2,7 +2,13 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# Prompt Management # [BETA] Prompt Management
:::info
This feature is currently in beta, and might change unexpectedly. We expect this to be more stable by next month (February 2025).
:::
Run experiments or change the specific model (e.g. from gpt-4o to gpt4o-mini finetune) from your prompt management tool (e.g. Langfuse) instead of making changes in the application. Run experiments or change the specific model (e.g. from gpt-4o to gpt4o-mini finetune) from your prompt management tool (e.g. Langfuse) instead of making changes in the application.

View file

@ -179,6 +179,7 @@ class LangFuseLogger:
optional_params = copy.deepcopy(kwargs.get("optional_params", {})) optional_params = copy.deepcopy(kwargs.get("optional_params", {}))
prompt = {"messages": kwargs.get("messages")} prompt = {"messages": kwargs.get("messages")}
functions = optional_params.pop("functions", None) functions = optional_params.pop("functions", None)
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
if functions is not None: if functions is not None:
@ -462,15 +463,27 @@ class LangFuseLogger:
if standard_logging_object is None: if standard_logging_object is None:
end_user_id = None end_user_id = None
prompt_management_metadata: Optional[dict] = None
else: else:
end_user_id = standard_logging_object["metadata"].get( end_user_id = standard_logging_object["metadata"].get(
"user_api_key_end_user_id", None "user_api_key_end_user_id", None
) )
prompt_management_metadata = cast(
Optional[dict],
standard_logging_object["metadata"].get(
"prompt_management_metadata", None
),
)
# Clean Metadata before logging - never log raw metadata # Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion # the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging # we clean out all extra litellm metadata params before logging
clean_metadata = {} clean_metadata: Dict[str, Any] = {}
if prompt_management_metadata is not None:
clean_metadata["prompt_management_metadata"] = (
prompt_management_metadata
)
if isinstance(metadata, dict): if isinstance(metadata, dict):
for key, value in metadata.items(): for key, value in metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
@ -498,10 +511,10 @@ class LangFuseLogger:
) )
session_id = clean_metadata.pop("session_id", None) session_id = clean_metadata.pop("session_id", None)
trace_name = clean_metadata.pop("trace_name", None) trace_name = cast(Optional[str], clean_metadata.pop("trace_name", None))
trace_id = clean_metadata.pop("trace_id", litellm_call_id) trace_id = clean_metadata.pop("trace_id", litellm_call_id)
existing_trace_id = clean_metadata.pop("existing_trace_id", None) existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", []) update_trace_keys = cast(list, clean_metadata.pop("update_trace_keys", []))
debug = clean_metadata.pop("debug_langfuse", None) debug = clean_metadata.pop("debug_langfuse", None)
mask_input = clean_metadata.pop("mask_input", False) mask_input = clean_metadata.pop("mask_input", False)
mask_output = clean_metadata.pop("mask_output", False) mask_output = clean_metadata.pop("mask_output", False)
@ -514,7 +527,7 @@ class LangFuseLogger:
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}" trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if existing_trace_id is not None: if existing_trace_id is not None:
trace_params = {"id": existing_trace_id} trace_params: Dict[str, Any] = {"id": existing_trace_id}
# Update the following keys for this trace # Update the following keys for this trace
for metadata_param_key in update_trace_keys: for metadata_param_key in update_trace_keys:
@ -656,8 +669,12 @@ class LangFuseLogger:
# if `generation_name` is None, use sensible default values # if `generation_name` is None, use sensible default values
# If using litellm proxy user `key_alias` if not None # If using litellm proxy user `key_alias` if not None
# If `key_alias` is None, just log `litellm-{call_type}` as the generation name # If `key_alias` is None, just log `litellm-{call_type}` as the generation name
_user_api_key_alias = clean_metadata.get("user_api_key_alias", None) _user_api_key_alias = cast(
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}" Optional[str], clean_metadata.get("user_api_key_alias", None)
)
generation_name = (
f"litellm-{cast(str, kwargs.get('call_type', 'completion'))}"
)
if _user_api_key_alias is not None: if _user_api_key_alias is not None:
generation_name = f"litellm:{_user_api_key_alias}" generation_name = f"litellm:{_user_api_key_alias}"

View file

@ -10,12 +10,14 @@ from packaging.version import Version
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.openai import AllMessageValues from litellm.integrations.prompt_management_base import PromptManagementClient
from litellm.types.llms.openai import AllMessageValues, ChatCompletionSystemMessage
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import ( from ...litellm_core_utils.specialty_caches.dynamic_logging_cache import (
DynamicLoggingCache, DynamicLoggingCache,
) )
from ..prompt_management_base import PromptManagementBase
from .langfuse import LangFuseLogger from .langfuse import LangFuseLogger
from .langfuse_handler import LangFuseHandler from .langfuse_handler import LangFuseHandler
@ -97,7 +99,7 @@ def langfuse_client_init(
return client return client
class LangfusePromptManagement(LangFuseLogger, CustomLogger): class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogger):
def __init__( def __init__(
self, self,
langfuse_public_key=None, langfuse_public_key=None,
@ -112,6 +114,10 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
flush_interval=flush_interval, flush_interval=flush_interval,
) )
@property
def integration_name(self):
return "langfuse"
def _get_prompt_from_id( def _get_prompt_from_id(
self, langfuse_prompt_id: str, langfuse_client: LangfuseClass self, langfuse_prompt_id: str, langfuse_client: LangfuseClass
) -> PROMPT_CLIENT: ) -> PROMPT_CLIENT:
@ -122,7 +128,7 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
langfuse_prompt_client: PROMPT_CLIENT, langfuse_prompt_client: PROMPT_CLIENT,
langfuse_prompt_variables: Optional[dict], langfuse_prompt_variables: Optional[dict],
call_type: Union[Literal["completion"], Literal["text_completion"]], call_type: Union[Literal["completion"], Literal["text_completion"]],
) -> Optional[Union[str, list]]: ) -> List[AllMessageValues]:
compiled_prompt: Optional[Union[str, list]] = None compiled_prompt: Optional[Union[str, list]] = None
if langfuse_prompt_variables is None: if langfuse_prompt_variables is None:
@ -130,16 +136,14 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
compiled_prompt = langfuse_prompt_client.compile(**langfuse_prompt_variables) compiled_prompt = langfuse_prompt_client.compile(**langfuse_prompt_variables)
return compiled_prompt if isinstance(compiled_prompt, str):
compiled_prompt = [
def _get_model_from_prompt( ChatCompletionSystemMessage(role="system", content=compiled_prompt)
self, langfuse_prompt_client: PROMPT_CLIENT, model: str ]
) -> str:
config = langfuse_prompt_client.config
if "model" in config:
return config["model"]
else: else:
return model.replace("langfuse/", "") compiled_prompt = cast(List[AllMessageValues], compiled_prompt)
return compiled_prompt
def _get_optional_params_from_langfuse( def _get_optional_params_from_langfuse(
self, langfuse_prompt_client: PROMPT_CLIENT self, langfuse_prompt_client: PROMPT_CLIENT
@ -173,23 +177,27 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
dynamic_callback_params, dynamic_callback_params,
) )
def get_chat_completion_prompt( def should_run_prompt_management(
self,
prompt_id: str,
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
langfuse_client = langfuse_client_init(
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
langfuse_host=dynamic_callback_params.get("langfuse_host"),
)
langfuse_prompt_client = self._get_prompt_from_id(
langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client
)
return langfuse_prompt_client is not None
def _compile_prompt_helper(
self, self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams, dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[ ) -> PromptManagementClient:
str,
List[AllMessageValues],
dict,
]:
if prompt_id is None:
raise ValueError(
"Langfuse prompt id is required. Pass in as parameter 'langfuse_prompt_id'"
)
langfuse_client = langfuse_client_init( langfuse_client = langfuse_client_init(
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"), langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
langfuse_secret=dynamic_callback_params.get("langfuse_secret"), langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
@ -206,25 +214,19 @@ class LangfusePromptManagement(LangFuseLogger, CustomLogger):
call_type="completion", call_type="completion",
) )
if compiled_prompt is None: template_model = langfuse_prompt_client.config.get("model")
raise ValueError(f"Langfuse prompt not found. Prompt id={prompt_id}")
if isinstance(compiled_prompt, list):
messages = compiled_prompt
elif isinstance(compiled_prompt, str):
messages = [{"role": "user", "content": compiled_prompt}]
else:
raise ValueError(
f"Langfuse prompt is not a list or string. Prompt id={prompt_id}, compiled_prompt type={type(compiled_prompt)}"
)
## SET MODEL template_optional_params = self._get_optional_params_from_langfuse(
model = self._get_model_from_prompt(langfuse_prompt_client, model)
optional_params = self._get_optional_params_from_langfuse(
langfuse_prompt_client langfuse_prompt_client
) )
return model, messages, optional_params return PromptManagementClient(
prompt_id=prompt_id,
prompt_template=compiled_prompt,
prompt_template_model=template_model,
prompt_template_optional_params=template_optional_params,
completed_messages=None,
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_callback_dynamic_params = kwargs.get( standard_callback_dynamic_params = kwargs.get(

View file

@ -0,0 +1,118 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypedDict
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams
class PromptManagementClient(TypedDict):
prompt_id: str
prompt_template: List[AllMessageValues]
prompt_template_model: Optional[str]
prompt_template_optional_params: Optional[Dict[str, Any]]
completed_messages: Optional[List[AllMessageValues]]
class PromptManagementBase(ABC):
@property
@abstractmethod
def integration_name(self) -> str:
pass
@abstractmethod
def should_run_prompt_management(
self,
prompt_id: str,
dynamic_callback_params: StandardCallbackDynamicParams,
) -> bool:
pass
@abstractmethod
def _compile_prompt_helper(
self,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> PromptManagementClient:
pass
def merge_messages(
self,
prompt_template: List[AllMessageValues],
client_messages: List[AllMessageValues],
) -> List[AllMessageValues]:
return prompt_template + client_messages
def compile_prompt(
self,
prompt_id: str,
prompt_variables: Optional[dict],
client_messages: List[AllMessageValues],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> PromptManagementClient:
compiled_prompt_client = self._compile_prompt_helper(
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=dynamic_callback_params,
)
try:
messages = compiled_prompt_client["prompt_template"] + client_messages
except Exception as e:
raise ValueError(
f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}"
)
compiled_prompt_client["completed_messages"] = messages
return compiled_prompt_client
def _get_model_from_prompt(
self, prompt_management_client: PromptManagementClient, model: str
) -> str:
if prompt_management_client["prompt_template_model"] is not None:
return prompt_management_client["prompt_template_model"]
else:
return model.replace("{}/".format(self.integration_name), "")
def get_chat_completion_prompt(
self,
model: str,
messages: List[AllMessageValues],
non_default_params: dict,
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
if not self.should_run_prompt_management(
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
):
return model, messages, non_default_params
prompt_template = self.compile_prompt(
prompt_id=prompt_id,
prompt_variables=prompt_variables,
client_messages=messages,
dynamic_callback_params=dynamic_callback_params,
)
completed_messages = prompt_template["completed_messages"] or messages
prompt_template_optional_params = (
prompt_template["prompt_template_optional_params"] or {}
)
updated_non_default_params = {
**non_default_params,
**prompt_template_optional_params,
}
model = self._get_model_from_prompt(
prompt_management_client=prompt_template, model=model
)
return model, completed_messages, updated_non_default_params

View file

@ -59,6 +59,7 @@ from litellm.types.utils import (
StandardLoggingPayload, StandardLoggingPayload,
StandardLoggingPayloadErrorInformation, StandardLoggingPayloadErrorInformation,
StandardLoggingPayloadStatus, StandardLoggingPayloadStatus,
StandardLoggingPromptManagementMetadata,
TextCompletionResponse, TextCompletionResponse,
TranscriptionResponse, TranscriptionResponse,
Usage, Usage,
@ -424,6 +425,7 @@ class Logging(LiteLLMLoggingBaseClass):
dynamic_callback_params=self.standard_callback_dynamic_params, dynamic_callback_params=self.standard_callback_dynamic_params,
) )
) )
self.messages = messages
return model, messages, non_default_params return model, messages, non_default_params
@ -431,6 +433,7 @@ class Logging(LiteLLMLoggingBaseClass):
""" """
Common helper function across the sync + async pre-call function Common helper function across the sync + async pre-call function
""" """
self.model_call_details["input"] = input self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args self.model_call_details["additional_args"] = additional_args
@ -2628,7 +2631,7 @@ class StandardLoggingPayloadSetup:
@staticmethod @staticmethod
def get_standard_logging_metadata( def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]] metadata: Optional[Dict[str, Any]], litellm_params: Optional[dict] = None
) -> StandardLoggingMetadata: ) -> StandardLoggingMetadata:
""" """
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -2643,6 +2646,20 @@ class StandardLoggingPayloadSetup:
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
""" """
prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
if litellm_params is not None:
prompt_id = cast(Optional[str], litellm_params.get("prompt_id", None))
prompt_variables = cast(
Optional[dict], litellm_params.get("prompt_variables", None)
)
prompt_management_metadata = StandardLoggingPromptManagementMetadata(
prompt_id=prompt_id,
prompt_variables=prompt_variables,
)
# Initialize with default values # Initialize with default values
clean_metadata = StandardLoggingMetadata( clean_metadata = StandardLoggingMetadata(
user_api_key_hash=None, user_api_key_hash=None,
@ -2655,6 +2672,7 @@ class StandardLoggingPayloadSetup:
requester_ip_address=None, requester_ip_address=None,
requester_metadata=None, requester_metadata=None,
user_api_key_end_user_id=None, user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys # Filter the metadata dictionary to include only the specified keys
@ -2949,7 +2967,7 @@ def get_standard_logging_object_payload(
) )
# clean up litellm metadata # clean up litellm metadata
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata( clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
metadata=metadata metadata=metadata, litellm_params=litellm_params
) )
saved_cache_cost: float = 0.0 saved_cache_cost: float = 0.0
@ -2966,6 +2984,7 @@ def get_standard_logging_object_payload(
## Get model cost information ## ## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs) base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information( model_cost_information = StandardLoggingPayloadSetup.get_model_cost_information(
base_model=base_model, base_model=base_model,
custom_pricing=custom_pricing, custom_pricing=custom_pricing,
@ -3072,6 +3091,7 @@ def get_standard_logging_metadata(
requester_ip_address=None, requester_ip_address=None,
requester_metadata=None, requester_metadata=None,
user_api_key_end_user_id=None, user_api_key_end_user_id=None,
prompt_management_metadata=None,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys # Filter the metadata dictionary to include only the specified keys

View file

@ -82,7 +82,7 @@ def _process_gemini_image(image_url: str) -> PartType:
): ):
file_data = FileDataType(file_uri=image_url, mime_type=image_type) file_data = FileDataType(file_uri=image_url, mime_type=image_type)
return PartType(file_data=file_data) return PartType(file_data=file_data)
elif "https://" in image_url or "base64" in image_url: elif "http://" in image_url or "https://" in image_url or "base64" in image_url:
# https links for unsupported mime types and base64 images # https links for unsupported mime types and base64 images
image = convert_to_anthropic_image_obj(image_url) image = convert_to_anthropic_image_obj(image_url)
_blob = BlobType(data=image["data"], mime_type=image["media_type"]) _blob = BlobType(data=image["data"], mime_type=image["media_type"])

View file

@ -1077,6 +1077,8 @@ def completion( # type: ignore # noqa: PLR0915
litellm_metadata=kwargs.get("litellm_metadata"), litellm_metadata=kwargs.get("litellm_metadata"),
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block, disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
drop_params=kwargs.get("drop_params"), drop_params=kwargs.get("drop_params"),
prompt_id=prompt_id,
prompt_variables=prompt_variables,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,

View file

@ -10,9 +10,8 @@ model_list:
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: chatbot_actions - model_name: chatbot_actions
litellm_params: litellm_params:
model: langfuse/azure/gpt-4o model: langfuse/gpt-3.5-turbo
api_base: "os.environ/AZURE_API_BASE" api_key: os.environ/OPENAI_API_KEY
api_key: "os.environ/AZURE_API_KEY"
tpm: 1000000 tpm: 1000000
prompt_id: "jokes" prompt_id: "jokes"
- model_name: openai-deepseek - model_name: openai-deepseek

View file

@ -2565,7 +2565,6 @@ class ProxyConfig:
for response in responses: for response in responses:
if response is not None: if response is not None:
param_name = getattr(response, "param_name", None) param_name = getattr(response, "param_name", None)
verbose_proxy_logger.info(f"loading {param_name} settings from db")
if param_name == "litellm_settings": if param_name == "litellm_settings":
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"litellm_settings: {response.param_value}" f"litellm_settings: {response.param_value}"

View file

@ -47,9 +47,6 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
@ -120,6 +117,8 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
ModelResponse, ModelResponse,
Rules,
function_setup,
get_llm_provider, get_llm_provider,
get_non_default_completion_params, get_non_default_completion_params,
get_secret, get_secret,
@ -1457,6 +1456,17 @@ class Router:
messages: List[AllMessageValues], messages: List[AllMessageValues],
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
): ):
litellm_logging_object = kwargs.get("litellm_logging_obj", None)
if litellm_logging_object is None:
litellm_logging_object, kwargs = function_setup(
**{
"original_function": "acompletion",
"rules_obj": Rules(),
"start_time": get_utc_datetime(),
**kwargs,
}
)
litellm_logging_object = cast(LiteLLMLogging, litellm_logging_object)
prompt_management_deployment = self.get_available_deployment( prompt_management_deployment = self.get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "prompt"}], messages=[{"role": "user", "content": "prompt"}],
@ -1475,38 +1485,31 @@ class Router:
"prompt_variables", None "prompt_variables", None
) )
if litellm_model is None or "/" not in litellm_model: if prompt_id is None or not isinstance(prompt_id, str):
raise ValueError( raise ValueError(
f"Model is not a custom logger compatible callback. Got={litellm_model}" f"Prompt ID is not set or not a string. Got={prompt_id}, type={type(prompt_id)}"
) )
if prompt_variables is not None and not isinstance(prompt_variables, dict):
custom_logger_compatible_callback = litellm_model.split("/", 1)[0]
split_litellm_model = litellm_model.split("/", 1)[1]
custom_logger = _init_custom_logger_compatible_class(
logging_integration=custom_logger_compatible_callback,
internal_usage_cache=None,
llm_router=None,
)
if custom_logger is None:
raise ValueError( raise ValueError(
f"Custom logger is not initialized. Got={custom_logger_compatible_callback}" f"Prompt variables is set but not a dictionary. Got={prompt_variables}, type={type(prompt_variables)}"
) )
model, messages, optional_params = ( model, messages, optional_params = (
await custom_logger.async_get_chat_completion_prompt( litellm_logging_object.get_chat_completion_prompt(
model=split_litellm_model, model=litellm_model,
messages=messages, messages=messages,
non_default_params=get_non_default_completion_params(kwargs=kwargs), non_default_params=get_non_default_completion_params(kwargs=kwargs),
prompt_id=prompt_id, prompt_id=prompt_id,
prompt_variables=prompt_variables, prompt_variables=prompt_variables,
dynamic_callback_params={},
) )
) )
kwargs = {**kwargs, **optional_params} kwargs = {**kwargs, **optional_params}
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["litellm_logging_obj"] = litellm_logging_object
kwargs["prompt_id"] = prompt_id
kwargs["prompt_variables"] = prompt_variables
_model_list = self.get_model_list(model_name=model) _model_list = self.get_model_list(model_name=model)
if _model_list is None or len(_model_list) == 0: # if direct call to model if _model_list is None or len(_model_list) == 0: # if direct call to model

View file

@ -1457,9 +1457,14 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_end_user_id: Optional[str] user_api_key_end_user_id: Optional[str]
class StandardLoggingPromptManagementMetadata(TypedDict):
prompt_id: Optional[str]
prompt_variables: Optional[dict]
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
""" """
Specific metadata k,v pairs logged to integration for easier cost tracking Specific metadata k,v pairs logged to integration for easier cost tracking and prompt management
""" """
spend_logs_metadata: Optional[ spend_logs_metadata: Optional[
@ -1467,6 +1472,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
] # special param to log k,v pairs to spendlogs for a call ] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str] requester_ip_address: Optional[str]
requester_metadata: Optional[dict] requester_metadata: Optional[dict]
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
class StandardLoggingAdditionalHeaders(TypedDict, total=False): class StandardLoggingAdditionalHeaders(TypedDict, total=False):

View file

@ -2034,6 +2034,8 @@ def get_litellm_params(
litellm_metadata: Optional[dict] = None, litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None, disable_add_transform_inline_image_block: Optional[bool] = None,
drop_params: Optional[bool] = None, drop_params: Optional[bool] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2068,6 +2070,8 @@ def get_litellm_params(
"litellm_metadata": litellm_metadata, "litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block, "disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
"drop_params": drop_params, "drop_params": drop_params,
"prompt_id": prompt_id,
"prompt_variables": prompt_variables,
} }
return litellm_params return litellm_params

View file

@ -5,6 +5,10 @@ import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import litellm.litellm_core_utils
import litellm.litellm_core_utils.prompt_templates
import litellm.litellm_core_utils.prompt_templates.factory
load_dotenv() load_dotenv()
import io import io
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -16,6 +20,8 @@ import pytest
import litellm import litellm
from litellm import get_optional_params from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.vertex_ai.gemini.transformation import _process_gemini_image
from litellm.types.llms.vertex_ai import PartType, BlobType
import httpx import httpx
@ -1240,3 +1246,53 @@ def test_vertex_embedding_url(model, expected_url):
assert url == expected_url assert url == expected_url
assert endpoint == "predict" assert endpoint == "predict"
import pytest
from unittest.mock import Mock, patch
from typing import Dict, Any
# Import your actual module here
# from your_module import _process_gemini_image, PartType, FileDataType, BlobType
@pytest.fixture
def mock_convert_url_to_base64():
with patch(
"litellm.litellm_core_utils.prompt_templates.factory.convert_url_to_base64",
) as mock:
# Setup the mock to return a valid image object
mock.return_value = "..."
yield mock
@pytest.fixture
def mock_blob():
return Mock(spec=BlobType)
@pytest.mark.parametrize(
"http_url",
[
"http://img1.etsystatic.com/260/0/7813604/il_fullxfull.4226713999_q86e.jpg",
"http://example.com/image.jpg",
"http://subdomain.domain.com/path/to/image.png",
],
)
def test_process_gemini_image_http_url(
http_url: str, mock_convert_url_to_base64: Mock, mock_blob: Mock
) -> None:
"""
Test that _process_gemini_image correctly handles HTTP URLs.
Args:
http_url: Test HTTP URL
mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function
mock_blob: Mocked BlobType instance
"""
# Arrange
expected_image_data = "..."
mock_convert_url_to_base64.return_value = expected_image_data
# Act
result = _process_gemini_image(http_url)

View file

@ -500,6 +500,37 @@ def test_get_supported_openai_params() -> None:
assert get_supported_openai_params("nonexistent") is None assert get_supported_openai_params("nonexistent") is None
def test_get_chat_completion_prompt():
"""
Unit test to ensure get_chat_completion_prompt updates messages in logging object.
"""
from litellm.litellm_core_utils.litellm_logging import Logging
litellm_logging_obj = Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
updated_message = "hello world"
litellm_logging_obj.get_chat_completion_prompt(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": updated_message}],
non_default_params={},
prompt_id="1234",
prompt_variables=None,
)
assert litellm_logging_obj.messages == [
{"role": "user", "content": updated_message}
]
def test_redact_msgs_from_logs(): def test_redact_msgs_from_logs():
""" """
Tests that turn_off_message_logging does not modify the response_obj Tests that turn_off_message_logging does not modify the response_obj

View file

@ -259,6 +259,7 @@ def validate_redacted_message_span_attributes(span):
"gen_ai.response.id", "gen_ai.response.id",
"gen_ai.response.model", "gen_ai.response.model",
"llm.usage.total_tokens", "llm.usage.total_tokens",
"metadata.prompt_management_metadata",
"gen_ai.usage.completion_tokens", "gen_ai.usage.completion_tokens",
"gen_ai.usage.prompt_tokens", "gen_ai.usage.prompt_tokens",
"metadata.user_api_key_hash", "metadata.user_api_key_hash",