forked from phoenix/litellm-mirror
Compare commits
8 commits
main
...
litellm_fi
Author | SHA1 | Date | |
---|---|---|---|
|
ae23c02b2f | ||
|
7973665219 | ||
|
69aa10d536 | ||
|
501bf6961f | ||
|
9116c09386 | ||
|
02cf18be83 | ||
|
9864459f4d | ||
|
2639c1971d |
4 changed files with 299 additions and 188 deletions
136
litellm/proxy/hooks/proxy_track_cost_callback.py
Normal file
136
litellm/proxy/hooks/proxy_track_cost_callback.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Proxy Success Callback - handles storing cost of a request in LiteLLM DB.
|
||||
|
||||
Updates cost for the following in LiteLLM DB:
|
||||
- spend logs
|
||||
- virtual key spend
|
||||
- internal user, team, external user spend
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.proxy.utils import log_db_metrics
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: litellm.ModelResponse, # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
"""
|
||||
Callback handles storing cost of a request in LiteLLM DB.
|
||||
|
||||
Updates cost for the following in LiteLLM DB:
|
||||
- spend logs
|
||||
- virtual key spend
|
||||
- internal user, team, external user spend
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
update_cache,
|
||||
update_database,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
||||
try:
|
||||
# check if it has collected an entire stream response
|
||||
verbose_proxy_logger.debug(
|
||||
"Proxy: In track_cost_callback for: kwargs=%s and completion_response: %s",
|
||||
kwargs,
|
||||
completion_response,
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError(
|
||||
"standard_logging_payload is none in kwargs, cannot track cost without it"
|
||||
)
|
||||
end_user_id = standard_logging_payload.get("end_user")
|
||||
user_api_key = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_hash"
|
||||
)
|
||||
user_id = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_user_id"
|
||||
)
|
||||
team_id = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_team_id"
|
||||
)
|
||||
org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id")
|
||||
key_alias = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_key_alias"
|
||||
)
|
||||
end_user_max_budget = standard_logging_payload.get("metadata", {}).get(
|
||||
"user_api_end_user_max_budget"
|
||||
)
|
||||
response_cost: Optional[float] = standard_logging_payload.get("response_cost")
|
||||
|
||||
if response_cost is not None:
|
||||
if user_api_key is not None or user_id is not None or team_id is not None:
|
||||
## UPDATE DATABASE
|
||||
await update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
token=user_api_key,
|
||||
key_alias=key_alias,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
max_budget=end_user_max_budget,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"User API key and team id and user id missing from custom callback."
|
||||
)
|
||||
else:
|
||||
cost_tracking_failure_debug_info = standard_logging_payload.get(
|
||||
"response_cost_failure_debug_info"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise ValueError(
|
||||
f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failed_tracking_alert(
|
||||
error_message=error_msg,
|
||||
failing_model=model,
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug("error in tracking cost callback - %s", e)
|
|
@ -303,6 +303,8 @@ from fastapi.security import OAuth2PasswordBearer
|
|||
from fastapi.security.api_key import APIKeyHeader
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback
|
||||
|
||||
# import enterprise folder
|
||||
try:
|
||||
# when using litellm cli
|
||||
|
@ -747,118 +749,6 @@ async def _PROXY_failure_handler(
|
|||
pass
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: litellm.ModelResponse, # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
||||
global prisma_client
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = metadata.get("user_api_key_user_id", None)
|
||||
team_id = metadata.get("user_api_key_team_id", None)
|
||||
org_id = metadata.get("user_api_key_org_id", None)
|
||||
key_alias = metadata.get("user_api_key_alias", None)
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
response_cost = (
|
||||
sl_object.get("response_cost", None)
|
||||
if sl_object is not None
|
||||
else kwargs.get("response_cost", None)
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
if kwargs.get("cache_hit", False) is True:
|
||||
response_cost = 0.0
|
||||
verbose_proxy_logger.info(
|
||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
|
||||
)
|
||||
if user_api_key is not None or user_id is not None or team_id is not None:
|
||||
## UPDATE DATABASE
|
||||
await update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
token=user_api_key,
|
||||
key_alias=key_alias,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
max_budget=end_user_max_budget,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"User API key and team id and user id missing from custom callback."
|
||||
)
|
||||
else:
|
||||
if kwargs["stream"] is not True or (
|
||||
kwargs["stream"] is True and "complete_streaming_response" in kwargs
|
||||
):
|
||||
if sl_object is not None:
|
||||
cost_tracking_failure_debug_info: Union[dict, str] = (
|
||||
sl_object["response_cost_failure_debug_info"] # type: ignore
|
||||
or "response_cost_failure_debug_info is None in standard_logging_object"
|
||||
)
|
||||
else:
|
||||
cost_tracking_failure_debug_info = (
|
||||
"standard_logging_object not found"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise Exception(
|
||||
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failed_tracking_alert(
|
||||
error_message=error_msg,
|
||||
failing_model=model,
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(error_msg)
|
||||
|
||||
|
||||
def error_tracking():
|
||||
global prisma_client
|
||||
if prisma_client is not None:
|
||||
|
|
|
@ -1460,7 +1460,28 @@ class AdapterCompletionStreamWrapper:
|
|||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class StandardLoggingBudgetMetadata(TypedDict, total=False):
|
||||
"""
|
||||
Store Budget related metadata for Team, Internal User, End User etc
|
||||
"""
|
||||
|
||||
user_api_end_user_max_budget: Optional[float]
|
||||
|
||||
|
||||
class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
||||
"""
|
||||
Store User API Key related metadata to identify the request
|
||||
|
||||
Example:
|
||||
user_api_key_hash: "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b"
|
||||
user_api_key_alias: "litellm-key-123"
|
||||
user_api_key_org_id: "123"
|
||||
user_api_key_team_id: "456"
|
||||
user_api_key_user_id: "789"
|
||||
user_api_key_team_alias: "litellm-team-123"
|
||||
|
||||
"""
|
||||
|
||||
user_api_key_hash: Optional[str] # hash of the litellm virtual key used
|
||||
user_api_key_alias: Optional[str]
|
||||
user_api_key_org_id: Optional[str]
|
||||
|
@ -1469,7 +1490,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
|||
user_api_key_team_alias: Optional[str]
|
||||
|
||||
|
||||
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||
class StandardLoggingMetadata(
|
||||
StandardLoggingUserAPIKeyMetadata, StandardLoggingBudgetMetadata
|
||||
):
|
||||
"""
|
||||
Specific metadata k,v pairs logged to integration for easier cost tracking
|
||||
"""
|
||||
|
|
|
@ -107,6 +107,12 @@ from litellm.proxy._types import (
|
|||
UpdateUserRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingMetadata,
|
||||
StandardLoggingHiddenParams,
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
|
@ -143,6 +149,58 @@ def prisma_client():
|
|||
return prisma_client
|
||||
|
||||
|
||||
def create_simple_standard_logging_payload() -> StandardLoggingPayload:
|
||||
|
||||
return StandardLoggingPayload(
|
||||
id="test_id",
|
||||
call_type="completion",
|
||||
response_cost=0.1,
|
||||
response_cost_failure_debug_info=None,
|
||||
status="success",
|
||||
total_tokens=30,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
startTime=1234567890.0,
|
||||
endTime=1234567891.0,
|
||||
completionStartTime=1234567890.5,
|
||||
model_map_information=StandardLoggingModelInformation(
|
||||
model_map_key="gpt-3.5-turbo", model_map_value=None
|
||||
),
|
||||
model="gpt-3.5-turbo",
|
||||
model_id="model-123",
|
||||
model_group="openai-gpt",
|
||||
api_base="https://api.openai.com",
|
||||
metadata=StandardLoggingMetadata(
|
||||
user_api_key_hash="test_hash",
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_alias="test_alias",
|
||||
user_api_key_team_id="test_team",
|
||||
user_api_key_user_id="test_user",
|
||||
user_api_key_team_alias="test_team_alias",
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address="127.0.0.1",
|
||||
requester_metadata=None,
|
||||
),
|
||||
cache_hit=False,
|
||||
cache_key=None,
|
||||
saved_cache_cost=0.0,
|
||||
request_tags=[],
|
||||
end_user=None,
|
||||
requester_ip_address="127.0.0.1",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
response={"choices": [{"message": {"content": "Hi there!"}}]},
|
||||
error_str=None,
|
||||
model_parameters={"stream": True},
|
||||
hidden_params=StandardLoggingHiddenParams(
|
||||
model_id="model-123",
|
||||
cache_key=None,
|
||||
api_base="https://api.openai.com",
|
||||
response_cost="0.1",
|
||||
additional_headers=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
async def test_new_user_response(prisma_client):
|
||||
|
@ -521,16 +579,17 @@ def test_call_with_user_over_budget(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": generated_key,
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -618,21 +677,17 @@ def test_call_with_end_user_over_budget(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 10
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token="sk-1234"
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user
|
||||
standard_logging_payload["end_user"] = user
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": "sk-1234",
|
||||
"user_api_key_user_id": user,
|
||||
},
|
||||
"proxy_server_request": {
|
||||
"body": {
|
||||
"user": user,
|
||||
}
|
||||
},
|
||||
},
|
||||
"response_cost": 10,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -724,16 +779,16 @@ def test_call_with_proxy_over_budget(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": generated_key,
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -815,17 +870,17 @@ def test_call_with_user_over_budget_stream(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"stream": True,
|
||||
"complete_streaming_response": resp,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": generated_key,
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=ModelResponse(),
|
||||
start_time=datetime.now(),
|
||||
|
@ -921,17 +976,17 @@ def test_call_with_proxy_over_budget_stream(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"stream": True,
|
||||
"complete_streaming_response": resp,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": generated_key,
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=ModelResponse(),
|
||||
start_time=datetime.now(),
|
||||
|
@ -1493,17 +1548,17 @@ def test_call_with_key_over_budget(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"model": "chatgpt-v-2",
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": hash_token(generated_key),
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -1610,17 +1665,17 @@ def test_call_with_key_over_budget_no_cache(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"model": "chatgpt-v-2",
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": hash_token(generated_key),
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00002,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -1734,10 +1789,17 @@ def test_call_with_key_over_model_budget(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00002
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"model": "chatgpt-v-2",
|
||||
"stream": False,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": hash_token(generated_key),
|
||||
|
@ -1840,17 +1902,17 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
|||
prompt_tokens=210000, completion_tokens=200000, total_tokens=41000
|
||||
),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 200000
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"model": "chatgpt-v-2",
|
||||
"stream": False,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": hash_token(generated_key),
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 200000,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -1921,19 +1983,19 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00005
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"call_type": "acompletion",
|
||||
"model": "sagemaker-chatgpt-v-2",
|
||||
"stream": True,
|
||||
"complete_streaming_response": resp,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": hash_token(generated_key),
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00005,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
@ -2329,19 +2391,19 @@ async def track_cost_callback_helper_fn(generated_key: str, user_id: str):
|
|||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410),
|
||||
)
|
||||
standard_logging_payload = create_simple_standard_logging_payload()
|
||||
standard_logging_payload["response_cost"] = 0.00005
|
||||
standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token(
|
||||
token=generated_key
|
||||
)
|
||||
standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id
|
||||
await track_cost_callback(
|
||||
kwargs={
|
||||
"call_type": "acompletion",
|
||||
"model": "sagemaker-chatgpt-v-2",
|
||||
"stream": True,
|
||||
"complete_streaming_response": resp,
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": hash_token(generated_key),
|
||||
"user_api_key_user_id": user_id,
|
||||
}
|
||||
},
|
||||
"response_cost": 0.00005,
|
||||
"standard_logging_object": standard_logging_payload,
|
||||
},
|
||||
completion_response=resp,
|
||||
start_time=datetime.now(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue