mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(Observability) - Add more detailed dd tracing on Proxy Auth, Bedrock Auth (#8693)
* add dd tracer * fix dd tracing * add @tracer.wrap() on def user_api_key_auth * add async_function_with_retries * remove dead code * add tracer.wrap on base aws llm * add tracer.wrap on base aws llm * fix print verbose * fix dd tracing * trace base aws llm * fix test base aws llm * fix converse transform * test base aws llm * BASE_AWS_LLM_PATH * BASE_AWS_LLM_PATH * test dd tracing
This commit is contained in:
parent
11a1692c63
commit
f940392971
9 changed files with 256 additions and 42 deletions
53
litellm/litellm_core_utils/dd_tracing.py
Normal file
53
litellm/litellm_core_utils/dd_tracing.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
"""
|
||||||
|
Handles Tracing on DataDog Traces.
|
||||||
|
|
||||||
|
If the ddtrace package is not installed, the tracer will be a no-op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ddtrace import tracer as dd_tracer
|
||||||
|
|
||||||
|
has_ddtrace = True
|
||||||
|
except ImportError:
|
||||||
|
has_ddtrace = False
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def null_tracer(name, **kwargs):
|
||||||
|
class NullSpan:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield NullSpan()
|
||||||
|
|
||||||
|
class NullTracer:
|
||||||
|
def trace(self, name, **kwargs):
|
||||||
|
class NullSpan:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return NullSpan()
|
||||||
|
|
||||||
|
def wrap(self, name=None, **kwargs):
|
||||||
|
def decorator(f):
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
dd_tracer = NullTracer()
|
||||||
|
|
||||||
|
# Export the tracer instance
|
||||||
|
tracer = dd_tracer
|
|
@ -9,6 +9,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -63,6 +64,7 @@ class BaseAWSLLM:
|
||||||
credential_str = json.dumps(credential_args, sort_keys=True)
|
credential_str = json.dumps(credential_args, sort_keys=True)
|
||||||
return hashlib.sha256(credential_str.encode()).hexdigest()
|
return hashlib.sha256(credential_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def get_credentials(
|
def get_credentials(
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
@ -200,6 +202,7 @@ class BaseAWSLLM:
|
||||||
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _auth_with_web_identity_token(
|
def _auth_with_web_identity_token(
|
||||||
self,
|
self,
|
||||||
aws_web_identity_token: str,
|
aws_web_identity_token: str,
|
||||||
|
@ -230,6 +233,7 @@ class BaseAWSLLM:
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with tracer.trace("boto3.client(sts)"):
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
region_name=aws_region_name,
|
region_name=aws_region_name,
|
||||||
|
@ -258,11 +262,13 @@ class BaseAWSLLM:
|
||||||
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with tracer.trace("boto3.Session(**iam_creds_dict)"):
|
||||||
session = boto3.Session(**iam_creds_dict)
|
session = boto3.Session(**iam_creds_dict)
|
||||||
|
|
||||||
iam_creds = session.get_credentials()
|
iam_creds = session.get_credentials()
|
||||||
return iam_creds, self._get_default_ttl_for_boto3_credentials()
|
return iam_creds, self._get_default_ttl_for_boto3_credentials()
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _auth_with_aws_role(
|
def _auth_with_aws_role(
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: Optional[str],
|
aws_access_key_id: Optional[str],
|
||||||
|
@ -276,6 +282,7 @@ class BaseAWSLLM:
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
with tracer.trace("boto3.client(sts)"):
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
@ -288,7 +295,6 @@ class BaseAWSLLM:
|
||||||
|
|
||||||
# Extract the credentials from the response and convert to Session Credentials
|
# Extract the credentials from the response and convert to Session Credentials
|
||||||
sts_credentials = sts_response["Credentials"]
|
sts_credentials = sts_response["Credentials"]
|
||||||
|
|
||||||
credentials = Credentials(
|
credentials = Credentials(
|
||||||
access_key=sts_credentials["AccessKeyId"],
|
access_key=sts_credentials["AccessKeyId"],
|
||||||
secret_key=sts_credentials["SecretAccessKey"],
|
secret_key=sts_credentials["SecretAccessKey"],
|
||||||
|
@ -301,6 +307,7 @@ class BaseAWSLLM:
|
||||||
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
|
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
|
||||||
return credentials, sts_ttl
|
return credentials, sts_ttl
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _auth_with_aws_profile(
|
def _auth_with_aws_profile(
|
||||||
self, aws_profile_name: str
|
self, aws_profile_name: str
|
||||||
) -> Tuple[Credentials, Optional[int]]:
|
) -> Tuple[Credentials, Optional[int]]:
|
||||||
|
@ -310,9 +317,11 @@ class BaseAWSLLM:
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
with tracer.trace("boto3.Session(profile_name=aws_profile_name)"):
|
||||||
client = boto3.Session(profile_name=aws_profile_name)
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
return client.get_credentials(), None
|
return client.get_credentials(), None
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _auth_with_aws_session_token(
|
def _auth_with_aws_session_token(
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: str,
|
aws_access_key_id: str,
|
||||||
|
@ -333,6 +342,7 @@ class BaseAWSLLM:
|
||||||
|
|
||||||
return credentials, None
|
return credentials, None
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _auth_with_access_key_and_secret_key(
|
def _auth_with_access_key_and_secret_key(
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: str,
|
aws_access_key_id: str,
|
||||||
|
@ -345,7 +355,9 @@ class BaseAWSLLM:
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
# Check if credentials are already in cache. These credentials have no expiry time.
|
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||||
|
with tracer.trace(
|
||||||
|
"boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)"
|
||||||
|
):
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
@ -355,16 +367,19 @@ class BaseAWSLLM:
|
||||||
credentials = session.get_credentials()
|
credentials = session.get_credentials()
|
||||||
return credentials, self._get_default_ttl_for_boto3_credentials()
|
return credentials, self._get_default_ttl_for_boto3_credentials()
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
|
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
|
||||||
"""
|
"""
|
||||||
Authenticate with AWS Environment Variables
|
Authenticate with AWS Environment Variables
|
||||||
"""
|
"""
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
with tracer.trace("boto3.Session()"):
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
credentials = session.get_credentials()
|
credentials = session.get_credentials()
|
||||||
return credentials, None
|
return credentials, None
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def _get_default_ttl_for_boto3_credentials(self) -> int:
|
def _get_default_ttl_for_boto3_credentials(self) -> int:
|
||||||
"""
|
"""
|
||||||
Get the default TTL for boto3 credentials
|
Get the default TTL for boto3 credentials
|
||||||
|
@ -475,6 +490,7 @@ class BaseAWSLLM:
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
def get_request_headers(
|
def get_request_headers(
|
||||||
self,
|
self,
|
||||||
credentials: Credentials,
|
credentials: Credentials,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -60,7 +60,6 @@ def make_sync_call(
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=litellm.print_verbose,
|
|
||||||
encoding=litellm.encoding,
|
encoding=litellm.encoding,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
completion_stream: Any = MockResponseIterator(
|
completion_stream: Any = MockResponseIterator(
|
||||||
|
@ -102,7 +101,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
messages: list,
|
messages: list,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -170,7 +168,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
messages: list,
|
messages: list,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj: LiteLLMLoggingObject,
|
logging_obj: LiteLLMLoggingObject,
|
||||||
|
@ -247,7 +244,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
@ -259,7 +255,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
custom_prompt_dict: dict,
|
custom_prompt_dict: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj: LiteLLMLoggingObject,
|
logging_obj: LiteLLMLoggingObject,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
@ -271,11 +266,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
|
||||||
from botocore.credentials import Credentials
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
modelId = optional_params.pop("model_id", None)
|
modelId = optional_params.pop("model_id", None)
|
||||||
|
@ -367,7 +357,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=proxy_endpoint_url,
|
api_base=proxy_endpoint_url,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -387,7 +376,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=proxy_endpoint_url,
|
api_base=proxy_endpoint_url,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -489,7 +477,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,7 @@ Translating between OpenAI's `/chat/completion` format and Amazon's `/converse`
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
from typing import Callable, List, Literal, Optional, Tuple, Union, cast, overload
|
from typing import List, Literal, Optional, Tuple, Union, cast, overload
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -542,7 +542,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=request_data,
|
data=request_data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=None,
|
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -557,7 +556,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
messages: List,
|
messages: List,
|
||||||
print_verbose: Optional[Callable],
|
|
||||||
encoding,
|
encoding,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -2638,7 +2638,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params, # type: ignore
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
|
|
|
@ -20,6 +20,7 @@ import litellm
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
from litellm._service_logger import ServiceLogging
|
from litellm._service_logger import ServiceLogging
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import (
|
||||||
_cache_key_object,
|
_cache_key_object,
|
||||||
|
@ -897,6 +898,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
# Check 3. If token is expired
|
# Check 3. If token is expired
|
||||||
if valid_token.expires is not None:
|
if valid_token.expires is not None:
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
|
if isinstance(valid_token.expires, datetime):
|
||||||
|
expiry_time = valid_token.expires
|
||||||
|
else:
|
||||||
expiry_time = datetime.fromisoformat(valid_token.expires)
|
expiry_time = datetime.fromisoformat(valid_token.expires)
|
||||||
if (
|
if (
|
||||||
expiry_time.tzinfo is None
|
expiry_time.tzinfo is None
|
||||||
|
@ -1127,6 +1131,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
async def user_api_key_auth(
|
async def user_api_key_auth(
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key: str = fastapi.Security(api_key_header),
|
api_key: str = fastapi.Security(api_key_header),
|
||||||
|
|
|
@ -48,6 +48,7 @@ from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.asyncify import run_async_function
|
from litellm.litellm_core_utils.asyncify import run_async_function
|
||||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||||
|
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
|
@ -2857,6 +2858,7 @@ class Router:
|
||||||
|
|
||||||
#### [END] ASSISTANTS API ####
|
#### [END] ASSISTANTS API ####
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915
|
async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
Try calling the function_with_retries
|
Try calling the function_with_retries
|
||||||
|
@ -3127,6 +3129,7 @@ class Router:
|
||||||
Context_Policy_Fallbacks={content_policy_fallbacks}",
|
Context_Policy_Fallbacks={content_policy_fallbacks}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@tracer.wrap()
|
||||||
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
|
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
|
||||||
verbose_router_logger.debug("Inside async function with retries.")
|
verbose_router_logger.debug("Inside async function with retries.")
|
||||||
original_function = kwargs.pop("original_function")
|
original_function = kwargs.pop("original_function")
|
||||||
|
|
53
tests/litellm/litellm_core_utils/test_dd_tracing.py
Normal file
53
tests/litellm/litellm_core_utils/test_dd_tracing.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.dd_tracing import dd_tracer
|
||||||
|
|
||||||
|
|
||||||
|
def test_dd_tracer_when_package_exists():
|
||||||
|
with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", True):
|
||||||
|
# Test the trace context manager
|
||||||
|
with dd_tracer.trace("test_operation") as span:
|
||||||
|
assert span is not None
|
||||||
|
|
||||||
|
# Test the wrapper decorator
|
||||||
|
@dd_tracer.wrap(name="test_function")
|
||||||
|
def sample_function():
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
result = sample_function()
|
||||||
|
assert result == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dd_tracer_when_package_not_exists():
|
||||||
|
with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", False):
|
||||||
|
# Test the trace context manager with null tracer
|
||||||
|
with dd_tracer.trace("test_operation") as span:
|
||||||
|
assert span is not None
|
||||||
|
# Verify null span methods don't raise exceptions
|
||||||
|
span.finish()
|
||||||
|
|
||||||
|
# Test the wrapper decorator with null tracer
|
||||||
|
@dd_tracer.wrap(name="test_function")
|
||||||
|
def sample_function():
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
result = sample_function()
|
||||||
|
assert result == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_null_tracer_context_manager():
|
||||||
|
with patch("litellm.litellm_core_utils.dd_tracing.has_ddtrace", False):
|
||||||
|
# Test that the context manager works without raising exceptions
|
||||||
|
with dd_tracer.trace("test_operation") as span:
|
||||||
|
# Test that we can call methods on the null span
|
||||||
|
span.finish()
|
||||||
|
assert True # If we get here without exceptions, the test passes
|
100
tests/litellm/llms/bedrock/test_base_aws_llm.py
Normal file
100
tests/litellm/llms/bedrock/test_base_aws_llm.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.bedrock.base_aws_llm import (
|
||||||
|
AwsAuthError,
|
||||||
|
BaseAWSLLM,
|
||||||
|
Boto3CredentialsInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global variable for the base_aws_llm.py file path
|
||||||
|
|
||||||
|
BASE_AWS_LLM_PATH = os.path.join(
|
||||||
|
os.path.dirname(__file__), "../../../../litellm/llms/bedrock/base_aws_llm.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_boto3_init_tracer_wrapping():
|
||||||
|
"""
|
||||||
|
Test that all boto3 initializations are wrapped in tracer.trace or @tracer.wrap
|
||||||
|
|
||||||
|
Ensures observability of boto3 calls in litellm.
|
||||||
|
"""
|
||||||
|
# Get the source code of base_aws_llm.py
|
||||||
|
with open(BASE_AWS_LLM_PATH, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# List all boto3 initialization patterns we want to check
|
||||||
|
boto3_init_patterns = ["boto3.client", "boto3.Session"]
|
||||||
|
|
||||||
|
lines = content.split("\n")
|
||||||
|
# Check each boto3 initialization is wrapped in tracer.trace
|
||||||
|
for line_number, line in enumerate(lines, 1):
|
||||||
|
for pattern in boto3_init_patterns:
|
||||||
|
if pattern in line:
|
||||||
|
# Look back up to 5 lines for decorator or trace block
|
||||||
|
start_line = max(0, line_number - 5)
|
||||||
|
context_lines = lines[start_line:line_number]
|
||||||
|
|
||||||
|
has_trace = (
|
||||||
|
"tracer.trace" in line
|
||||||
|
or any("tracer.trace" in prev_line for prev_line in context_lines)
|
||||||
|
or any("@tracer.wrap" in prev_line for prev_line in context_lines)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_trace:
|
||||||
|
print(f"\nContext for line {line_number}:")
|
||||||
|
for i, ctx_line in enumerate(context_lines, start=start_line + 1):
|
||||||
|
print(f"{i}: {ctx_line}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
has_trace
|
||||||
|
), f"boto3 initialization '{pattern}' on line {line_number} is not wrapped with tracer.trace or @tracer.wrap"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_functions_tracer_wrapping():
|
||||||
|
"""
|
||||||
|
Test that all _auth functions in base_aws_llm.py are wrapped with @tracer.wrap
|
||||||
|
|
||||||
|
Ensures observability of AWS authentication calls in litellm.
|
||||||
|
"""
|
||||||
|
# Get the source code of base_aws_llm.py
|
||||||
|
with open(BASE_AWS_LLM_PATH, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
lines = content.split("\n")
|
||||||
|
# Check each line for _auth function definitions
|
||||||
|
for line_number, line in enumerate(lines, 1):
|
||||||
|
if line.strip().startswith("def _auth_"):
|
||||||
|
# Look back up to 2 lines for the @tracer.wrap decorator
|
||||||
|
start_line = max(0, line_number - 2)
|
||||||
|
context_lines = lines[start_line:line_number]
|
||||||
|
|
||||||
|
has_tracer_wrap = any(
|
||||||
|
"@tracer.wrap" in prev_line for prev_line in context_lines
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_tracer_wrap:
|
||||||
|
print(f"\nContext for line {line_number}:")
|
||||||
|
for i, ctx_line in enumerate(context_lines, start=start_line + 1):
|
||||||
|
print(f"{i}: {ctx_line}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
has_tracer_wrap
|
||||||
|
), f"Auth function on line {line_number} is not wrapped with @tracer.wrap: {line.strip()}"
|
Loading…
Add table
Add a link
Reference in a new issue