diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d2e65742c..15f7f59fa 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -201,6 +201,7 @@ class Logging: start_time, litellm_call_id: str, function_id: str, + litellm_trace_id: Optional[str] = None, dynamic_input_callbacks: Optional[ List[Union[str, Callable, CustomLogger]] ] = None, @@ -238,6 +239,7 @@ class Logging: self.start_time = start_time # log the call start time self.call_type = call_type self.litellm_call_id = litellm_call_id + self.litellm_trace_id = litellm_trace_id self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response self.sync_streaming_chunks: List[Any] = ( @@ -274,6 +276,11 @@ class Logging: self.completion_start_time: Optional[datetime.datetime] = None self._llm_caching_handler: Optional[LLMCachingHandler] = None + self.model_call_details = { + "litellm_trace_id": litellm_trace_id, + "litellm_call_id": litellm_call_id, + } + def process_dynamic_callbacks(self): """ Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks @@ -381,21 +388,23 @@ class Logging: self.logger_fn = litellm_params.get("logger_fn", None) verbose_logger.debug(f"self.optional_params: {self.optional_params}") - self.model_call_details = { - "model": self.model, - "messages": self.messages, - "optional_params": self.optional_params, - "litellm_params": self.litellm_params, - "start_time": self.start_time, - "stream": self.stream, - "user": user, - "call_type": str(self.call_type), - "litellm_call_id": self.litellm_call_id, - "completion_start_time": self.completion_start_time, - "standard_callback_dynamic_params": self.standard_callback_dynamic_params, - **self.optional_params, - **additional_params, - } + self.model_call_details.update( + { + "model": self.model, + "messages": self.messages, + "optional_params": self.optional_params, + "litellm_params": self.litellm_params, + "start_time": self.start_time, + "stream": self.stream, + "user": user, + "call_type": str(self.call_type), + "litellm_call_id": self.litellm_call_id, + "completion_start_time": self.completion_start_time, + "standard_callback_dynamic_params": self.standard_callback_dynamic_params, + **self.optional_params, + **additional_params, + } + ) ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation if "stream_options" in additional_params: @@ -2806,6 +2815,7 @@ def get_standard_logging_object_payload( payload: StandardLoggingPayload = StandardLoggingPayload( id=str(id), + trace_id=kwargs.get("litellm_trace_id"), # type: ignore call_type=call_type or "", cache_hit=cache_hit, status=status, diff --git a/litellm/main.py b/litellm/main.py index ad8f791c3..543a93eea 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915 azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), user_continue_message=kwargs.get("user_continue_message"), base_model=base_model, + litellm_trace_id=kwargs.get("litellm_trace_id"), ) logging.update_environment_variables( model=model, diff --git a/litellm/router.py b/litellm/router.py index 4735d422b..400347ff2 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -679,9 +679,8 @@ class Router: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._completion - kwargs.get("request_timeout", self.timeout) - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) + response = self.function_with_fallbacks(**kwargs) return response except Exception as e: @@ -783,8 +782,7 @@ class Router: kwargs["stream"] = stream kwargs["original_function"] = self._acompletion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) request_priority = kwargs.get("priority") or self.default_priority @@ -948,6 +946,17 @@ class Router: self.fail_calls[model_name] += 1 raise e + def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None: + """ + Adds/updates to kwargs: + - num_retries + - litellm_trace_id + - metadata + """ + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + kwargs.setdefault("litellm_trace_id", str(uuid.uuid4())) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: """ Adds default litellm params to kwargs, if set. @@ -1511,9 +1520,7 @@ class Router: kwargs["model"] = model kwargs["file"] = file kwargs["original_function"] = self._atranscription - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response @@ -1688,9 +1695,7 @@ class Router: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._arerank - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) @@ -1839,9 +1844,7 @@ class Router: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._atext_completion - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response @@ -2112,9 +2115,7 @@ class Router: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._aembedding - kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.get("request_timeout", self.timeout) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: @@ -2616,6 +2617,7 @@ class Router: content_policy_fallbacks: Optional[List] = kwargs.get( "content_policy_fallbacks", self.content_policy_fallbacks ) + try: self._handle_mock_testing_fallbacks( kwargs=kwargs, diff --git a/litellm/types/router.py b/litellm/types/router.py index 6119ca4b7..bb93aaa63 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel): max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None + ## LOGGING PARAMS ## + litellm_trace_id: Optional[str] = None ## UNIFIED PROJECT/REGION ## region_name: Optional[str] = None ## VERTEX AI ## @@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel): None # timeout when making stream=True calls, if str, pass in as os.environ/ ), organization: Optional[str] = None, # for openai orgs + ## LOGGING PARAMS ## + litellm_trace_id: Optional[str] = None, ## UNIFIED PROJECT/REGION ## region_name: Optional[str] = None, ## VERTEX AI ## diff --git a/litellm/types/utils.py b/litellm/types/utils.py index e3df357be..d02129681 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False): all_litellm_params = [ "metadata", + "litellm_trace_id", "tags", "acompletion", "aimg_generation", @@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"] class StandardLoggingPayload(TypedDict): id: str + trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries) call_type: str response_cost: float response_cost_failure_debug_info: Optional[ diff --git a/litellm/utils.py b/litellm/utils.py index 802bcfc04..fdb533e4e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915 messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], + litellm_trace_id=kwargs.get("litellm_trace_id"), function_id=function_id or "", call_type=call_type, start_time=start_time, @@ -2056,6 +2057,7 @@ def get_litellm_params( azure_ad_token_provider=None, user_continue_message=None, base_model=None, + litellm_trace_id=None, ): litellm_params = { "acompletion": acompletion, @@ -2084,6 +2086,7 @@ def get_litellm_params( "user_continue_message": user_continue_message, "base_model": base_model or _get_base_model_from_litellm_call_metadata(metadata=metadata), + "litellm_trace_id": litellm_trace_id, } return litellm_params diff --git a/tests/local_testing/test_custom_callback_input.py b/tests/local_testing/test_custom_callback_input.py index 1744d3891..9b7b6d532 100644 --- a/tests/local_testing/test_custom_callback_input.py +++ b/tests/local_testing/test_custom_callback_input.py @@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode): print(f"standard_logging_object usage: {built_response.usage}") except litellm.InternalServerError: pass + + +def test_standard_logging_retries(): + """ + know if a request was retried. + """ + from litellm.types.utils import StandardLoggingPayload + from litellm.router import Router + + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "openai/gpt-3.5-turbo", + "api_key": "test-api-key", + }, + } + ] + ) + + with patch.object( + customHandler, "log_failure_event", new=MagicMock() + ) as mock_client: + try: + router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + num_retries=1, + mock_response="litellm.RateLimitError", + ) + except litellm.RateLimitError: + pass + + assert mock_client.call_count == 2 + assert ( + mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][ + "trace_id" + ] + is not None + ) + assert ( + mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][ + "trace_id" + ] + == mock_client.call_args_list[1].kwargs["kwargs"][ + "standard_logging_object" + ]["trace_id"] + ) diff --git a/tests/local_testing/test_router_utils.py b/tests/local_testing/test_router_utils.py index 538ab4d0b..d266cfbd9 100644 --- a/tests/local_testing/test_router_utils.py +++ b/tests/local_testing/test_router_utils.py @@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dotenv import load_dotenv +from unittest.mock import patch, MagicMock, AsyncMock load_dotenv() @@ -83,3 +84,93 @@ def test_returned_settings(): except Exception: print(traceback.format_exc()) pytest.fail("An error occurred - " + traceback.format_exc()) + + +from litellm.types.utils import CallTypes + + +def test_update_kwargs_before_fallbacks_unit_test(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + } + ], + ) + + kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]} + + router._update_kwargs_before_fallbacks( + model="gpt-3.5-turbo", + kwargs=kwargs, + ) + + assert kwargs["litellm_trace_id"] is not None + + +@pytest.mark.parametrize( + "call_type", + [ + CallTypes.acompletion, + CallTypes.atext_completion, + CallTypes.aembedding, + CallTypes.arerank, + CallTypes.atranscription, + ], +) +@pytest.mark.asyncio +async def test_update_kwargs_before_fallbacks(call_type): + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + } + ], + ) + + if call_type.value.startswith("a"): + with patch.object(router, "async_function_with_fallbacks") as mock_client: + if call_type.value == "acompletion": + input_kwarg = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + elif ( + call_type.value == "atext_completion" + or call_type.value == "aimage_generation" + ): + input_kwarg = { + "prompt": "Hello, how are you?", + } + elif call_type.value == "aembedding" or call_type.value == "arerank": + input_kwarg = { + "input": "Hello, how are you?", + } + elif call_type.value == "atranscription": + input_kwarg = { + "file": "path/to/file", + } + else: + input_kwarg = {} + + await getattr(router, call_type.value)( + model="gpt-3.5-turbo", + **input_kwarg, + ) + + mock_client.assert_called_once() + + print(mock_client.call_args.kwargs) + assert mock_client.call_args.kwargs["litellm_trace_id"] is not None