diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 69d6adca4..9239b7b90 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2359,6 +2359,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_mlflow_logger) return _mlflow_logger # type: ignore + def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, ) -> Optional[CustomLogger]: @@ -2719,6 +2720,31 @@ class StandardLoggingPayloadSetup: return clean_hidden_params +class StandardLoggingPayloadAccessors: + """ + Accessor methods for StandardLoggingPayload + + Class that allows easily reading fields from StandardLoggingPayload + + """ + + @staticmethod + def get_custom_llm_provider_from_standard_logging_payload( + standard_logging_payload: Optional[StandardLoggingPayload], + ) -> Optional[str]: + """ + Accessor method to safely get custom_llm_provider from standard_logging_payload + """ + if standard_logging_payload is None: + return None + model_map_information = ( + standard_logging_payload.get("model_map_information") or {} + ) + model_map_value = model_map_information.get("model_map_value") or {} + custom_llm_provider = model_map_value.get("litellm_provider") + return custom_llm_provider + + def get_standard_logging_object_payload( kwargs: Optional[dict], init_response_obj: Union[Any, BaseModel, dict], diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index f5851ded9..09fe3c2fa 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1773,6 +1773,7 @@ class SpendLogsPayload(TypedDict): model_id: Optional[str] model_group: Optional[str] api_base: str + custom_llm_provider: Optional[str] user: str metadata: str # json str cache_hit: str diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 64045999c..7f717e6c8 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -192,6 +192,7 @@ model LiteLLM_SpendLogs { model_id String? @default("") // the model id stored in proxy model db model_group String? @default("") // public model_name / model_group api_base String? @default("") + custom_llm_provider String? @default("") // openai, vertex_ai etc user String? @default("") metadata Json? @default("{}") cache_hit String? @default("") diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 48924d521..e6a79b05e 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -10,8 +10,10 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadAccessors from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy.utils import PrismaClient, hash_token +from litellm.types.utils import StandardLoggingPayload def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: @@ -49,6 +51,10 @@ def get_logging_payload( response_obj = {} # standardize this function to be used across, s3, dynamoDB, langfuse logging litellm_params = kwargs.get("litellm_params", {}) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + metadata = ( litellm_params.get("metadata", {}) or {} ) # if litellm_params['metadata'] == None @@ -150,6 +156,9 @@ def get_logging_payload( request_tags=request_tags, end_user=end_user_id or "", api_base=litellm_params.get("api_base", ""), + custom_llm_provider=StandardLoggingPayloadAccessors.get_custom_llm_provider_from_standard_logging_payload( + standard_logging_payload + ), model_group=_model_group, model_id=_model_id, requester_ip_address=clean_metadata.get("requester_ip_address", None), diff --git a/schema.prisma b/schema.prisma index 64045999c..7f717e6c8 100644 --- a/schema.prisma +++ b/schema.prisma @@ -192,6 +192,7 @@ model LiteLLM_SpendLogs { model_id String? @default("") // the model id stored in proxy model db model_group String? @default("") // public model_name / model_group api_base String? @default("") + custom_llm_provider String? @default("") // openai, vertex_ai etc user String? @default("") metadata Json? @default("{}") cache_hit String? @default("") diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 3ce4cb7d7..43bcfc882 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries=3 +# litellm.num_retries = 3 litellm.cache = None litellm.success_callback = [] diff --git a/tests/local_testing/test_spend_logs.py b/tests/local_testing/test_spend_logs.py index 926f4b5ad..fe292c12d 100644 --- a/tests/local_testing/test_spend_logs.py +++ b/tests/local_testing/test_spend_logs.py @@ -28,12 +28,25 @@ import litellm from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload from litellm.proxy.utils import SpendLogsMetadata, SpendLogsPayload # noqa: E402 +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingModelInformation, + StandardLoggingMetadata, + StandardLoggingHiddenParams, + ModelInfo, +) + def test_spend_logs_payload(): """ Ensure only expected values are logged in spend logs payload. """ + standard_logging_payload = _create_standard_logging_payload() + standard_logging_payload["model_map_information"]["model_map_value"][ + "litellm_provider" + ] = "very-obscure-name" + input_args: dict = { "kwargs": { "model": "chatgpt-v-2", @@ -47,6 +60,7 @@ def test_spend_logs_payload(): "user": "116544810872468347480", "extra_body": {}, }, + "standard_logging_object": standard_logging_payload, "litellm_params": { "acompletion": True, "api_key": "23c217a5b59f41b6b7a198017f4792f2", @@ -205,6 +219,9 @@ def test_spend_logs_payload(): assert ( payload["request_tags"] == '["model-anthropic-claude-v2.1", "app-ishaan-prod"]' ) + print("payload['custom_llm_provider']", payload["custom_llm_provider"]) + # Ensures custom llm provider is logged + read from standard logging payload + assert payload["custom_llm_provider"] == "very-obscure-name" def test_spend_logs_payload_whisper(): @@ -292,3 +309,61 @@ def test_spend_logs_payload_whisper(): assert payload["call_type"] == "atranscription" assert payload["spend"] == 0.00023398580000000003 + + +def _create_standard_logging_payload() -> StandardLoggingPayload: + """ + helper function that creates a standard logging payload for testing + + in the test you can edit the values in SLP that you would need + """ + return StandardLoggingPayload( + id="test_id", + type="test_id", + call_type="completion", + response_cost=0.1, + response_cost_failure_debug_info=None, + status="success", + total_tokens=30, + prompt_tokens=20, + completion_tokens=10, + startTime=1234567890.0, + endTime=1234567891.0, + completionStartTime=1234567890.5, + model_map_information=StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", + model_map_value=ModelInfo(litellm_provider="azure"), + ), + model="gpt-3.5-turbo", + model_id="model-123", + model_group="openai-gpt", + api_base="https://api.openai.com", + metadata=StandardLoggingMetadata( + user_api_key_hash="test_hash", + user_api_key_org_id=None, + user_api_key_alias="test_alias", + user_api_key_team_id="test_team", + user_api_key_user_id="test_user", + user_api_key_team_alias="test_team_alias", + spend_logs_metadata=None, + requester_ip_address="127.0.0.1", + requester_metadata=None, + ), + cache_hit=False, + cache_key=None, + saved_cache_cost=0.0, + request_tags=[], + end_user=None, + requester_ip_address="127.0.0.1", + messages=[{"role": "user", "content": "Hello, world!"}], + response={"choices": [{"message": {"content": "Hi there!"}}]}, + error_str=None, + model_parameters={"stream": True}, + hidden_params=StandardLoggingHiddenParams( + model_id="model-123", + cache_key=None, + api_base="https://api.openai.com", + response_cost="0.1", + additional_headers=None, + ), + )