(Observability) - Add more detailed dd tracing on Proxy Auth, Bedrock Auth (#8693)

* add dd tracer

* fix dd tracing

* add @tracer.wrap() on def user_api_key_auth

* add async_function_with_retries

* remove dead code

* add tracer.wrap on base aws llm

* add tracer.wrap on base aws llm

* fix print verbose

* fix dd tracing

* trace base aws llm

* fix test base aws llm

* fix converse transform

* test base aws llm

* BASE_AWS_LLM_PATH

* BASE_AWS_LLM_PATH

* test dd tracing
This commit is contained in:
Ishaan Jaff 2025-02-20 18:00:41 -08:00 committed by GitHub
parent 11a1692c63
commit f940392971
9 changed files with 256 additions and 42 deletions

View file

@ -0,0 +1,53 @@
"""
Handles Tracing on DataDog Traces.
If the ddtrace package is not installed, the tracer will be a no-op.
"""
from contextlib import contextmanager
try:
from ddtrace import tracer as dd_tracer
has_ddtrace = True
except ImportError:
has_ddtrace = False
@contextmanager
def null_tracer(name, **kwargs):
class NullSpan:
def __enter__(self):
return self
def __exit__(self, *args):
pass
def finish(self):
pass
yield NullSpan()
class NullTracer:
def trace(self, name, **kwargs):
class NullSpan:
def __enter__(self):
return self
def __exit__(self, *args):
pass
def finish(self):
pass
return NullSpan()
def wrap(self, name=None, **kwargs):
def decorator(f):
return f
return decorator
dd_tracer = NullTracer()
# Export the tracer instance
tracer = dd_tracer

View file

@ -9,6 +9,7 @@ from pydantic import BaseModel
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.secret_managers.main import get_secret, get_secret_str from litellm.secret_managers.main import get_secret, get_secret_str
if TYPE_CHECKING: if TYPE_CHECKING:
@ -63,6 +64,7 @@ class BaseAWSLLM:
credential_str = json.dumps(credential_args, sort_keys=True) credential_str = json.dumps(credential_args, sort_keys=True)
return hashlib.sha256(credential_str.encode()).hexdigest() return hashlib.sha256(credential_str.encode()).hexdigest()
@tracer.wrap()
def get_credentials( def get_credentials(
self, self,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
@ -200,6 +202,7 @@ class BaseAWSLLM:
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl) self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
return credentials return credentials
@tracer.wrap()
def _auth_with_web_identity_token( def _auth_with_web_identity_token(
self, self,
aws_web_identity_token: str, aws_web_identity_token: str,
@ -230,6 +233,7 @@ class BaseAWSLLM:
status_code=401, status_code=401,
) )
with tracer.trace("boto3.client(sts)"):
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
region_name=aws_region_name, region_name=aws_region_name,
@ -258,11 +262,13 @@ class BaseAWSLLM:
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}" f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
) )
with tracer.trace("boto3.Session(**iam_creds_dict)"):
session = boto3.Session(**iam_creds_dict) session = boto3.Session(**iam_creds_dict)
iam_creds = session.get_credentials() iam_creds = session.get_credentials()
return iam_creds, self._get_default_ttl_for_boto3_credentials() return iam_creds, self._get_default_ttl_for_boto3_credentials()
@tracer.wrap()
def _auth_with_aws_role( def _auth_with_aws_role(
self, self,
aws_access_key_id: Optional[str], aws_access_key_id: Optional[str],
@ -276,6 +282,7 @@ class BaseAWSLLM:
import boto3 import boto3
from botocore.credentials import Credentials from botocore.credentials import Credentials
with tracer.trace("boto3.client(sts)"):
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL] aws_access_key_id=aws_access_key_id, # [OPTIONAL]
@ -288,7 +295,6 @@ class BaseAWSLLM:
# Extract the credentials from the response and convert to Session Credentials # Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"] sts_credentials = sts_response["Credentials"]
credentials = Credentials( credentials = Credentials(
access_key=sts_credentials["AccessKeyId"], access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"], secret_key=sts_credentials["SecretAccessKey"],
@ -301,6 +307,7 @@ class BaseAWSLLM:
sts_ttl = (sts_expiry - current_time).total_seconds() - 60 sts_ttl = (sts_expiry - current_time).total_seconds() - 60
return credentials, sts_ttl return credentials, sts_ttl
@tracer.wrap()
def _auth_with_aws_profile( def _auth_with_aws_profile(
self, aws_profile_name: str self, aws_profile_name: str
) -> Tuple[Credentials, Optional[int]]: ) -> Tuple[Credentials, Optional[int]]:
@ -310,9 +317,11 @@ class BaseAWSLLM:
import boto3 import boto3
# uses auth values from AWS profile usually stored in ~/.aws/credentials # uses auth values from AWS profile usually stored in ~/.aws/credentials
with tracer.trace("boto3.Session(profile_name=aws_profile_name)"):
client = boto3.Session(profile_name=aws_profile_name) client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials(), None return client.get_credentials(), None
@tracer.wrap()
def _auth_with_aws_session_token( def _auth_with_aws_session_token(
self, self,
aws_access_key_id: str, aws_access_key_id: str,
@ -333,6 +342,7 @@ class BaseAWSLLM:
return credentials, None return credentials, None
@tracer.wrap()
def _auth_with_access_key_and_secret_key( def _auth_with_access_key_and_secret_key(
self, self,
aws_access_key_id: str, aws_access_key_id: str,
@ -345,7 +355,9 @@ class BaseAWSLLM:
import boto3 import boto3
# Check if credentials are already in cache. These credentials have no expiry time. # Check if credentials are already in cache. These credentials have no expiry time.
with tracer.trace(
"boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)"
):
session = boto3.Session( session = boto3.Session(
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
@ -355,16 +367,19 @@ class BaseAWSLLM:
credentials = session.get_credentials() credentials = session.get_credentials()
return credentials, self._get_default_ttl_for_boto3_credentials() return credentials, self._get_default_ttl_for_boto3_credentials()
@tracer.wrap()
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]: def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
""" """
Authenticate with AWS Environment Variables Authenticate with AWS Environment Variables
""" """
import boto3 import boto3
with tracer.trace("boto3.Session()"):
session = boto3.Session() session = boto3.Session()
credentials = session.get_credentials() credentials = session.get_credentials()
return credentials, None return credentials, None
@tracer.wrap()
def _get_default_ttl_for_boto3_credentials(self) -> int: def _get_default_ttl_for_boto3_credentials(self) -> int:
""" """
Get the default TTL for boto3 credentials Get the default TTL for boto3 credentials
@ -475,6 +490,7 @@ class BaseAWSLLM:
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
) )
@tracer.wrap()
def get_request_headers( def get_request_headers(
self, self,
credentials: Credentials, credentials: Credentials,

View file

@ -1,6 +1,6 @@
import json import json
import urllib import urllib
from typing import Any, Callable, Optional, Union from typing import Any, Optional, Union
import httpx import httpx
@ -60,7 +60,6 @@ def make_sync_call(
api_key="", api_key="",
data=data, data=data,
messages=messages, messages=messages,
print_verbose=litellm.print_verbose,
encoding=litellm.encoding, encoding=litellm.encoding,
) # type: ignore ) # type: ignore
completion_stream: Any = MockResponseIterator( completion_stream: Any = MockResponseIterator(
@ -102,7 +101,6 @@ class BedrockConverseLLM(BaseAWSLLM):
messages: list, messages: list,
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
encoding, encoding,
logging_obj, logging_obj,
@ -170,7 +168,6 @@ class BedrockConverseLLM(BaseAWSLLM):
messages: list, messages: list,
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
encoding, encoding,
logging_obj: LiteLLMLoggingObject, logging_obj: LiteLLMLoggingObject,
@ -247,7 +244,6 @@ class BedrockConverseLLM(BaseAWSLLM):
api_key="", api_key="",
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
) )
@ -259,7 +255,6 @@ class BedrockConverseLLM(BaseAWSLLM):
api_base: Optional[str], api_base: Optional[str],
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable,
encoding, encoding,
logging_obj: LiteLLMLoggingObject, logging_obj: LiteLLMLoggingObject,
optional_params: dict, optional_params: dict,
@ -271,11 +266,6 @@ class BedrockConverseLLM(BaseAWSLLM):
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
): ):
try:
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ## ## SETUP ##
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None) modelId = optional_params.pop("model_id", None)
@ -367,7 +357,6 @@ class BedrockConverseLLM(BaseAWSLLM):
messages=messages, messages=messages,
api_base=proxy_endpoint_url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
@ -387,7 +376,6 @@ class BedrockConverseLLM(BaseAWSLLM):
messages=messages, messages=messages,
api_base=proxy_endpoint_url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
@ -489,7 +477,6 @@ class BedrockConverseLLM(BaseAWSLLM):
api_key="", api_key="",
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
) )

View file

@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse`
import copy import copy
import time import time
import types import types
from typing import Callable, List, Literal, Optional, Tuple, Union, cast, overload from typing import List, Literal, Optional, Tuple, Union, cast, overload
import httpx import httpx
@ -542,7 +542,6 @@ class AmazonConverseConfig(BaseConfig):
api_key=api_key, api_key=api_key,
data=request_data, data=request_data,
messages=messages, messages=messages,
print_verbose=None,
encoding=encoding, encoding=encoding,
) )
@ -557,7 +556,6 @@ class AmazonConverseConfig(BaseConfig):
api_key: Optional[str], api_key: Optional[str],
data: Union[dict, str], data: Union[dict, str],
messages: List, messages: List,
print_verbose: Optional[Callable],
encoding, encoding,
) -> ModelResponse: ) -> ModelResponse:
## LOGGING ## LOGGING

View file

@ -2638,7 +2638,6 @@ def completion( # type: ignore # noqa: PLR0915
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, # type: ignore litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,

View file

@ -20,6 +20,7 @@ import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm._service_logger import ServiceLogging from litellm._service_logger import ServiceLogging
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
_cache_key_object, _cache_key_object,
@ -897,6 +898,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
# Check 3. If token is expired # Check 3. If token is expired
if valid_token.expires is not None: if valid_token.expires is not None:
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
if isinstance(valid_token.expires, datetime):
expiry_time = valid_token.expires
else:
expiry_time = datetime.fromisoformat(valid_token.expires) expiry_time = datetime.fromisoformat(valid_token.expires)
if ( if (
expiry_time.tzinfo is None expiry_time.tzinfo is None
@ -1127,6 +1131,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
) )
@tracer.wrap()
async def user_api_key_auth( async def user_api_key_auth(
request: Request, request: Request,
api_key: str = fastapi.Security(api_key_header), api_key: str = fastapi.Security(api_key_header),

View file

@ -48,6 +48,7 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.asyncify import run_async_function
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.dd_tracing import tracer
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
@ -2857,6 +2858,7 @@ class Router:
#### [END] ASSISTANTS API #### #### [END] ASSISTANTS API ####
@tracer.wrap()
async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915
""" """
Try calling the function_with_retries Try calling the function_with_retries
@ -3127,6 +3129,7 @@ class Router:
Context_Policy_Fallbacks={content_policy_fallbacks}", Context_Policy_Fallbacks={content_policy_fallbacks}",
) )
@tracer.wrap()
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
verbose_router_logger.debug("Inside async function with retries.") verbose_router_logger.debug("Inside async function with retries.")
original_function = kwargs.pop("original_function") original_function = kwargs.pop("original_function")

View file

@ -0,0 +1,53 @@
import json
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.litellm_core_utils.dd_tracing import dd_tracer
def test_dd_tracer_when_package_exists():
with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", True):
# Test the trace context manager
with dd_tracer.trace("test_operation") as span:
assert span is not None
# Test the wrapper decorator
@dd_tracer.wrap(name="test_function")
def sample_function():
return "test"
result = sample_function()
assert result == "test"
def test_dd_tracer_when_package_not_exists():
with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", False):
# Test the trace context manager with null tracer
with dd_tracer.trace("test_operation") as span:
assert span is not None
# Verify null span methods don't raise exceptions
span.finish()
# Test the wrapper decorator with null tracer
@dd_tracer.wrap(name="test_function")
def sample_function():
return "test"
result = sample_function()
assert result == "test"
def test_null_tracer_context_manager():
with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", False):
# Test that the context manager works without raising exceptions
with dd_tracer.trace("test_operation") as span:
# Test that we can call methods on the null span
span.finish()
assert True # If we get here without exceptions, the test passes

View file

@ -0,0 +1,100 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from datetime import datetime, timezone
from typing import Any, Dict
from unittest.mock import MagicMock, patch
from botocore.credentials import Credentials
import litellm
from litellm.llms.bedrock.base_aws_llm import (
AwsAuthError,
BaseAWSLLM,
Boto3CredentialsInfo,
)
# Global variable for the base_aws_llm.py file path
BASE_AWS_LLM_PATH = os.path.join(
os.path.dirname(__file__), "../../../../litellm/llms/bedrock/base_aws_llm.py"
)
def test_boto3_init_tracer_wrapping():
"""
Test that all boto3 initializations are wrapped in tracer.trace or @tracer.wrap
Ensures observability of boto3 calls in litellm.
"""
# Get the source code of base_aws_llm.py
with open(BASE_AWS_LLM_PATH, "r") as f:
content = f.read()
# List all boto3 initialization patterns we want to check
boto3_init_patterns = ["boto3.client", "boto3.Session"]
lines = content.split("\n")
# Check each boto3 initialization is wrapped in tracer.trace
for line_number, line in enumerate(lines, 1):
for pattern in boto3_init_patterns:
if pattern in line:
# Look back up to 5 lines for decorator or trace block
start_line = max(0, line_number - 5)
context_lines = lines[start_line:line_number]
has_trace = (
"tracer.trace" in line
or any("tracer.trace" in prev_line for prev_line in context_lines)
or any("@tracer.wrap" in prev_line for prev_line in context_lines)
)
if not has_trace:
print(f"\nContext for line {line_number}:")
for i, ctx_line in enumerate(context_lines, start=start_line + 1):
print(f"{i}: {ctx_line}")
assert (
has_trace
), f"boto3 initialization '{pattern}' on line {line_number} is not wrapped with tracer.trace or @tracer.wrap"
def test_auth_functions_tracer_wrapping():
"""
Test that all _auth functions in base_aws_llm.py are wrapped with @tracer.wrap
Ensures observability of AWS authentication calls in litellm.
"""
# Get the source code of base_aws_llm.py
with open(BASE_AWS_LLM_PATH, "r") as f:
content = f.read()
lines = content.split("\n")
# Check each line for _auth function definitions
for line_number, line in enumerate(lines, 1):
if line.strip().startswith("def _auth_"):
# Look back up to 2 lines for the @tracer.wrap decorator
start_line = max(0, line_number - 2)
context_lines = lines[start_line:line_number]
has_tracer_wrap = any(
"@tracer.wrap" in prev_line for prev_line in context_lines
)
if not has_tracer_wrap:
print(f"\nContext for line {line_number}:")
for i, ctx_line in enumerate(context_lines, start=start_line + 1):
print(f"{i}: {ctx_line}")
assert (
has_tracer_wrap
), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}"