mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_and_google_ai_studio_gemini.py): handle nuance in counting exclusive vs. inclusive tokens
Addresses https://github.com/BerriAI/litellm/pull/10141#discussion_r2052272035
This commit is contained in:
parent
e434ccc7e1
commit
72d89cc47a
2 changed files with 82 additions and 4 deletions
|
@ -57,6 +57,7 @@ from litellm.types.llms.vertex_ai import (
|
|||
LogprobsResult,
|
||||
ToolConfig,
|
||||
Tools,
|
||||
UsageMetadata,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionTokenLogprob,
|
||||
|
@ -740,6 +741,23 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
|
||||
return model_response
|
||||
|
||||
def is_candidate_token_count_inclusive(self, usage_metadata: UsageMetadata) -> bool:
|
||||
"""
|
||||
Check if the candidate token count is inclusive of the thinking token count
|
||||
|
||||
if prompttokencount + candidatesTokenCount == totalTokenCount, then the candidate token count is inclusive of the thinking token count
|
||||
|
||||
else the candidate token count is exclusive of the thinking token count
|
||||
|
||||
Addresses - https://github.com/BerriAI/litellm/pull/10141#discussion_r2052272035
|
||||
"""
|
||||
if usage_metadata.get("promptTokenCount", 0) + usage_metadata.get(
|
||||
"candidatesTokenCount", 0
|
||||
) == usage_metadata.get("totalTokenCount", 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _calculate_usage(
|
||||
self,
|
||||
completion_response: GenerateContentResponseBody,
|
||||
|
@ -768,14 +786,23 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
)
|
||||
|
||||
completion_tokens = completion_response["usageMetadata"].get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
if (
|
||||
not self.is_candidate_token_count_inclusive(
|
||||
completion_response["usageMetadata"]
|
||||
)
|
||||
and reasoning_tokens
|
||||
):
|
||||
completion_tokens = reasoning_tokens + completion_tokens
|
||||
## GET USAGE ##
|
||||
usage = Usage(
|
||||
prompt_tokens=completion_response["usageMetadata"].get(
|
||||
"promptTokenCount", 0
|
||||
),
|
||||
completion_tokens=completion_response["usageMetadata"].get(
|
||||
"candidatesTokenCount", 0
|
||||
),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
|
|
|
@ -10,7 +10,8 @@ from litellm import ModelResponse
|
|||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
from litellm.types.utils import ChoiceLogprobs
|
||||
from litellm.types.llms.vertex_ai import UsageMetadata
|
||||
from litellm.types.utils import ChoiceLogprobs, Usage
|
||||
|
||||
|
||||
def test_top_logprobs():
|
||||
|
@ -259,3 +260,53 @@ def test_vertex_ai_empty_content():
|
|||
content, reasoning_content = v.get_assistant_content_message(parts=parts)
|
||||
assert content is None
|
||||
assert reasoning_content is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"usage_metadata, inclusive, expected_usage",
|
||||
[
|
||||
(
|
||||
UsageMetadata(
|
||||
promptTokenCount=10,
|
||||
candidatesTokenCount=10,
|
||||
totalTokenCount=20,
|
||||
thoughtsTokenCount=5,
|
||||
),
|
||||
True,
|
||||
Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=10,
|
||||
total_tokens=20,
|
||||
reasoning_tokens=5,
|
||||
),
|
||||
),
|
||||
(
|
||||
UsageMetadata(
|
||||
promptTokenCount=10,
|
||||
candidatesTokenCount=5,
|
||||
totalTokenCount=20,
|
||||
thoughtsTokenCount=5,
|
||||
),
|
||||
False,
|
||||
Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=10,
|
||||
total_tokens=20,
|
||||
reasoning_tokens=5,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vertex_ai_candidate_token_count_inclusive(
|
||||
usage_metadata, inclusive, expected_usage
|
||||
):
|
||||
"""
|
||||
Test that the candidate token count is inclusive of the thinking token count
|
||||
"""
|
||||
v = VertexGeminiConfig()
|
||||
assert v.is_candidate_token_count_inclusive(usage_metadata) is inclusive
|
||||
|
||||
usage = v._calculate_usage(completion_response={"usageMetadata": usage_metadata})
|
||||
assert usage.prompt_tokens == expected_usage.prompt_tokens
|
||||
assert usage.completion_tokens == expected_usage.completion_tokens
|
||||
assert usage.total_tokens == expected_usage.total_tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue