Litellm dev 01 14 2025 p2 (#7772)

* feat(pass_through_endpoints.py): fix anthropic end user cost tracking

* fix(anthropic/chat/transformation.py): use returned provider model for anthropic

handles anthropic `-latest` tag in request body throwing cost calculation errors

ensures we can be accurate in our model cost tracking

* feat(model_prices_and_context_window.json): add gemini-2.0-flash-thinking-exp pricing

* test: update test to use assumption that user_api_key_dict can get anthropic user id

* test: fix test

* fix: fix test

* fix(anthropic_pass_through.py): uncomment previous anthropic end-user cost tracking code block

can't guarantee user api key dict always has end user id - too many code paths

* fix(user_api_key_auth.py): this allows end user id from request body to always be read and set in auth object

* fix(auth_check.py): fix linting error

* test: fix auth check

* fix(auth_utils.py): fix get end user id to handle metadata = None
This commit is contained in:
Krish Dholakia 2025-01-15 21:34:50 -08:00 committed by GitHub
parent 73c004cfe5
commit 543655adc7
16 changed files with 287 additions and 43 deletions

View file

@ -3101,7 +3101,7 @@ def get_standard_logging_object_payload(
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
metadata: dict = (
litellm_params.get("litellm_metadata")
or litellm_params.get("metadata", None)
@ -3149,6 +3149,11 @@ def get_standard_logging_object_payload(
prompt_integration=kwargs.get("prompt_integration", None),
)
_request_body = proxy_server_request.get("body", {})
end_user_id = clean_metadata["user_api_key_end_user_id"] or _request_body.get(
"user", None
) # maintain backwards compatibility with old request body check
saved_cache_cost: float = 0.0
if cache_hit is True:

View file

@ -14,6 +14,7 @@ import litellm.types
import litellm.types.utils
from litellm import LlmProviders
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -214,6 +215,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params: dict,
json_mode: bool,
litellm_params: dict,
provider_config: BaseConfig,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
@ -248,7 +250,7 @@ class AnthropicChatCompletion(BaseLLM):
headers=error_headers,
)
return AnthropicConfig().transform_response(
return provider_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
@ -282,6 +284,7 @@ class AnthropicChatCompletion(BaseLLM):
headers={},
client=None,
):
optional_params = copy.deepcopy(optional_params)
stream = optional_params.pop("stream", None)
json_mode: bool = optional_params.pop("json_mode", False)
@ -362,6 +365,7 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
provider_config=config,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
@ -426,7 +430,7 @@ class AnthropicChatCompletion(BaseLLM):
headers=error_headers,
)
return AnthropicConfig().transform_response(
return config.transform_response(
model=model,
raw_response=response,
model_response=model_response,

View file

@ -668,7 +668,7 @@ class AnthropicConfig(BaseConfig):
cache_read_input_tokens: int = 0
model_response.created = int(time.time())
model_response.model = model
model_response.model = completion_response["model"]
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens

View file

@ -1,11 +1,13 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
from typing import List
from typing import Any, List, Optional
import httpx
import litellm
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ....anthropic.chat.transformation import AnthropicConfig
@ -64,6 +66,37 @@ class VertexAIAnthropicConfig(AnthropicConfig):
data.pop("model", None) # vertex anthropic doesn't accept 'model' parameter
return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
response = super().transform_response(
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
response.model = model
return response
@classmethod
def is_supported_model(cls, model: str, custom_llm_provider: str) -> bool:
"""

View file

@ -194,6 +194,7 @@ class VertexAIPartnerModels(VertexBase):
"is_vertex_request": True,
}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,

View file

@ -3263,6 +3263,39 @@
"supports_audio_output": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"gemini-2.0-flash-thinking-exp": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"gemini/gemini-2.0-flash-exp": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
@ -3298,6 +3331,41 @@
"rpm": 10,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"gemini/gemini-2.0-flash-thinking-exp": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"tpm": 4000000,
"rpm": 10,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"vertex_ai/claude-3-sonnet": {
"max_tokens": 4096,
"max_input_tokens": 200000,

View file

@ -55,15 +55,15 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
def _allowed_import_check() -> bool:
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder
# Get the calling frame
caller_frame = inspect.stack()[2]
caller_function = caller_frame.function
caller_function_callable = caller_frame.frame.f_globals.get(caller_function)
allowed_function = "user_api_key_auth"
allowed_signature = inspect.signature(user_api_key_auth)
allowed_function = "_user_api_key_auth_builder"
allowed_signature = inspect.signature(_user_api_key_auth_builder)
if caller_function_callable is None or not callable(caller_function_callable):
raise Exception(f"Caller function {caller_function} is not callable")
caller_signature = inspect.signature(caller_function_callable)
@ -303,7 +303,11 @@ def get_actual_routes(allowed_routes: list) -> list:
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
actual_routes = actual_routes + route_value
if isinstance(route_value, set):
actual_routes.extend(list(route_value))
else:
actual_routes.extend(route_value)
except KeyError:
actual_routes.append(route_name)
return actual_routes

View file

@ -464,6 +464,7 @@ def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
from litellm.proxy.proxy_server import general_settings, premium_user
if premium_user is not True:
return False
# premium use has opted into using client credentials
@ -493,3 +494,17 @@ def _has_user_setup_sso():
)
return sso_setup
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
# openai - check 'user'
if "user" in request_body:
return request_body["user"]
# anthropic - check 'litellm_metadata'
end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
if end_user_id:
return end_user_id
metadata = request_body.get("metadata")
if metadata and "user_id" in metadata:
return metadata["user_id"]
return None

View file

@ -35,6 +35,7 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.auth_utils import (
_get_request_ip_address,
get_end_user_id_from_request_body,
get_request_route,
is_pass_through_provider_route,
pre_db_read_auth_checks,
@ -213,17 +214,25 @@ async def user_api_key_auth_websocket(websocket: WebSocket):
raise HTTPException(status_code=403, detail=str(e))
async def user_api_key_auth( # noqa: PLR0915
request: Request,
api_key: str = fastapi.Security(api_key_header),
azure_api_key_header: str = fastapi.Security(azure_api_key_header),
anthropic_api_key_header: Optional[str] = fastapi.Security(
anthropic_api_key_header
),
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
google_ai_studio_api_key_header
),
def update_valid_token_with_end_user_params(
valid_token: UserAPIKeyAuth, end_user_params: dict
) -> UserAPIKeyAuth:
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get("allowed_model_region")
return valid_token
async def _user_api_key_auth_builder( # noqa: PLR0915
request: Request,
api_key: str,
azure_api_key_header: str,
anthropic_api_key_header: Optional[str],
google_ai_studio_api_key_header: Optional[str],
request_data: dict,
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
general_settings,
jwt_handler,
@ -243,8 +252,9 @@ async def user_api_key_auth( # noqa: PLR0915
start_time = datetime.now()
route: str = get_request_route(request=request)
try:
# get the request body
request_data = await _read_request_body(request=request)
await pre_db_read_auth_checks(
request_data=request_data,
request=request,
@ -608,9 +618,10 @@ async def user_api_key_auth( # noqa: PLR0915
## Check END-USER OBJECT
_end_user_object = None
end_user_params = {}
if "user" in request_data:
end_user_id = get_end_user_id_from_request_body(request_data)
if end_user_id:
try:
end_user_id = request_data["user"]
end_user_params["end_user_id"] = end_user_id
# get end-user object
@ -671,11 +682,8 @@ async def user_api_key_auth( # noqa: PLR0915
and valid_token.user_role == LitellmUserRoles.PROXY_ADMIN
):
# update end-user params on valid token
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
valid_token = update_valid_token_with_end_user_params(
valid_token=valid_token, end_user_params=end_user_params
)
valid_token.parent_otel_span = parent_otel_span
@ -753,6 +761,10 @@ async def user_api_key_auth( # noqa: PLR0915
)
)
_user_api_key_obj = update_valid_token_with_end_user_params(
valid_token=_user_api_key_obj, end_user_params=end_user_params
)
return _user_api_key_obj
## IF it's not a master key
@ -1235,7 +1247,6 @@ async def user_api_key_auth( # noqa: PLR0915
parent_otel_span=parent_otel_span,
api_key=api_key,
)
request_data = await _read_request_body(request=request)
asyncio.create_task(
proxy_logging_obj.post_call_failure_hook(
request_data=request_data,
@ -1270,6 +1281,39 @@ async def user_api_key_auth( # noqa: PLR0915
)
async def user_api_key_auth(
request: Request,
api_key: str = fastapi.Security(api_key_header),
azure_api_key_header: str = fastapi.Security(azure_api_key_header),
anthropic_api_key_header: Optional[str] = fastapi.Security(
anthropic_api_key_header
),
google_ai_studio_api_key_header: Optional[str] = fastapi.Security(
google_ai_studio_api_key_header
),
) -> UserAPIKeyAuth:
"""
Parent function to authenticate user api key / jwt token.
"""
request_data = await _read_request_body(request=request)
user_api_key_auth_obj = await _user_api_key_auth_builder(
request=request,
api_key=api_key,
azure_api_key_header=azure_api_key_header,
anthropic_api_key_header=anthropic_api_key_header,
google_ai_studio_api_key_header=google_ai_studio_api_key_header,
request_data=request_data,
)
end_user_id = get_end_user_id_from_request_body(request_data)
if end_user_id is not None:
user_api_key_auth_obj.end_user_id = end_user_id
return user_api_key_auth_obj
async def _return_user_api_key_auth_obj(
user_obj: Optional[LiteLLM_UserTable],
api_key: str,

View file

@ -15,6 +15,7 @@ from litellm.llms.anthropic.chat.handler import (
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.utils import ModelResponse, TextCompletionResponse
@ -78,12 +79,7 @@ class AnthropicPassthroughLoggingHandler:
) -> Optional[str]:
request_body = passthrough_logging_payload.get("request_body")
if request_body:
end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
if end_user_id:
return end_user_id
return request_body.get("metadata", {}).get(
"user_id", None
) # support anthropic param - https://docs.anthropic.com/en/api/messages
return get_end_user_id_from_request_body(request_body)
return None
@staticmethod

View file

@ -566,7 +566,7 @@ def _init_kwargs_for_pass_through_endpoint(
"user_api_key": user_api_key_dict.api_key,
"user_api_key_user_id": user_api_key_dict.user_id,
"user_api_key_team_id": user_api_key_dict.team_id,
"user_api_key_end_user_id": user_api_key_dict.user_id,
"user_api_key_end_user_id": user_api_key_dict.end_user_id,
}
if _litellm_metadata:
_metadata.update(_litellm_metadata)

View file

@ -3263,6 +3263,39 @@
"supports_audio_output": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"gemini-2.0-flash-thinking-exp": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"gemini/gemini-2.0-flash-exp": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
@ -3298,6 +3331,41 @@
"rpm": 10,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"gemini/gemini-2.0-flash-thinking-exp": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"tpm": 4000000,
"rpm": 10,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash"
},
"vertex_ai/claude-3-sonnet": {
"max_tokens": 4096,
"max_input_tokens": 200000,

View file

@ -1460,8 +1460,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
httpx_response.side_effect = vertex_httpx_mock_post_valid_response_anthropic
else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
resp = None
with patch.object(client, "post", new=httpx_response) as mock_call:
print("SENDING CLIENT POST={}".format(client.post))
litellm.set_verbose = True
print(f"model entering completion: {model}")
try:
resp = completion(
model=model,
@ -1502,6 +1505,9 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
"text"
]
)
elif resp is not None:
assert resp.model == model.split("/")[1].split("@")[0]
@pytest.mark.parametrize(

View file

@ -174,6 +174,7 @@ async def test_anthropic_streaming_with_headers():
"stream": True,
"litellm_metadata": {
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
"user": "test-user-1",
},
}
@ -225,9 +226,9 @@ async def test_anthropic_streaming_with_headers():
assert (
log_entry["call_type"] == "pass_through_endpoint"
), "Call type should be pass_through_endpoint"
assert (
log_entry["api_base"] == "https://api.anthropic.com/v1/messages"
), "API base should be Anthropic's endpoint"
# assert (
# log_entry["api_base"] == "https://api.anthropic.com/v1/messages"
# ), "API base should be Anthropic's endpoint"
# Token and spend assertions
assert log_entry["spend"] > 0, "Spend value should not be None"
@ -265,3 +266,5 @@ async def test_anthropic_streaming_with_headers():
), "Should have user API key in metadata"
assert "claude" in log_entry["model"]
assert log_entry["end_user"] == "test-user-1"

View file

@ -65,6 +65,7 @@ def mock_user_api_key_dict():
api_key="test-key",
user_id="test-user",
team_id="test-team",
end_user_id="test-user",
)

View file

@ -201,7 +201,3 @@ def test_create_anthropic_response_logging_payload(mock_logging_obj, metadata_pa
assert "model" in result
assert "response_cost" in result
assert "standard_logging_object" in result
if metadata_params:
assert "test" == result["standard_logging_object"]["end_user"]
else:
assert "" == result["standard_logging_object"]["end_user"]