Gemini-2.5-flash - support reasoning cost calc + return reasoning content (#10141)

* build(model_prices_and_context_window.json): add vertex ai gemini-2.5-flash pricing

* build(model_prices_and_context_window.json): add gemini reasoning token pricing

* fix(vertex_and_google_ai_studio_gemini.py): support counting thinking tokens for gemini

allows accurate cost calc

* fix(utils.py): add reasoning token cost calc to generic cost calc

ensures gemini-2.5-flash cost calculation is accurate

* build(model_prices_and_context_window.json): mark gemini-2.5-flash as 'supports_reasoning'

* feat(gemini/): support 'thinking' + 'reasoning_effort' params + new unit tests

allow controlling thinking effort for gemini-2.5-flash models

* test: update unit testing

* feat(vertex_and_google_ai_studio_gemini.py): return reasoning content if given in gemini response

* test: update model name

* fix: fix ruff check

* test(test_spend_management_endpoints.py): update tests to be less sensitive to new keys / updates to usage object

* fix(vertex_and_google_ai_studio_gemini.py): fix translation
This commit is contained in:
Krish Dholakia 2025-04-19 09:20:52 -07:00 committed by GitHub
parent 81a08babb1
commit 468cd46bc1
16 changed files with 453 additions and 88 deletions

1
.gitignore vendored
View file

@ -86,3 +86,4 @@ litellm/proxy/db/migrations/0_init/migration.sql
litellm/proxy/db/migrations/*
litellm/proxy/migrations/*config.yaml
litellm/proxy/migrations/*
tests/litellm/litellm_core_utils/llm_cost_calc/log.txt

View file

@ -21,6 +21,10 @@ DEFAULT_MAX_TOKENS = 256 # used when providers need a default
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET = 1024
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET = 2048
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET = 4096
########## Networking constants ##############################################################
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour

View file

@ -267,6 +267,7 @@ def generic_cost_per_token(
## CALCULATE OUTPUT COST
text_tokens = usage.completion_tokens
audio_tokens = 0
reasoning_tokens = 0
if usage.completion_tokens_details is not None:
audio_tokens = (
cast(
@ -282,7 +283,13 @@ def generic_cost_per_token(
)
or usage.completion_tokens # default to completion tokens, if this field is not set
)
reasoning_tokens = (
cast(
Optional[int],
getattr(usage.completion_tokens_details, "reasoning_tokens", 0),
)
or 0
)
## TEXT COST
completion_cost = float(text_tokens) * completion_base_cost
@ -290,6 +297,10 @@ def generic_cost_per_token(
"output_cost_per_audio_token"
)
_output_cost_per_reasoning_token: Optional[float] = model_info.get(
"output_cost_per_reasoning_token"
)
## AUDIO COST
if (
_output_cost_per_audio_token is not None
@ -298,4 +309,12 @@ def generic_cost_per_token(
):
completion_cost += float(audio_tokens) * _output_cost_per_audio_token
## REASONING COST
if (
_output_cost_per_reasoning_token is not None
and reasoning_tokens
and reasoning_tokens > 0
):
completion_cost += float(reasoning_tokens) * _output_cost_per_reasoning_token
return prompt_cost, completion_cost

View file

@ -7,6 +7,9 @@ import httpx
import litellm
from litellm.constants import (
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS,
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
RESPONSE_FORMAT_TOOL_NAME,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
@ -276,11 +279,20 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig):
if reasoning_effort is None:
return None
elif reasoning_effort == "low":
return AnthropicThinkingParam(type="enabled", budget_tokens=1024)
return AnthropicThinkingParam(
type="enabled",
budget_tokens=DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
)
elif reasoning_effort == "medium":
return AnthropicThinkingParam(type="enabled", budget_tokens=2048)
return AnthropicThinkingParam(
type="enabled",
budget_tokens=DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
)
elif reasoning_effort == "high":
return AnthropicThinkingParam(type="enabled", budget_tokens=4096)
return AnthropicThinkingParam(
type="enabled",
budget_tokens=DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
)
else:
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")

View file

@ -7,6 +7,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.vertex_ai import ContentType, PartType
from litellm.utils import supports_reasoning
from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_history
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
@ -67,7 +68,7 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
supported_params = [
"temperature",
"top_p",
"max_tokens",
@ -83,6 +84,10 @@ class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
"frequency_penalty",
"modalities",
]
if supports_reasoning(model):
supported_params.append("reasoning_effort")
supported_params.append("thinking")
return supported_params
def map_openai_params(
self,

View file

@ -24,6 +24,11 @@ import litellm
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.constants import (
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
@ -31,6 +36,7 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.anthropic import AnthropicThinkingParam
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionResponseMessage,
@ -45,6 +51,7 @@ from litellm.types.llms.vertex_ai import (
ContentType,
FunctionCallingConfig,
FunctionDeclaration,
GeminiThinkingConfig,
GenerateContentResponseBody,
HttpxPartType,
LogprobsResult,
@ -59,7 +66,7 @@ from litellm.types.utils import (
TopLogprob,
Usage,
)
from litellm.utils import CustomStreamWrapper, ModelResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, supports_reasoning
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ..common_utils import VertexAIError, _build_vertex_schema
@ -190,7 +197,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
supported_params = [
"temperature",
"top_p",
"max_tokens",
@ -210,6 +217,10 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
"top_logprobs",
"modalities",
]
if supports_reasoning(model):
supported_params.append("reasoning_effort")
supported_params.append("thinking")
return supported_params
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict]
@ -313,10 +324,14 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
if isinstance(old_schema, list):
for item in old_schema:
if isinstance(item, dict):
item = _build_vertex_schema(parameters=item, add_property_ordering=True)
item = _build_vertex_schema(
parameters=item, add_property_ordering=True
)
elif isinstance(old_schema, dict):
old_schema = _build_vertex_schema(parameters=old_schema, add_property_ordering=True)
old_schema = _build_vertex_schema(
parameters=old_schema, add_property_ordering=True
)
return old_schema
def apply_response_schema_transformation(self, value: dict, optional_params: dict):
@ -343,6 +358,43 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
value=optional_params["response_schema"]
)
@staticmethod
def _map_reasoning_effort_to_thinking_budget(
reasoning_effort: str,
) -> GeminiThinkingConfig:
if reasoning_effort == "low":
return {
"thinkingBudget": DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
"includeThoughts": True,
}
elif reasoning_effort == "medium":
return {
"thinkingBudget": DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
"includeThoughts": True,
}
elif reasoning_effort == "high":
return {
"thinkingBudget": DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
"includeThoughts": True,
}
else:
raise ValueError(f"Invalid reasoning effort: {reasoning_effort}")
@staticmethod
def _map_thinking_param(
thinking_param: AnthropicThinkingParam,
) -> GeminiThinkingConfig:
thinking_enabled = thinking_param.get("type") == "enabled"
thinking_budget = thinking_param.get("budget_tokens")
params: GeminiThinkingConfig = {}
if thinking_enabled:
params["includeThoughts"] = True
if thinking_budget:
params["thinkingBudget"] = thinking_budget
return params
def map_openai_params(
self,
non_default_params: Dict,
@ -399,6 +451,16 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
optional_params["tool_choice"] = _tool_choice_value
elif param == "seed":
optional_params["seed"] = value
elif param == "reasoning_effort" and isinstance(value, str):
optional_params[
"thinkingConfig"
] = VertexGeminiConfig._map_reasoning_effort_to_thinking_budget(value)
elif param == "thinking":
optional_params[
"thinkingConfig"
] = VertexGeminiConfig._map_thinking_param(
cast(AnthropicThinkingParam, value)
)
elif param == "modalities" and isinstance(value, list):
response_modalities = []
for modality in value:
@ -514,19 +576,27 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
def get_assistant_content_message(
self, parts: List[HttpxPartType]
) -> Optional[str]:
_content_str = ""
) -> Tuple[Optional[str], Optional[str]]:
content_str: Optional[str] = None
reasoning_content_str: Optional[str] = None
for part in parts:
_content_str = ""
if "text" in part:
_content_str += part["text"]
elif "inlineData" in part: # base64 encoded image
_content_str += "data:{};base64,{}".format(
part["inlineData"]["mimeType"], part["inlineData"]["data"]
)
if part.get("thought") is True:
if reasoning_content_str is None:
reasoning_content_str = ""
reasoning_content_str += _content_str
else:
if content_str is None:
content_str = ""
content_str += _content_str
if _content_str:
return _content_str
return None
return content_str, reasoning_content_str
def _transform_parts(
self,
@ -677,6 +747,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
audio_tokens: Optional[int] = None
text_tokens: Optional[int] = None
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
reasoning_tokens: Optional[int] = None
if "cachedContentTokenCount" in completion_response["usageMetadata"]:
cached_tokens = completion_response["usageMetadata"][
"cachedContentTokenCount"
@ -687,7 +758,10 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
audio_tokens = detail["tokenCount"]
elif detail["modality"] == "TEXT":
text_tokens = detail["tokenCount"]
if "thoughtsTokenCount" in completion_response["usageMetadata"]:
reasoning_tokens = completion_response["usageMetadata"][
"thoughtsTokenCount"
]
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cached_tokens,
audio_tokens=audio_tokens,
@ -703,6 +777,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
),
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
prompt_tokens_details=prompt_tokens_details,
reasoning_tokens=reasoning_tokens,
)
return usage
@ -731,11 +806,16 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
citation_metadata.append(candidate["citationMetadata"])
if "parts" in candidate["content"]:
chat_completion_message[
"content"
] = VertexGeminiConfig().get_assistant_content_message(
(
content,
reasoning_content,
) = VertexGeminiConfig().get_assistant_content_message(
parts=candidate["content"]["parts"]
)
if content is not None:
chat_completion_message["content"] = content
if reasoning_content is not None:
chat_completion_message["reasoning_content"] = reasoning_content
functions, tools = self._transform_parts(
parts=candidate["content"]["parts"],

View file

@ -5178,9 +5178,10 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.0000001,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000060,
"input_cost_per_audio_token": 1e-6,
"input_cost_per_token": 0.15e-6,
"output_cost_per_token": 0.6e-6,
"output_cost_per_reasoning_token": 3.5e-6,
"litellm_provider": "gemini",
"mode": "chat",
"rpm": 10,
@ -5188,9 +5189,39 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_reasoning": true,
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview"
},
"gemini-2.5-flash-preview-04-17": {
"max_tokens": 65536,
"max_input_tokens": 1048576,
"max_output_tokens": 65536,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 1e-6,
"input_cost_per_token": 0.15e-6,
"output_cost_per_token": 0.6e-6,
"output_cost_per_reasoning_token": 3.5e-6,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_reasoning": true,
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_endpoints": ["/v1/chat/completions", "/v1/completions", "/v1/batch"],
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview"

View file

@ -69,6 +69,7 @@ class HttpxPartType(TypedDict, total=False):
functionResponse: FunctionResponse
executableCode: HttpxExecutableCode
codeExecutionResult: HttpxCodeExecutionResult
thought: bool
class HttpxContentType(TypedDict, total=False):
@ -166,6 +167,11 @@ class SafetSettingsConfig(TypedDict, total=False):
method: HarmBlockMethod
class GeminiThinkingConfig(TypedDict, total=False):
includeThoughts: bool
thinkingBudget: int
class GenerationConfig(TypedDict, total=False):
temperature: float
top_p: float
@ -181,6 +187,7 @@ class GenerationConfig(TypedDict, total=False):
responseLogprobs: bool
logprobs: int
responseModalities: List[Literal["TEXT", "IMAGE", "AUDIO", "VIDEO"]]
thinkingConfig: GeminiThinkingConfig
class Tools(TypedDict, total=False):
@ -212,6 +219,7 @@ class UsageMetadata(TypedDict, total=False):
candidatesTokenCount: int
cachedContentTokenCount: int
promptTokensDetails: List[PromptTokensDetails]
thoughtsTokenCount: int
class CachedContent(TypedDict, total=False):

View file

@ -150,6 +150,7 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
] # only for vertex ai models
output_cost_per_image: Optional[float]
output_vector_size: Optional[int]
output_cost_per_reasoning_token: Optional[float]
output_cost_per_video_per_second: Optional[float] # only for vertex ai models
output_cost_per_audio_per_second: Optional[float] # only for vertex ai models
output_cost_per_second: Optional[float] # for OpenAI Speech models
@ -829,8 +830,11 @@ class Usage(CompletionUsage):
# handle reasoning_tokens
_completion_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
if reasoning_tokens:
text_tokens = (
completion_tokens - reasoning_tokens if completion_tokens else None
)
completion_tokens_details = CompletionTokensDetailsWrapper(
reasoning_tokens=reasoning_tokens
reasoning_tokens=reasoning_tokens, text_tokens=text_tokens
)
# Ensure completion_tokens_details is properly handled

View file

@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
kwargs.pop("callbacks", None)
)
dynamic_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = kwargs.pop("callbacks", None)
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
if len(all_callbacks) > 0:
@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = (
reset_retry_policy()
) # prevent infinite loops
kwargs[
"retry_policy"
] = reset_retry_policy() # prevent infinite loops
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915
True # so that main.py adds the function call to the prompt
)
if "tools" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("tools")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("tools")
non_default_params.pop(
"tool_choice", None
) # causes ollama requests to hang
elif "functions" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("functions")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("functions")
elif (
litellm.add_function_to_prompt
): # if user opts to add it to prompt instead
@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915
if "response_format" in non_default_params:
if provider_config is not None:
non_default_params["response_format"] = (
provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
non_default_params[
"response_format"
] = provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
else:
non_default_params["response_format"] = type_to_response_format_param(
@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int:
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
response_obj.choices
)
_choices: Union[
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
] = response_obj.choices
response_str = ""
for choice in _choices:
@ -4563,6 +4563,9 @@ def _get_model_info_helper( # noqa: PLR0915
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_reasoning_token=_model_info.get(
"output_cost_per_reasoning_token", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),

View file

@ -5178,9 +5178,10 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.0000001,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000060,
"input_cost_per_audio_token": 1e-6,
"input_cost_per_token": 0.15e-6,
"output_cost_per_token": 0.6e-6,
"output_cost_per_reasoning_token": 3.5e-6,
"litellm_provider": "gemini",
"mode": "chat",
"rpm": 10,
@ -5188,9 +5189,39 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_reasoning": true,
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview"
},
"gemini-2.5-flash-preview-04-17": {
"max_tokens": 65536,
"max_input_tokens": 1048576,
"max_output_tokens": 65536,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 1e-6,
"input_cost_per_token": 0.15e-6,
"output_cost_per_token": 0.6e-6,
"output_cost_per_reasoning_token": 3.5e-6,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_reasoning": true,
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_endpoints": ["/v1/chat/completions", "/v1/completions", "/v1/batch"],
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview"

View file

@ -10,7 +10,13 @@ from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
)
from litellm.types.llms.openai import FileSearchTool, WebSearchOptions
from litellm.types.utils import ModelInfo, ModelResponse, StandardBuiltInToolsParams
from litellm.types.utils import (
CompletionTokensDetailsWrapper,
ModelInfo,
ModelResponse,
PromptTokensDetailsWrapper,
StandardBuiltInToolsParams,
)
sys.path.insert(
0, os.path.abspath("../../..")
@ -20,6 +26,51 @@ from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_toke
from litellm.types.utils import Usage
def test_reasoning_tokens_gemini():
model = "gemini-2.5-flash-preview-04-17"
custom_llm_provider = "gemini"
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
usage = Usage(
completion_tokens=1578,
prompt_tokens=17,
total_tokens=1595,
completion_tokens_details=CompletionTokensDetailsWrapper(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=952,
rejected_prediction_tokens=None,
text_tokens=626,
),
prompt_tokens_details=PromptTokensDetailsWrapper(
audio_tokens=None, cached_tokens=None, text_tokens=17, image_tokens=None
),
)
model_cost_map = litellm.model_cost[model]
prompt_cost, completion_cost = generic_cost_per_token(
model=model,
usage=usage,
custom_llm_provider=custom_llm_provider,
)
assert round(prompt_cost, 10) == round(
model_cost_map["input_cost_per_token"] * usage.prompt_tokens,
10,
)
assert round(completion_cost, 10) == round(
(
model_cost_map["output_cost_per_token"]
* usage.completion_tokens_details.text_tokens
)
+ (
model_cost_map["output_cost_per_reasoning_token"]
* usage.completion_tokens_details.reasoning_tokens
),
10,
)
def test_generic_cost_per_token_above_200k_tokens():
model = "gemini-2.5-pro-exp-03-25"
custom_llm_provider = "vertex_ai"

View file

@ -1,7 +1,9 @@
import asyncio
from typing import List, cast
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
import litellm
from litellm import ModelResponse
@ -9,8 +11,6 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.utils import ChoiceLogprobs
from pydantic import BaseModel
from typing import List, cast
def test_top_logprobs():
@ -66,7 +66,6 @@ def test_get_model_name_from_gemini_spec_model():
assert result == "ft-uuid-123"
def test_vertex_ai_response_schema_dict():
v = VertexGeminiConfig()
transformed_request = v.map_openai_params(
@ -221,3 +220,22 @@ def test_vertex_ai_retain_property_ordering():
schema = transformed_request["response_schema"]
# should leave existing value alone, despite dictionary ordering
assert schema["propertyOrdering"] == ["thought", "output"]
def test_vertex_ai_thinking_output_part():
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.vertex_ai import HttpxPartType
v = VertexGeminiConfig()
parts = [
HttpxPartType(
thought=True,
text="I'm thinking...",
),
HttpxPartType(text="Hello world"),
]
content, reasoning_content = v.get_assistant_content_message(parts=parts)
assert content == "Hello world"
assert reasoning_content == "I'm thinking..."

View file

@ -20,6 +20,16 @@ from litellm.proxy.hooks.proxy_track_cost_callback import _ProxyDBLogger
from litellm.proxy.proxy_server import app, prisma_client
from litellm.router import Router
ignored_keys = [
"request_id",
"startTime",
"endTime",
"completionStartTime",
"endTime",
"metadata.model_map_information",
"metadata.usage_object",
]
@pytest.fixture
def client():
@ -457,7 +467,7 @@ class TestSpendLogsPayload:
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "usage_object": {"completion_tokens": 20, "prompt_tokens": 10, "total_tokens": 30, "completion_tokens_details": null, "prompt_tokens_details": null}, "model_map_information": {"model_map_key": "gpt-4o", "model_map_value": {"key": "gpt-4o", "max_tokens": 16384, "max_input_tokens": 128000, "max_output_tokens": 16384, "input_cost_per_token": 2.5e-06, "cache_creation_input_token_cost": null, "cache_read_input_token_cost": 1.25e-06, "input_cost_per_character": null, "input_cost_per_token_above_128k_tokens": null, "input_cost_per_token_above_200k_tokens": null, "input_cost_per_query": null, "input_cost_per_second": null, "input_cost_per_audio_token": null, "input_cost_per_token_batches": 1.25e-06, "output_cost_per_token_batches": 5e-06, "output_cost_per_token": 1e-05, "output_cost_per_audio_token": null, "output_cost_per_character": null, "output_cost_per_token_above_128k_tokens": null, "output_cost_per_character_above_128k_tokens": null, "output_cost_per_token_above_200k_tokens": null, "output_cost_per_second": null, "output_cost_per_image": null, "output_vector_size": null, "litellm_provider": "openai", "mode": "chat", "supports_system_messages": true, "supports_response_schema": true, "supports_vision": true, "supports_function_calling": true, "supports_tool_choice": true, "supports_assistant_prefill": false, "supports_prompt_caching": true, "supports_audio_input": false, "supports_audio_output": false, "supports_pdf_input": false, "supports_embedding_image_input": false, "supports_native_streaming": null, "supports_web_search": true, "supports_reasoning": false, "search_context_cost_per_query": {"search_context_size_low": 0.03, "search_context_size_medium": 0.035, "search_context_size_high": 0.05}, "tpm": null, "rpm": null, "supported_openai_params": ["frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "max_tokens", "max_completion_tokens", "modalities", "prediction", "n", "presence_penalty", "seed", "stop", "stream", "stream_options", "temperature", "top_p", "tools", "tool_choice", "function_call", "functions", "max_retries", "extra_headers", "parallel_tool_calls", "audio", "response_format", "user"]}}, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
"metadata": '{"applied_guardrails": [], "batch_models": null, "mcp_tool_call_metadata": null, "usage_object": {"completion_tokens": 20, "prompt_tokens": 10, "total_tokens": 30, "completion_tokens_details": null, "prompt_tokens_details": null}, "model_map_information": {"model_map_key": "gpt-4o", "model_map_value": {"key": "gpt-4o", "max_tokens": 16384, "max_input_tokens": 128000, "max_output_tokens": 16384, "input_cost_per_token": 2.5e-06, "cache_creation_input_token_cost": null, "cache_read_input_token_cost": 1.25e-06, "input_cost_per_character": null, "input_cost_per_token_above_128k_tokens": null, "input_cost_per_token_above_200k_tokens": null, "input_cost_per_query": null, "input_cost_per_second": null, "input_cost_per_audio_token": null, "input_cost_per_token_batches": 1.25e-06, "output_cost_per_token_batches": 5e-06, "output_cost_per_token": 1e-05, "output_cost_per_audio_token": null, "output_cost_per_character": null, "output_cost_per_token_above_128k_tokens": null, "output_cost_per_character_above_128k_tokens": null, "output_cost_per_token_above_200k_tokens": null, "output_cost_per_second": null, "output_cost_per_reasoning_token": null, "output_cost_per_image": null, "output_vector_size": null, "litellm_provider": "openai", "mode": "chat", "supports_system_messages": true, "supports_response_schema": true, "supports_vision": true, "supports_function_calling": true, "supports_tool_choice": true, "supports_assistant_prefill": false, "supports_prompt_caching": true, "supports_audio_input": false, "supports_audio_output": false, "supports_pdf_input": false, "supports_embedding_image_input": false, "supports_native_streaming": null, "supports_web_search": true, "supports_reasoning": false, "search_context_cost_per_query": {"search_context_size_low": 0.03, "search_context_size_medium": 0.035, "search_context_size_high": 0.05}, "tpm": null, "rpm": null, "supported_openai_params": ["frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "max_tokens", "max_completion_tokens", "modalities", "prediction", "n", "presence_penalty", "seed", "stop", "stream", "stream_options", "temperature", "top_p", "tools", "tool_choice", "function_call", "functions", "max_retries", "extra_headers", "parallel_tool_calls", "audio", "response_format", "user"]}}, "additional_usage_values": {"completion_tokens_details": null, "prompt_tokens_details": null}}',
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,
@ -475,19 +485,11 @@ class TestSpendLogsPayload:
}
)
for key, value in expected_payload.items():
if key in [
"request_id",
"startTime",
"endTime",
"completionStartTime",
"endTime",
]:
assert payload[key] is not None
else:
assert (
payload[key] == value
), f"Expected {key} to be {value}, but got {payload[key]}"
differences = _compare_nested_dicts(
payload, expected_payload, ignore_keys=ignored_keys
)
if differences:
assert False, f"Dictionary mismatch: {differences}"
def mock_anthropic_response(*args, **kwargs):
mock_response = MagicMock()
@ -573,19 +575,11 @@ class TestSpendLogsPayload:
}
)
for key, value in expected_payload.items():
if key in [
"request_id",
"startTime",
"endTime",
"completionStartTime",
"endTime",
]:
assert payload[key] is not None
else:
assert (
payload[key] == value
), f"Expected {key} to be {value}, but got {payload[key]}"
differences = _compare_nested_dicts(
payload, expected_payload, ignore_keys=ignored_keys
)
if differences:
assert False, f"Dictionary mismatch: {differences}"
@pytest.mark.asyncio
async def test_spend_logs_payload_success_log_with_router(self):
@ -669,16 +663,71 @@ class TestSpendLogsPayload:
}
)
for key, value in expected_payload.items():
if key in [
"request_id",
"startTime",
"endTime",
"completionStartTime",
"endTime",
]:
assert payload[key] is not None
else:
assert (
payload[key] == value
), f"Expected {key} to be {value}, but got {payload[key]}"
differences = _compare_nested_dicts(
payload, expected_payload, ignore_keys=ignored_keys
)
if differences:
assert False, f"Dictionary mismatch: {differences}"
def _compare_nested_dicts(
actual: dict, expected: dict, path: str = "", ignore_keys: list[str] = []
) -> list[str]:
"""Compare nested dictionaries and return a list of differences in a human-friendly format."""
differences = []
# Check if current path should be ignored
if path in ignore_keys:
return differences
# Check for keys in actual but not in expected
for key in actual.keys():
current_path = f"{path}.{key}" if path else key
if current_path not in ignore_keys and key not in expected:
differences.append(f"Extra key in actual: {current_path}")
for key, expected_value in expected.items():
current_path = f"{path}.{key}" if path else key
if current_path in ignore_keys:
continue
if key not in actual:
differences.append(f"Missing key: {current_path}")
continue
actual_value = actual[key]
# Try to parse JSON strings
if isinstance(expected_value, str):
try:
expected_value = json.loads(expected_value)
except json.JSONDecodeError:
pass
if isinstance(actual_value, str):
try:
actual_value = json.loads(actual_value)
except json.JSONDecodeError:
pass
if isinstance(expected_value, dict) and isinstance(actual_value, dict):
differences.extend(
_compare_nested_dicts(
actual_value, expected_value, current_path, ignore_keys
)
)
elif isinstance(expected_value, dict) or isinstance(actual_value, dict):
differences.append(
f"Type mismatch at {current_path}: expected dict, got {type(actual_value).__name__}"
)
else:
# For non-dict values, only report if they're different
if actual_value != expected_value:
# Format the values to be more readable
actual_str = str(actual_value)
expected_str = str(expected_value)
if len(actual_str) > 50 or len(expected_str) > 50:
actual_str = f"{actual_str[:50]}..."
expected_str = f"{expected_str[:50]}..."
differences.append(
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
)
return differences

View file

@ -76,6 +76,11 @@ class BaseLLMChatTest(ABC):
"""Must return the base completion call args"""
pass
def get_base_completion_call_args_with_reasoning_model(self) -> dict:
"""Must return the base completion call args with reasoning_effort"""
return {}
def test_developer_role_translation(self):
"""
Test that the developer role is translated correctly for non-OpenAI providers.
@ -1126,6 +1131,46 @@ class BaseLLMChatTest(ABC):
print(response)
def test_reasoning_effort(self):
"""Test that reasoning_effort is passed correctly to the model"""
from litellm.utils import supports_reasoning
from litellm import completion
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
base_completion_call_args = self.get_base_completion_call_args_with_reasoning_model()
if len(base_completion_call_args) == 0:
print("base_completion_call_args is empty")
pytest.skip("Model does not support reasoning")
if not supports_reasoning(base_completion_call_args["model"], None):
print("Model does not support reasoning")
pytest.skip("Model does not support reasoning")
_, provider, _, _ = litellm.get_llm_provider(
model=base_completion_call_args["model"]
)
## CHECK PARAM MAPPING
optional_params = get_optional_params(
model=base_completion_call_args["model"],
custom_llm_provider=provider,
reasoning_effort="high",
)
# either accepts reasoning effort or thinking budget
assert "reasoning_effort" in optional_params or "4096" in json.dumps(optional_params)
try:
litellm._turn_on_debug()
response = completion(
**base_completion_call_args,
reasoning_effort="low",
messages=[{"role": "user", "content": "Hello!"}],
)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"Error: {e}")
class BaseOSeriesModelsTest(ABC): # test across azure/openai

View file

@ -17,6 +17,9 @@ from litellm import completion
class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-2.0-flash"}
def get_base_completion_call_args_with_reasoning_model(self) -> dict:
return {"model": "gemini/gemini-2.5-flash-preview-04-17"}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
@ -85,3 +88,4 @@ def test_gemini_image_generation():
assert response.choices[0].message.content is not None