mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Add datadog health check support + fix bedrock converse cost tracking w/ region name specified (#7958)
* fix(bedrock/converse_handler.py): fix bedrock region name on async calls * fix(utils.py): fix split model handling Fixes bedrock cost calculation when region name is given * feat(_health_endpoints.py): support health checking datadog integration Closes https://github.com/BerriAI/litellm/issues/7921
This commit is contained in:
parent
c0e83ab377
commit
fe460f19f5
13 changed files with 254 additions and 33 deletions
|
@ -386,6 +386,7 @@ def _select_model_name_for_cost_calc(
|
||||||
3. If completion response has model set return that
|
3. If completion response has model set return that
|
||||||
4. Check if model is passed in return that
|
4. Check if model is passed in return that
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return_model: Optional[str] = None
|
return_model: Optional[str] = None
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
custom_llm_provider = _get_provider_for_cost_calc(
|
custom_llm_provider = _get_provider_for_cost_calc(
|
||||||
|
|
19
litellm/integrations/base_health_check.py
Normal file
19
litellm/integrations/base_health_check.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
"""
|
||||||
|
Base class for health check integrations
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckIntegration(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||||
|
"""
|
||||||
|
Check if the service is healthy
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -15,12 +15,14 @@ For batching specific details see CustomBatchLogger class
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime as datetimeObj
|
from datetime import datetime as datetimeObj
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -31,14 +33,20 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
httpxSpecialProvider,
|
httpxSpecialProvider,
|
||||||
)
|
)
|
||||||
|
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
||||||
from litellm.types.integrations.datadog import *
|
from litellm.types.integrations.datadog import *
|
||||||
from litellm.types.services import ServiceLoggerPayload
|
from litellm.types.services import ServiceLoggerPayload
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
from ..base_health_check import HealthCheckIntegration
|
||||||
|
|
||||||
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
|
DD_MAX_BATCH_SIZE = 1000 # max number of logs DD API can accept
|
||||||
|
|
||||||
|
|
||||||
class DataDogLogger(CustomBatchLogger):
|
class DataDogLogger(
|
||||||
|
CustomBatchLogger,
|
||||||
|
HealthCheckIntegration,
|
||||||
|
):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -235,6 +243,25 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
if len(self.log_queue) >= self.batch_size:
|
if len(self.log_queue) >= self.batch_size:
|
||||||
await self.async_send_batch()
|
await self.async_send_batch()
|
||||||
|
|
||||||
|
def _create_datadog_logging_payload_helper(
|
||||||
|
self,
|
||||||
|
standard_logging_object: StandardLoggingPayload,
|
||||||
|
status: DataDogStatus,
|
||||||
|
) -> DatadogPayload:
|
||||||
|
json_payload = json.dumps(standard_logging_object, default=str)
|
||||||
|
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
||||||
|
dd_payload = DatadogPayload(
|
||||||
|
ddsource=self._get_datadog_source(),
|
||||||
|
ddtags=self._get_datadog_tags(
|
||||||
|
standard_logging_object=standard_logging_object
|
||||||
|
),
|
||||||
|
hostname=self._get_datadog_hostname(),
|
||||||
|
message=json_payload,
|
||||||
|
service=self._get_datadog_service(),
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
return dd_payload
|
||||||
|
|
||||||
def create_datadog_logging_payload(
|
def create_datadog_logging_payload(
|
||||||
self,
|
self,
|
||||||
kwargs: Union[dict, Any],
|
kwargs: Union[dict, Any],
|
||||||
|
@ -254,7 +281,6 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
Returns:
|
Returns:
|
||||||
DatadogPayload: defined in types.py
|
DatadogPayload: defined in types.py
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
|
|
||||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
"standard_logging_object", None
|
"standard_logging_object", None
|
||||||
|
@ -268,18 +294,9 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
|
|
||||||
# Build the initial payload
|
# Build the initial payload
|
||||||
self.truncate_standard_logging_payload_content(standard_logging_object)
|
self.truncate_standard_logging_payload_content(standard_logging_object)
|
||||||
json_payload = json.dumps(standard_logging_object, default=str)
|
|
||||||
|
|
||||||
verbose_logger.debug("Datadog: Logger - Logging payload = %s", json_payload)
|
dd_payload = self._create_datadog_logging_payload_helper(
|
||||||
|
standard_logging_object=standard_logging_object,
|
||||||
dd_payload = DatadogPayload(
|
|
||||||
ddsource=self._get_datadog_source(),
|
|
||||||
ddtags=self._get_datadog_tags(
|
|
||||||
standard_logging_object=standard_logging_object
|
|
||||||
),
|
|
||||||
hostname=self._get_datadog_hostname(),
|
|
||||||
message=json_payload,
|
|
||||||
service=self._get_datadog_service(),
|
|
||||||
status=status,
|
status=status,
|
||||||
)
|
)
|
||||||
return dd_payload
|
return dd_payload
|
||||||
|
@ -293,6 +310,7 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
|
|
||||||
"Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
|
"Datadog recommends sending your logs compressed. Add the Content-Encoding: gzip header to the request when sending"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -493,3 +511,35 @@ class DataDogLogger(CustomBatchLogger):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_datadog_pod_name():
|
def _get_datadog_pod_name():
|
||||||
return os.getenv("POD_NAME", "unknown")
|
return os.getenv("POD_NAME", "unknown")
|
||||||
|
|
||||||
|
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
||||||
|
"""
|
||||||
|
Check if the service is healthy
|
||||||
|
"""
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
create_dummy_standard_logging_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
standard_logging_object = create_dummy_standard_logging_payload()
|
||||||
|
dd_payload = self._create_datadog_logging_payload_helper(
|
||||||
|
standard_logging_object=standard_logging_object,
|
||||||
|
status=DataDogStatus.INFO,
|
||||||
|
)
|
||||||
|
log_queue = [dd_payload]
|
||||||
|
response = await self.async_send_compressed_data(log_queue)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
return IntegrationHealthCheckStatus(
|
||||||
|
status="healthy",
|
||||||
|
error_message=None,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
return IntegrationHealthCheckStatus(
|
||||||
|
status="unhealthy",
|
||||||
|
error_message=e.response.text,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return IntegrationHealthCheckStatus(
|
||||||
|
status="unhealthy",
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
|
@ -3341,3 +3341,85 @@ def _get_traceback_str_for_error(error_str: str) -> str:
|
||||||
function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called
|
function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called
|
||||||
"""
|
"""
|
||||||
return traceback.format_exc()
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
# used for unit testing
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
||||||
|
# First create the nested objects with proper typing
|
||||||
|
model_info = StandardLoggingModelInformation(
|
||||||
|
model_map_key="gpt-3.5-turbo", model_map_value=None
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = StandardLoggingMetadata( # type: ignore
|
||||||
|
user_api_key_hash=str("test_hash"),
|
||||||
|
user_api_key_alias=str("test_alias"),
|
||||||
|
user_api_key_team_id=str("test_team"),
|
||||||
|
user_api_key_user_id=str("test_user"),
|
||||||
|
user_api_key_team_alias=str("test_team_alias"),
|
||||||
|
user_api_key_org_id=None,
|
||||||
|
spend_logs_metadata=None,
|
||||||
|
requester_ip_address=str("127.0.0.1"),
|
||||||
|
requester_metadata=None,
|
||||||
|
user_api_key_end_user_id=str("test_end_user"),
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_params = StandardLoggingHiddenParams(
|
||||||
|
model_id=None,
|
||||||
|
cache_key=None,
|
||||||
|
api_base=None,
|
||||||
|
response_cost=None,
|
||||||
|
additional_headers=None,
|
||||||
|
litellm_overhead_time_ms=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert numeric values to appropriate types
|
||||||
|
response_cost = Decimal("0.1")
|
||||||
|
start_time = Decimal("1234567890.0")
|
||||||
|
end_time = Decimal("1234567891.0")
|
||||||
|
completion_start_time = Decimal("1234567890.5")
|
||||||
|
saved_cache_cost = Decimal("0.0")
|
||||||
|
|
||||||
|
# Create messages and response with proper typing
|
||||||
|
messages: List[Dict[str, str]] = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
response: Dict[str, List[Dict[str, Dict[str, str]]]] = {
|
||||||
|
"choices": [{"message": {"content": "Hi there!"}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main payload initialization
|
||||||
|
return StandardLoggingPayload( # type: ignore
|
||||||
|
id=str("test_id"),
|
||||||
|
call_type=str("completion"),
|
||||||
|
stream=bool(False),
|
||||||
|
response_cost=response_cost,
|
||||||
|
response_cost_failure_debug_info=None,
|
||||||
|
status=str("success"),
|
||||||
|
total_tokens=int(30),
|
||||||
|
prompt_tokens=int(20),
|
||||||
|
completion_tokens=int(10),
|
||||||
|
startTime=start_time,
|
||||||
|
endTime=end_time,
|
||||||
|
completionStartTime=completion_start_time,
|
||||||
|
model_map_information=model_info,
|
||||||
|
model=str("gpt-3.5-turbo"),
|
||||||
|
model_id=str("model-123"),
|
||||||
|
model_group=str("openai-gpt"),
|
||||||
|
custom_llm_provider=str("openai"),
|
||||||
|
api_base=str("https://api.openai.com"),
|
||||||
|
metadata=metadata,
|
||||||
|
cache_hit=bool(False),
|
||||||
|
cache_key=None,
|
||||||
|
saved_cache_cost=saved_cache_cost,
|
||||||
|
request_tags=[],
|
||||||
|
end_user=None,
|
||||||
|
requester_ip_address=str("127.0.0.1"),
|
||||||
|
messages=messages,
|
||||||
|
response=response,
|
||||||
|
error_str=None,
|
||||||
|
model_parameters={"stream": True},
|
||||||
|
hidden_params=hidden_params,
|
||||||
|
)
|
||||||
|
|
|
@ -110,7 +110,7 @@ def _set_duration_in_model_call_details(
|
||||||
if logging_obj and hasattr(logging_obj, "model_call_details"):
|
if logging_obj and hasattr(logging_obj, "model_call_details"):
|
||||||
logging_obj.model_call_details["llm_api_duration_ms"] = duration_ms
|
logging_obj.model_call_details["llm_api_duration_ms"] = duration_ms
|
||||||
else:
|
else:
|
||||||
verbose_logger.warning(
|
verbose_logger.debug(
|
||||||
"`logging_obj` not found - unable to track `llm_api_duration_ms"
|
"`logging_obj` not found - unable to track `llm_api_duration_ms"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -207,7 +207,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
"api_base": api_base,
|
"api_base": api_base,
|
||||||
"headers": headers,
|
"headers": prepped.headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -226,7 +226,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
url=api_base, headers=headers, data=data, logging_obj=logging_obj
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
logging_obj=logging_obj,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
|
@ -267,6 +270,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -300,8 +304,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||||
|
|
||||||
litellm_params["aws_region_name"] = aws_region_name
|
|
||||||
|
|
||||||
### SET REGION NAME ###
|
### SET REGION NAME ###
|
||||||
if aws_region_name is None:
|
if aws_region_name is None:
|
||||||
# check env #
|
# check env #
|
||||||
|
@ -321,6 +323,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
if aws_region_name is None:
|
if aws_region_name is None:
|
||||||
aws_region_name = "us-west-2"
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
litellm_params["aws_region_name"] = (
|
||||||
|
aws_region_name # [DO NOT DELETE] important for async calls
|
||||||
|
)
|
||||||
|
|
||||||
credentials: Credentials = self.get_credentials(
|
credentials: Credentials = self.get_credentials(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
@ -347,7 +353,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
headers = {"Content-Type": "application/json", **extra_headers}
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
|
|
@ -51,7 +51,6 @@ async def test_endpoint(request: Request):
|
||||||
"/health/services",
|
"/health/services",
|
||||||
tags=["health"],
|
tags=["health"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
include_in_schema=False,
|
|
||||||
)
|
)
|
||||||
async def health_services_endpoint( # noqa: PLR0915
|
async def health_services_endpoint( # noqa: PLR0915
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
@ -64,14 +63,19 @@ async def health_services_endpoint( # noqa: PLR0915
|
||||||
"webhook",
|
"webhook",
|
||||||
"email",
|
"email",
|
||||||
"braintrust",
|
"braintrust",
|
||||||
|
"datadog",
|
||||||
],
|
],
|
||||||
str,
|
str,
|
||||||
] = fastapi.Query(description="Specify the service being hit."),
|
] = fastapi.Query(description="Specify the service being hit."),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Hidden endpoint.
|
Use this admin-only endpoint to check if the service is healthy.
|
||||||
|
|
||||||
Used by the UI to let user check if slack alerting is working as expected.
|
Example:
|
||||||
|
```
|
||||||
|
curl -L -X GET 'http://0.0.0.0:4000/health/services?service=datadog' \
|
||||||
|
-H 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
|
@ -84,6 +88,7 @@ async def health_services_endpoint( # noqa: PLR0915
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail={"error": "Service must be specified."}
|
status_code=400, detail={"error": "Service must be specified."}
|
||||||
)
|
)
|
||||||
|
|
||||||
if service not in [
|
if service not in [
|
||||||
"slack_budget_alerts",
|
"slack_budget_alerts",
|
||||||
"email",
|
"email",
|
||||||
|
@ -95,6 +100,7 @@ async def health_services_endpoint( # noqa: PLR0915
|
||||||
"otel",
|
"otel",
|
||||||
"custom_callback_api",
|
"custom_callback_api",
|
||||||
"langsmith",
|
"langsmith",
|
||||||
|
"datadog",
|
||||||
]:
|
]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
@ -118,8 +124,20 @@ async def health_services_endpoint( # noqa: PLR0915
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "Mock LLM request made - check {}.".format(service),
|
"message": "Mock LLM request made - check {}.".format(service),
|
||||||
}
|
}
|
||||||
|
elif service == "datadog":
|
||||||
|
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||||
|
|
||||||
if service == "langfuse":
|
datadog_logger = DataDogLogger()
|
||||||
|
response = await datadog_logger.async_health_check()
|
||||||
|
return {
|
||||||
|
"status": response["status"],
|
||||||
|
"message": (
|
||||||
|
response["error_message"]
|
||||||
|
if response["status"] == "unhealthy"
|
||||||
|
else "Datadog is healthy"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
elif service == "langfuse":
|
||||||
from litellm.integrations.langfuse.langfuse import LangFuseLogger
|
from litellm.integrations.langfuse.langfuse import LangFuseLogger
|
||||||
|
|
||||||
langfuse_logger = LangFuseLogger()
|
langfuse_logger = LangFuseLogger()
|
||||||
|
|
|
@ -1228,9 +1228,6 @@ class PrismaClient:
|
||||||
"""
|
"""
|
||||||
Generic implementation of get data
|
Generic implementation of get data
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"PrismaClient: get_generic_data: {key}, table_name: {table_name}"
|
|
||||||
)
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
if table_name == "users":
|
if table_name == "users":
|
||||||
|
|
6
litellm/types/integrations/base_health_check.py
Normal file
6
litellm/types/integrations/base_health_check.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from typing import Literal, Optional, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationHealthCheckStatus(TypedDict):
|
||||||
|
status: Literal["healthy", "unhealthy"]
|
||||||
|
error_message: Optional[str]
|
|
@ -4092,7 +4092,7 @@ def _get_potential_model_names(
|
||||||
elif custom_llm_provider and model.startswith(
|
elif custom_llm_provider and model.startswith(
|
||||||
custom_llm_provider + "/"
|
custom_llm_provider + "/"
|
||||||
): # handle case where custom_llm_provider is provided and model starts with custom_llm_provider
|
): # handle case where custom_llm_provider is provided and model starts with custom_llm_provider
|
||||||
split_model = model.split("/")[1]
|
split_model = model.split("/", 1)[1]
|
||||||
combined_model_name = model
|
combined_model_name = model
|
||||||
stripped_model_name = _strip_model_name(
|
stripped_model_name = _strip_model_name(
|
||||||
model=split_model, custom_llm_provider=custom_llm_provider
|
model=split_model, custom_llm_provider=custom_llm_provider
|
||||||
|
|
|
@ -539,17 +539,19 @@ class BaseLLMChatTest(ABC):
|
||||||
|
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def test_completion_cost(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_cost(self):
|
||||||
from litellm import completion_cost
|
from litellm import completion_cost
|
||||||
|
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = self.completion_function(
|
response = await self.async_completion_function(
|
||||||
**self.get_base_completion_call_args(),
|
**self.get_base_completion_call_args(),
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
)
|
)
|
||||||
|
print(response._hidden_params)
|
||||||
cost = completion_cost(response)
|
cost = completion_cost(response)
|
||||||
|
|
||||||
assert cost > 0
|
assert cost > 0
|
||||||
|
|
|
@ -2057,7 +2057,7 @@ def test_bedrock_supports_tool_call(model, expected_supports_tool_call):
|
||||||
assert "tools" not in supported_openai_params
|
assert "tools" not in supported_openai_params
|
||||||
|
|
||||||
|
|
||||||
class TestBedrockConverseChat(BaseLLMChatTest):
|
class TestBedrockConverseChatCrossRegion(BaseLLMChatTest):
|
||||||
def get_base_completion_call_args(self) -> dict:
|
def get_base_completion_call_args(self) -> dict:
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
@ -2104,6 +2104,29 @@ class TestBedrockConverseChat(BaseLLMChatTest):
|
||||||
assert cost > 0
|
assert cost > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockConverseChatNormal(BaseLLMChatTest):
|
||||||
|
def get_base_completion_call_args(self) -> dict:
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
litellm.add_known_models()
|
||||||
|
return {
|
||||||
|
"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"aws_region_name": "us-east-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_multilingual_requests(self):
|
||||||
|
"""
|
||||||
|
Bedrock API raises a 400 BadRequest error when the request contains invalid utf-8 sequences.
|
||||||
|
|
||||||
|
Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestBedrockRerank(BaseLLMRerankTest):
|
class TestBedrockRerank(BaseLLMRerankTest):
|
||||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
return litellm.LlmProviders.BEDROCK
|
return litellm.LlmProviders.BEDROCK
|
||||||
|
|
|
@ -2600,7 +2600,14 @@ async def test_test_completion_cost_gpt4o_audio_output_from_model(stream):
|
||||||
assert round(cost, 2) == round(total_input_cost + total_output_cost, 2)
|
assert round(cost, 2) == round(total_input_cost + total_output_cost, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_completion_cost_azure_ai_meta():
|
@pytest.mark.parametrize(
|
||||||
|
"response_model, custom_llm_provider",
|
||||||
|
[
|
||||||
|
("azure_ai/Meta-Llama-3.1-70B-Instruct", "azure_ai"),
|
||||||
|
("anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_cost_model_response_cost(response_model, custom_llm_provider):
|
||||||
"""
|
"""
|
||||||
Relevant issue: https://github.com/BerriAI/litellm/issues/6310
|
Relevant issue: https://github.com/BerriAI/litellm/issues/6310
|
||||||
"""
|
"""
|
||||||
|
@ -2628,7 +2635,7 @@ def test_completion_cost_azure_ai_meta():
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1729243714,
|
"created": 1729243714,
|
||||||
"model": "azure_ai/Meta-Llama-3.1-70B-Instruct",
|
"model": response_model,
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"service_tier": None,
|
"service_tier": None,
|
||||||
"system_fingerprint": None,
|
"system_fingerprint": None,
|
||||||
|
@ -2642,7 +2649,7 @@ def test_completion_cost_azure_ai_meta():
|
||||||
}
|
}
|
||||||
|
|
||||||
model_response = ModelResponse(**response)
|
model_response = ModelResponse(**response)
|
||||||
cost = completion_cost(model_response, custom_llm_provider="azure_ai")
|
cost = completion_cost(model_response, custom_llm_provider=custom_llm_provider)
|
||||||
|
|
||||||
assert cost > 0
|
assert cost > 0
|
||||||
|
|
||||||
|
@ -2754,3 +2761,14 @@ def test_add_known_models():
|
||||||
assert (
|
assert (
|
||||||
"bedrock/us-west-1/meta.llama3-70b-instruct-v1:0" not in litellm.bedrock_models
|
"bedrock/us-west-1/meta.llama3-70b-instruct-v1:0" not in litellm.bedrock_models
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_cost_calc_with_region():
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
)
|
||||||
|
assert response._hidden_params["response_cost"] > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue