forked from phoenix/litellm-mirror
Litellm router trace (#6742)
* feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks * feat(router.py): log trace id across retry/fallback logic allows grouping llm logs for the same request * test: fix tests * fix: fix test * fix(transformation.py): only set non-none stop_sequences
This commit is contained in:
parent
9053a6a203
commit
02b6f69004
8 changed files with 197 additions and 32 deletions
|
@ -201,6 +201,7 @@ class Logging:
|
||||||
start_time,
|
start_time,
|
||||||
litellm_call_id: str,
|
litellm_call_id: str,
|
||||||
function_id: str,
|
function_id: str,
|
||||||
|
litellm_trace_id: Optional[str] = None,
|
||||||
dynamic_input_callbacks: Optional[
|
dynamic_input_callbacks: Optional[
|
||||||
List[Union[str, Callable, CustomLogger]]
|
List[Union[str, Callable, CustomLogger]]
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -238,6 +239,7 @@ class Logging:
|
||||||
self.start_time = start_time # log the call start time
|
self.start_time = start_time # log the call start time
|
||||||
self.call_type = call_type
|
self.call_type = call_type
|
||||||
self.litellm_call_id = litellm_call_id
|
self.litellm_call_id = litellm_call_id
|
||||||
|
self.litellm_trace_id = litellm_trace_id
|
||||||
self.function_id = function_id
|
self.function_id = function_id
|
||||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||||
self.sync_streaming_chunks: List[Any] = (
|
self.sync_streaming_chunks: List[Any] = (
|
||||||
|
@ -274,6 +276,11 @@ class Logging:
|
||||||
self.completion_start_time: Optional[datetime.datetime] = None
|
self.completion_start_time: Optional[datetime.datetime] = None
|
||||||
self._llm_caching_handler: Optional[LLMCachingHandler] = None
|
self._llm_caching_handler: Optional[LLMCachingHandler] = None
|
||||||
|
|
||||||
|
self.model_call_details = {
|
||||||
|
"litellm_trace_id": litellm_trace_id,
|
||||||
|
"litellm_call_id": litellm_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
def process_dynamic_callbacks(self):
|
def process_dynamic_callbacks(self):
|
||||||
"""
|
"""
|
||||||
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
|
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
|
||||||
|
@ -381,7 +388,8 @@ class Logging:
|
||||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||||
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
||||||
|
|
||||||
self.model_call_details = {
|
self.model_call_details.update(
|
||||||
|
{
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": self.messages,
|
"messages": self.messages,
|
||||||
"optional_params": self.optional_params,
|
"optional_params": self.optional_params,
|
||||||
|
@ -396,6 +404,7 @@ class Logging:
|
||||||
**self.optional_params,
|
**self.optional_params,
|
||||||
**additional_params,
|
**additional_params,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
|
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
|
||||||
if "stream_options" in additional_params:
|
if "stream_options" in additional_params:
|
||||||
|
@ -2806,6 +2815,7 @@ def get_standard_logging_object_payload(
|
||||||
|
|
||||||
payload: StandardLoggingPayload = StandardLoggingPayload(
|
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||||
id=str(id),
|
id=str(id),
|
||||||
|
trace_id=kwargs.get("litellm_trace_id"), # type: ignore
|
||||||
call_type=call_type or "",
|
call_type=call_type or "",
|
||||||
cache_hit=cache_hit,
|
cache_hit=cache_hit,
|
||||||
status=status,
|
status=status,
|
||||||
|
|
|
@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||||
user_continue_message=kwargs.get("user_continue_message"),
|
user_continue_message=kwargs.get("user_continue_message"),
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -679,9 +679,8 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._completion
|
kwargs["original_function"] = self._completion
|
||||||
kwargs.get("request_timeout", self.timeout)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = self.function_with_fallbacks(**kwargs)
|
response = self.function_with_fallbacks(**kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -783,8 +782,7 @@ class Router:
|
||||||
kwargs["stream"] = stream
|
kwargs["stream"] = stream
|
||||||
kwargs["original_function"] = self._acompletion
|
kwargs["original_function"] = self._acompletion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
|
|
||||||
request_priority = kwargs.get("priority") or self.default_priority
|
request_priority = kwargs.get("priority") or self.default_priority
|
||||||
|
|
||||||
|
@ -948,6 +946,17 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
|
||||||
|
"""
|
||||||
|
Adds/updates to kwargs:
|
||||||
|
- num_retries
|
||||||
|
- litellm_trace_id
|
||||||
|
- metadata
|
||||||
|
"""
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Adds default litellm params to kwargs, if set.
|
Adds default litellm params to kwargs, if set.
|
||||||
|
@ -1511,9 +1520,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["file"] = file
|
kwargs["file"] = file
|
||||||
kwargs["original_function"] = self._atranscription
|
kwargs["original_function"] = self._atranscription
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -1688,9 +1695,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["input"] = input
|
kwargs["input"] = input
|
||||||
kwargs["original_function"] = self._arerank
|
kwargs["original_function"] = self._arerank
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -1839,9 +1844,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["original_function"] = self._atext_completion
|
kwargs["original_function"] = self._atext_completion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -2112,9 +2115,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["input"] = input
|
kwargs["input"] = input
|
||||||
kwargs["original_function"] = self._aembedding
|
kwargs["original_function"] = self._aembedding
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2616,6 +2617,7 @@ class Router:
|
||||||
content_policy_fallbacks: Optional[List] = kwargs.get(
|
content_policy_fallbacks: Optional[List] = kwargs.get(
|
||||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._handle_mock_testing_fallbacks(
|
self._handle_mock_testing_fallbacks(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
|
|
@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||||
|
## LOGGING PARAMS ##
|
||||||
|
litellm_trace_id: Optional[str] = None
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
|
@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
),
|
),
|
||||||
organization: Optional[str] = None, # for openai orgs
|
organization: Optional[str] = None, # for openai orgs
|
||||||
|
## LOGGING PARAMS ##
|
||||||
|
litellm_trace_id: Optional[str] = None,
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
region_name: Optional[str] = None,
|
region_name: Optional[str] = None,
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
|
|
|
@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False):
|
||||||
|
|
||||||
all_litellm_params = [
|
all_litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
|
"litellm_trace_id",
|
||||||
"tags",
|
"tags",
|
||||||
"acompletion",
|
"acompletion",
|
||||||
"aimg_generation",
|
"aimg_generation",
|
||||||
|
@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"]
|
||||||
|
|
||||||
class StandardLoggingPayload(TypedDict):
|
class StandardLoggingPayload(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
|
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
|
||||||
call_type: str
|
call_type: str
|
||||||
response_cost: float
|
response_cost: float
|
||||||
response_cost_failure_debug_info: Optional[
|
response_cost_failure_debug_info: Optional[
|
||||||
|
|
|
@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
litellm_call_id=kwargs["litellm_call_id"],
|
litellm_call_id=kwargs["litellm_call_id"],
|
||||||
|
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||||
function_id=function_id or "",
|
function_id=function_id or "",
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -2056,6 +2057,7 @@ def get_litellm_params(
|
||||||
azure_ad_token_provider=None,
|
azure_ad_token_provider=None,
|
||||||
user_continue_message=None,
|
user_continue_message=None,
|
||||||
base_model=None,
|
base_model=None,
|
||||||
|
litellm_trace_id=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2084,6 +2086,7 @@ def get_litellm_params(
|
||||||
"user_continue_message": user_continue_message,
|
"user_continue_message": user_continue_message,
|
||||||
"base_model": base_model
|
"base_model": base_model
|
||||||
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
||||||
|
"litellm_trace_id": litellm_trace_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
|
@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
|
||||||
print(f"standard_logging_object usage: {built_response.usage}")
|
print(f"standard_logging_object usage: {built_response.usage}")
|
||||||
except litellm.InternalServerError:
|
except litellm.InternalServerError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_standard_logging_retries():
|
||||||
|
"""
|
||||||
|
know if a request was retried.
|
||||||
|
"""
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-3.5-turbo",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
customHandler, "log_failure_event", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
num_retries=1,
|
||||||
|
mock_response="litellm.RateLimitError",
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert mock_client.call_count == 2
|
||||||
|
assert (
|
||||||
|
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"trace_id"
|
||||||
|
]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"trace_id"
|
||||||
|
]
|
||||||
|
== mock_client.call_args_list[1].kwargs["kwargs"][
|
||||||
|
"standard_logging_object"
|
||||||
|
]["trace_id"]
|
||||||
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -83,3 +84,93 @@ def test_returned_settings():
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
pytest.fail("An error occurred - " + traceback.format_exc())
|
pytest.fail("An error occurred - " + traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.types.utils import CallTypes
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_kwargs_before_fallbacks_unit_test():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
|
||||||
|
|
||||||
|
router._update_kwargs_before_fallbacks(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kwargs["litellm_trace_id"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"call_type",
|
||||||
|
[
|
||||||
|
CallTypes.acompletion,
|
||||||
|
CallTypes.atext_completion,
|
||||||
|
CallTypes.aembedding,
|
||||||
|
CallTypes.arerank,
|
||||||
|
CallTypes.atranscription,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_kwargs_before_fallbacks(call_type):
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if call_type.value.startswith("a"):
|
||||||
|
with patch.object(router, "async_function_with_fallbacks") as mock_client:
|
||||||
|
if call_type.value == "acompletion":
|
||||||
|
input_kwarg = {
|
||||||
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
}
|
||||||
|
elif (
|
||||||
|
call_type.value == "atext_completion"
|
||||||
|
or call_type.value == "aimage_generation"
|
||||||
|
):
|
||||||
|
input_kwarg = {
|
||||||
|
"prompt": "Hello, how are you?",
|
||||||
|
}
|
||||||
|
elif call_type.value == "aembedding" or call_type.value == "arerank":
|
||||||
|
input_kwarg = {
|
||||||
|
"input": "Hello, how are you?",
|
||||||
|
}
|
||||||
|
elif call_type.value == "atranscription":
|
||||||
|
input_kwarg = {
|
||||||
|
"file": "path/to/file",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
input_kwarg = {}
|
||||||
|
|
||||||
|
await getattr(router, call_type.value)(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
**input_kwarg,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
print(mock_client.call_args.kwargs)
|
||||||
|
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue