Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
Ishaan Jaff
ae23c02b2f fix merge updates 2024-11-12 10:34:08 -08:00
Ishaan Jaff
7973665219 Merge branch 'main' into litellm_fix_track_cost_callback 2024-11-12 10:29:40 -08:00
Ishaan Jaff
69aa10d536 fix test_check_num_callbacks_on_lowest_latency 2024-11-04 13:44:30 -08:00
Ishaan Jaff
501bf6961f fix test_call_with_key_over_model_budget 2024-11-04 12:01:46 -08:00
Ishaan Jaff
9116c09386 fix test key gen prisma 2024-11-04 11:33:27 -08:00
Ishaan Jaff
02cf18be83 StandardLoggingBudgetMetadata 2024-11-04 11:11:09 -08:00
Ishaan Jaff
9864459f4d fix use standard_logging_payload for track cost callback 2024-11-04 11:01:58 -08:00
Ishaan Jaff
2639c1971d use separate file for _PROXY_track_cost_callback 2024-11-04 10:44:13 -08:00
4 changed files with 299 additions and 188 deletions

View file

@ -0,0 +1,136 @@
"""
Proxy Success Callback - handles storing cost of a request in LiteLLM DB.
Updates cost for the following in LiteLLM DB:
- spend logs
- virtual key spend
- internal user, team, external user spend
"""
import asyncio
import traceback
from typing import Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.proxy.utils import log_db_metrics
from litellm.types.utils import StandardLoggingPayload
@log_db_metrics
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
"""
Callback handles storing cost of a request in LiteLLM DB.
Updates cost for the following in LiteLLM DB:
- spend logs
- virtual key spend
- internal user, team, external user spend
"""
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
update_cache,
update_database,
)
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
try:
# check if it has collected an entire stream response
verbose_proxy_logger.debug(
"Proxy: In track_cost_callback for: kwargs=%s and completion_response: %s",
kwargs,
completion_response,
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
raise ValueError(
"standard_logging_payload is none in kwargs, cannot track cost without it"
)
end_user_id = standard_logging_payload.get("end_user")
user_api_key = standard_logging_payload.get("metadata", {}).get(
"user_api_key_hash"
)
user_id = standard_logging_payload.get("metadata", {}).get(
"user_api_key_user_id"
)
team_id = standard_logging_payload.get("metadata", {}).get(
"user_api_key_team_id"
)
org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id")
key_alias = standard_logging_payload.get("metadata", {}).get(
"user_api_key_alias"
)
end_user_max_budget = standard_logging_payload.get("metadata", {}).get(
"user_api_end_user_max_budget"
)
response_cost: Optional[float] = standard_logging_payload.get("response_cost")
if response_cost is not None:
if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE
await update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
end_user_id=end_user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
)
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else:
raise Exception(
"User API key and team id and user id missing from custom callback."
)
else:
cost_tracking_failure_debug_info = standard_logging_payload.get(
"response_cost_failure_debug_info"
)
model = kwargs.get("model")
raise ValueError(
f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.debug("error in tracking cost callback - %s", e)

View file

@ -303,6 +303,8 @@ from fastapi.security import OAuth2PasswordBearer
from fastapi.security.api_key import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback
# import enterprise folder
try:
# when using litellm cli
@ -747,118 +749,6 @@ async def _PROXY_failure_handler(
pass
@log_db_metrics
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
global prisma_client
try:
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
response_cost = (
sl_object.get("response_cost", None)
if sl_object is not None
else kwargs.get("response_cost", None)
)
if response_cost is not None:
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.info(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
)
if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE
await update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
end_user_id=end_user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
)
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else:
raise Exception(
"User API key and team id and user id missing from custom callback."
)
else:
if kwargs["stream"] is not True or (
kwargs["stream"] is True and "complete_streaming_response" in kwargs
):
if sl_object is not None:
cost_tracking_failure_debug_info: Union[dict, str] = (
sl_object["response_cost_failure_debug_info"] # type: ignore
or "response_cost_failure_debug_info is None in standard_logging_object"
)
else:
cost_tracking_failure_debug_info = (
"standard_logging_object not found"
)
model = kwargs.get("model")
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.debug(error_msg)
def error_tracking():
global prisma_client
if prisma_client is not None:

View file

@ -1460,7 +1460,28 @@ class AdapterCompletionStreamWrapper:
raise StopAsyncIteration
class StandardLoggingBudgetMetadata(TypedDict, total=False):
"""
Store Budget related metadata for Team, Internal User, End User etc
"""
user_api_end_user_max_budget: Optional[float]
class StandardLoggingUserAPIKeyMetadata(TypedDict):
"""
Store User API Key related metadata to identify the request
Example:
user_api_key_hash: "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b"
user_api_key_alias: "litellm-key-123"
user_api_key_org_id: "123"
user_api_key_team_id: "456"
user_api_key_user_id: "789"
user_api_key_team_alias: "litellm-team-123"
"""
user_api_key_hash: Optional[str] # hash of the litellm virtual key used
user_api_key_alias: Optional[str]
user_api_key_org_id: Optional[str]
@ -1469,7 +1490,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_team_alias: Optional[str]
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
class StandardLoggingMetadata(
StandardLoggingUserAPIKeyMetadata, StandardLoggingBudgetMetadata
):
"""
Specific metadata k,v pairs logged to integration for easier cost tracking
"""

View file

@ -107,6 +107,12 @@ from litellm.proxy._types import (
UpdateUserRequest,
UserAPIKeyAuth,
)
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingModelInformation,
StandardLoggingMetadata,
StandardLoggingHiddenParams,
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
@ -143,6 +149,58 @@ def prisma_client():
return prisma_client
def create_simple_standard_logging_payload() -> StandardLoggingPayload:
return StandardLoggingPayload(
id="test_id",
call_type="completion",
response_cost=0.1,
response_cost_failure_debug_info=None,
status="success",
total_tokens=30,
prompt_tokens=20,
completion_tokens=10,
startTime=1234567890.0,
endTime=1234567891.0,
completionStartTime=1234567890.5,
model_map_information=StandardLoggingModelInformation(
model_map_key="gpt-3.5-turbo", model_map_value=None
),
model="gpt-3.5-turbo",
model_id="model-123",
model_group="openai-gpt",
api_base="https://api.openai.com",
metadata=StandardLoggingMetadata(
user_api_key_hash="test_hash",
user_api_key_org_id=None,
user_api_key_alias="test_alias",
user_api_key_team_id="test_team",
user_api_key_user_id="test_user",
user_api_key_team_alias="test_team_alias",
spend_logs_metadata=None,
requester_ip_address="127.0.0.1",
requester_metadata=None,
),
cache_hit=False,
cache_key=None,
saved_cache_cost=0.0,
request_tags=[],
end_user=None,
requester_ip_address="127.0.0.1",
messages=[{"role": "user", "content": "Hello, world!"}],
response={"choices": [{"message": {"content": "Hi there!"}}]},
error_str=None,
model_parameters={"stream": True},
hidden_params=StandardLoggingHiddenParams(
model_id="model-123",
cache_key=None,
api_base="https://api.openai.com",
response_cost="0.1",
additional_headers=None,
),
)
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=6, delay=1)
async def test_new_user_response(prisma_client):
@ -521,16 +579,17 @@ def test_call_with_user_over_budget(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": generated_key,
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00002,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -618,21 +677,17 @@ def test_call_with_end_user_over_budget(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 10
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token="sk-1234"
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user
standard_logging_payload["end_user"] = user
await track_cost_callback(
kwargs={
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": "sk-1234",
"user_api_key_user_id": user,
},
"proxy_server_request": {
"body": {
"user": user,
}
},
},
"response_cost": 10,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -724,16 +779,16 @@ def test_call_with_proxy_over_budget(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": generated_key,
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00002,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -815,17 +870,17 @@ def test_call_with_user_over_budget_stream(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"stream": True,
"complete_streaming_response": resp,
"litellm_params": {
"metadata": {
"user_api_key": generated_key,
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00002,
"standard_logging_object": standard_logging_payload,
},
completion_response=ModelResponse(),
start_time=datetime.now(),
@ -921,17 +976,17 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"stream": True,
"complete_streaming_response": resp,
"litellm_params": {
"metadata": {
"user_api_key": generated_key,
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00002,
"standard_logging_object": standard_logging_payload,
},
completion_response=ModelResponse(),
start_time=datetime.now(),
@ -1493,17 +1548,17 @@ def test_call_with_key_over_budget(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"model": "chatgpt-v-2",
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00002,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -1610,17 +1665,17 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"model": "chatgpt-v-2",
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00002,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -1734,10 +1789,17 @@ def test_call_with_key_over_model_budget(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00002
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"model": "chatgpt-v-2",
"stream": False,
"standard_logging_object": standard_logging_payload,
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
@ -1840,17 +1902,17 @@ async def test_call_with_key_never_over_budget(prisma_client):
prompt_tokens=210000, completion_tokens=200000, total_tokens=41000
),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 200000
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"model": "chatgpt-v-2",
"stream": False,
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
}
},
"response_cost": 200000,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -1921,19 +1983,19 @@ async def test_call_with_key_over_budget_stream(prisma_client):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00005
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"call_type": "acompletion",
"model": "sagemaker-chatgpt-v-2",
"stream": True,
"complete_streaming_response": resp,
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00005,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),
@ -2329,19 +2391,19 @@ async def track_cost_callback_helper_fn(generated_key: str, user_id: str):
model="gpt-35-turbo", # azure always has model written like this
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
)
standard_logging_payload = create_simple_standard_logging_payload()
standard_logging_payload["response_cost"] = 0.00005
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
token=generated_key
)
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
await track_cost_callback(
kwargs={
"call_type": "acompletion",
"model": "sagemaker-chatgpt-v-2",
"stream": True,
"complete_streaming_response": resp,
"litellm_params": {
"metadata": {
"user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id,
}
},
"response_cost": 0.00005,
"standard_logging_object": standard_logging_payload,
},
completion_response=resp,
start_time=datetime.now(),