mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Add datadog health check support + fix bedrock converse cost tracking w/ region name specified (#7958)
* fix(bedrock/converse_handler.py): fix bedrock region name on async calls * fix(utils.py): fix split model handling Fixes bedrock cost calculation when region name is given * feat(_health_endpoints.py): support health checking datadog integration Closes https://github.com/BerriAI/litellm/issues/7921
This commit is contained in:
parent
a835baacfc
commit
c6e9240405
13 changed files with 254 additions and 33 deletions
|
@ -386,6 +386,7 @@ def _select_model_name_for_cost_calc(
|
|||
3. If completion response has model set return that
|
||||
4. Check if model is passed in return that
|
||||
"""
|
||||
|
||||
return_model: Optional[str] = None
|
||||
region_name: Optional[str] = None
|
||||
custom_llm_provider = _get_provider_for_cost_calc(
|
||||
|
|
19
litellm/integrations/base_health_check.py
Normal file
19
litellm/integrations/base_health_check.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Base class for health check integrations
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
|
||||
|
||||
class HealthCheckIntegration(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
pass
|
|
@ -15,12 +15,14 @@ For batching specific details see CustomBatchLogger class
|
|||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime as datetimeObj
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
|
@ -31,14 +33,20 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||
from litellm.types.integrations.datadog import *
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
from ..base_health_check import HealthCheckIntegration
|
||||
|
||||
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
|
||||
|
||||
|
||||
class DataDogLogger(CustomBatchLogger):
|
||||
class DataDogLogger(
|
||||
CustomBatchLogger,
|
||||
HealthCheckIntegration,
|
||||
):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -235,6 +243,25 @@ class DataDogLogger(CustomBatchLogger):
|
|||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
def _create_datadog_logging_payload_helper(
|
||||
self,
|
||||
standard_logging_object: StandardLoggingPayload,
|
||||
status: DataDogStatus,
|
||||
) -> DatadogPayload:
|
||||
json_payload = json.dumps(standard_logging_object, default=str)
|
||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(
|
||||
standard_logging_object=standard_logging_object
|
||||
),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=json_payload,
|
||||
service=self._get_datadog_service(),
|
||||
status=status,
|
||||
)
|
||||
return dd_payload
|
||||
|
||||
def create_datadog_logging_payload(
|
||||
self,
|
||||
kwargs: Union[dict, Any],
|
||||
|
@ -254,7 +281,6 @@ class DataDogLogger(CustomBatchLogger):
|
|||
Returns:
|
||||
DatadogPayload: defined in types.py
|
||||
"""
|
||||
import json
|
||||
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
|
@ -268,18 +294,9 @@ class DataDogLogger(CustomBatchLogger):
|
|||
|
||||
# Build the initial payload
|
||||
self.truncate_standard_logging_payload_content(standard_logging_object)
|
||||
json_payload = json.dumps(standard_logging_object, default=str)
|
||||
|
||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||
|
||||
dd_payload = DatadogPayload(
|
||||
ddsource=self._get_datadog_source(),
|
||||
ddtags=self._get_datadog_tags(
|
||||
standard_logging_object=standard_logging_object
|
||||
),
|
||||
hostname=self._get_datadog_hostname(),
|
||||
message=json_payload,
|
||||
service=self._get_datadog_service(),
|
||||
dd_payload = self._create_datadog_logging_payload_helper(
|
||||
standard_logging_object=standard_logging_object,
|
||||
status=status,
|
||||
)
|
||||
return dd_payload
|
||||
|
@ -293,6 +310,7 @@ class DataDogLogger(CustomBatchLogger):
|
|||
|
||||
"Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import json
|
||||
|
||||
|
@ -493,3 +511,35 @@ class DataDogLogger(CustomBatchLogger):
|
|||
@staticmethod
|
||||
def _get_datadog_pod_name():
|
||||
return os.getenv("POD_NAME", "unknown")
|
||||
|
||||
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||
"""
|
||||
Check if the service is healthy
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
create_dummy_standard_logging_payload,
|
||||
)
|
||||
|
||||
standard_logging_object = create_dummy_standard_logging_payload()
|
||||
dd_payload = self._create_datadog_logging_payload_helper(
|
||||
standard_logging_object=standard_logging_object,
|
||||
status=DataDogStatus.INFO,
|
||||
)
|
||||
log_queue = [dd_payload]
|
||||
response = await self.async_send_compressed_data(log_queue)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="healthy",
|
||||
error_message=None,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
return IntegrationHealthCheckStatus(
|
||||
status="unhealthy",
|
||||
error_message=str(e),
|
||||
)
|
||||
|
|
|
@ -3341,3 +3341,85 @@ def _get_traceback_str_for_error(error_str: str) -> str:
|
|||
function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called
|
||||
"""
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
# used for unit testing
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
||||
# First create the nested objects with proper typing
|
||||
model_info = StandardLoggingModelInformation(
|
||||
model_map_key="gpt-3.5-turbo", model_map_value=None
|
||||
)
|
||||
|
||||
metadata = StandardLoggingMetadata( # type: ignore
|
||||
user_api_key_hash=str("test_hash"),
|
||||
user_api_key_alias=str("test_alias"),
|
||||
user_api_key_team_id=str("test_team"),
|
||||
user_api_key_user_id=str("test_user"),
|
||||
user_api_key_team_alias=str("test_team_alias"),
|
||||
user_api_key_org_id=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=str("127.0.0.1"),
|
||||
requester_metadata=None,
|
||||
user_api_key_end_user_id=str("test_end_user"),
|
||||
)
|
||||
|
||||
hidden_params = StandardLoggingHiddenParams(
|
||||
model_id=None,
|
||||
cache_key=None,
|
||||
api_base=None,
|
||||
response_cost=None,
|
||||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
)
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
response_cost = Decimal("0.1")
|
||||
start_time = Decimal("1234567890.0")
|
||||
end_time = Decimal("1234567891.0")
|
||||
completion_start_time = Decimal("1234567890.5")
|
||||
saved_cache_cost = Decimal("0.0")
|
||||
|
||||
# Create messages and response with proper typing
|
||||
messages: List[Dict[str, str]] = [{"role": "user", "content": "Hello, world!"}]
|
||||
response: Dict[str, List[Dict[str, Dict[str, str]]]] = {
|
||||
"choices": [{"message": {"content": "Hi there!"}}]
|
||||
}
|
||||
|
||||
# Main payload initialization
|
||||
return StandardLoggingPayload( # type: ignore
|
||||
id=str("test_id"),
|
||||
call_type=str("completion"),
|
||||
stream=bool(False),
|
||||
response_cost=response_cost,
|
||||
response_cost_failure_debug_info=None,
|
||||
status=str("success"),
|
||||
total_tokens=int(30),
|
||||
prompt_tokens=int(20),
|
||||
completion_tokens=int(10),
|
||||
startTime=start_time,
|
||||
endTime=end_time,
|
||||
completionStartTime=completion_start_time,
|
||||
model_map_information=model_info,
|
||||
model=str("gpt-3.5-turbo"),
|
||||
model_id=str("model-123"),
|
||||
model_group=str("openai-gpt"),
|
||||
custom_llm_provider=str("openai"),
|
||||
api_base=str("https://api.openai.com"),
|
||||
metadata=metadata,
|
||||
cache_hit=bool(False),
|
||||
cache_key=None,
|
||||
saved_cache_cost=saved_cache_cost,
|
||||
request_tags=[],
|
||||
end_user=None,
|
||||
requester_ip_address=str("127.0.0.1"),
|
||||
messages=messages,
|
||||
response=response,
|
||||
error_str=None,
|
||||
model_parameters={"stream": True},
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
|
|
@ -110,7 +110,7 @@ def _set_duration_in_model_call_details(
|
|||
if logging_obj and hasattr(logging_obj, "model_call_details"):
|
||||
logging_obj.model_call_details["llm_api_duration_ms"] = duration_ms
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
verbose_logger.debug(
|
||||
"`logging_obj` not found - unable to track `llm_api_duration_ms"
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
@ -207,7 +207,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -226,7 +226,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base, headers=headers, data=data, logging_obj=logging_obj
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
|
@ -267,6 +270,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
|
@ -300,8 +304,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
litellm_params["aws_region_name"] = aws_region_name
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
|
@ -321,6 +323,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
litellm_params["aws_region_name"] = (
|
||||
aws_region_name # [DO NOT DELETE] important for async calls
|
||||
)
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
|
@ -347,7 +353,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
|
|
@ -51,7 +51,6 @@ async def test_endpoint(request: Request):
|
|||
"/health/services",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def health_services_endpoint( # noqa: PLR0915
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
|
@ -64,14 +63,19 @@ async def health_services_endpoint( # noqa: PLR0915
|
|||
"webhook",
|
||||
"email",
|
||||
"braintrust",
|
||||
"datadog",
|
||||
],
|
||||
str,
|
||||
] = fastapi.Query(description="Specify the service being hit."),
|
||||
):
|
||||
"""
|
||||
Hidden endpoint.
|
||||
Use this admin-only endpoint to check if the service is healthy.
|
||||
|
||||
Used by the UI to let user check if slack alerting is working as expected.
|
||||
Example:
|
||||
```
|
||||
curl -L -X GET 'http://0.0.0.0:4000/health/services?service=datadog' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
|
@ -84,6 +88,7 @@ async def health_services_endpoint( # noqa: PLR0915
|
|||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Service must be specified."}
|
||||
)
|
||||
|
||||
if service not in [
|
||||
"slack_budget_alerts",
|
||||
"email",
|
||||
|
@ -95,6 +100,7 @@ async def health_services_endpoint( # noqa: PLR0915
|
|||
"otel",
|
||||
"custom_callback_api",
|
||||
"langsmith",
|
||||
"datadog",
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
@ -118,8 +124,20 @@ async def health_services_endpoint( # noqa: PLR0915
|
|||
"status": "success",
|
||||
"message": "Mock LLM request made - check {}.".format(service),
|
||||
}
|
||||
elif service == "datadog":
|
||||
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||
|
||||
if service == "langfuse":
|
||||
datadog_logger = DataDogLogger()
|
||||
response = await datadog_logger.async_health_check()
|
||||
return {
|
||||
"status": response["status"],
|
||||
"message": (
|
||||
response["error_message"]
|
||||
if response["status"] == "unhealthy"
|
||||
else "Datadog is healthy"
|
||||
),
|
||||
}
|
||||
elif service == "langfuse":
|
||||
from litellm.integrations.langfuse.langfuse import LangFuseLogger
|
||||
|
||||
langfuse_logger = LangFuseLogger()
|
||||
|
|
|
@ -1228,9 +1228,6 @@ class PrismaClient:
|
|||
"""
|
||||
Generic implementation of get data
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
|
||||
)
|
||||
start_time = time.time()
|
||||
try:
|
||||
if table_name == "users":
|
||||
|
|
6
litellm/types/integrations/base_health_check.py
Normal file
6
litellm/types/integrations/base_health_check.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Literal, Optional, TypedDict
|
||||
|
||||
|
||||
class IntegrationHealthCheckStatus(TypedDict):
|
||||
status: Literal["healthy", "unhealthy"]
|
||||
error_message: Optional[str]
|
|
@ -4092,7 +4092,7 @@ def _get_potential_model_names(
|
|||
elif custom_llm_provider and model.startswith(
|
||||
custom_llm_provider + "/"
|
||||
): # handle case where custom_llm_provider is provided and model starts with custom_llm_provider
|
||||
split_model = model.split("/")[1]
|
||||
split_model = model.split("/", 1)[1]
|
||||
combined_model_name = model
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=split_model, custom_llm_provider=custom_llm_provider
|
||||
|
|
|
@ -539,17 +539,19 @@ class BaseLLMChatTest(ABC):
|
|||
|
||||
return url
|
||||
|
||||
def test_completion_cost(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_cost(self):
|
||||
from litellm import completion_cost
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
litellm.set_verbose = True
|
||||
response = self.completion_function(
|
||||
response = await self.async_completion_function(
|
||||
**self.get_base_completion_call_args(),
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
)
|
||||
print(response._hidden_params)
|
||||
cost = completion_cost(response)
|
||||
|
||||
assert cost > 0
|
||||
|
|
|
@ -2057,7 +2057,7 @@ def test_bedrock_supports_tool_call(model, expected_supports_tool_call):
|
|||
assert "tools" not in supported_openai_params
|
||||
|
||||
|
||||
class TestBedrockConverseChat(BaseLLMChatTest):
|
||||
class TestBedrockConverseChatCrossRegion(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
@ -2104,6 +2104,29 @@ class TestBedrockConverseChat(BaseLLMChatTest):
|
|||
assert cost > 0
|
||||
|
||||
|
||||
class TestBedrockConverseChatNormal(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
litellm.add_known_models()
|
||||
return {
|
||||
"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"aws_region_name": "us-east-1",
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
def test_multilingual_requests(self):
|
||||
"""
|
||||
Bedrock API raises a 400 BadRequest error when the request contains invalid utf-8 sequences.
|
||||
|
||||
Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestBedrockRerank(BaseLLMRerankTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.BEDROCK
|
||||
|
|
|
@ -2600,7 +2600,14 @@ async def test_test_completion_cost_gpt4o_audio_output_from_model(stream):
|
|||
assert round(cost, 2) == round(total_input_cost + total_output_cost, 2)
|
||||
|
||||
|
||||
def test_completion_cost_azure_ai_meta():
|
||||
@pytest.mark.parametrize(
|
||||
"response_model, custom_llm_provider",
|
||||
[
|
||||
("azure_ai/Meta-Llama-3.1-70B-Instruct", "azure_ai"),
|
||||
("anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock"),
|
||||
],
|
||||
)
|
||||
def test_completion_cost_model_response_cost(response_model, custom_llm_provider):
|
||||
"""
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/6310
|
||||
"""
|
||||
|
@ -2628,7 +2635,7 @@ def test_completion_cost_azure_ai_meta():
|
|||
}
|
||||
],
|
||||
"created": 1729243714,
|
||||
"model": "azure_ai/Meta-Llama-3.1-70B-Instruct",
|
||||
"model": response_model,
|
||||
"object": "chat.completion",
|
||||
"service_tier": None,
|
||||
"system_fingerprint": None,
|
||||
|
@ -2642,7 +2649,7 @@ def test_completion_cost_azure_ai_meta():
|
|||
}
|
||||
|
||||
model_response = ModelResponse(**response)
|
||||
cost = completion_cost(model_response, custom_llm_provider="azure_ai")
|
||||
cost = completion_cost(model_response, custom_llm_provider=custom_llm_provider)
|
||||
|
||||
assert cost > 0
|
||||
|
||||
|
@ -2754,3 +2761,14 @@ def test_add_known_models():
|
|||
assert (
|
||||
"bedrock/us-west-1/meta.llama3-70b-instruct-v1:0" not in litellm.bedrock_models
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_cost_calc_with_region():
|
||||
from litellm import completion
|
||||
|
||||
response = completion(
|
||||
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
assert response._hidden_params["response_cost"] > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue