fix:Gemini Flash 2.0 implementation is not returning the logprobs (#9713)

* fix:Gemini Flash 2.0 implementation is not returning the logprobs

* fix: linting error by adding a helper method called _process_candidates
This commit is contained in:
sajda 2025-04-04 00:23:41 +05:30 committed by GitHub
parent 6dda1ba6dd
commit 4a4328b5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 137 additions and 76 deletions

View file

@ -676,6 +676,66 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
return usage
def _process_candidates(self, _candidates, model_response, litellm_params):
"""Helper method to process candidates and extract metadata"""
grounding_metadata: List[dict] = []
safety_ratings: List = []
citation_metadata: List = []
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
chat_completion_logprobs: Optional[ChoiceLogprobs] = None
tools: Optional[List[ChatCompletionToolCallChunk]] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
for idx, candidate in enumerate(_candidates):
if "content" not in candidate:
continue
if "groundingMetadata" in candidate:
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
if "safetyRatings" in candidate:
safety_ratings.append(candidate["safetyRatings"])
if "citationMetadata" in candidate:
citation_metadata.append(candidate["citationMetadata"])
if "parts" in candidate["content"]:
chat_completion_message["content"] = VertexGeminiConfig().get_assistant_content_message(
parts=candidate["content"]["parts"]
)
functions, tools = self._transform_parts(
parts=candidate["content"]["parts"],
index=candidate.get("index", idx),
is_function_call=litellm_params.get("litellm_param_is_function_call"),
)
if "logprobsResult" in candidate:
chat_completion_logprobs = self._transform_logprobs(
logprobs_result=candidate["logprobsResult"]
)
# Handle avgLogprobs for Gemini Flash 2.0
elif "avgLogprobs" in candidate:
chat_completion_logprobs = candidate["avgLogprobs"]
if tools:
chat_completion_message["tool_calls"] = tools
if functions is not None:
chat_completion_message["function_call"] = functions
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=chat_completion_message, # type: ignore
logprobs=chat_completion_logprobs,
enhancements=None,
)
model_response.choices.append(choice)
return grounding_metadata, safety_ratings, citation_metadata
def transform_response(
self,
model: str,
@ -725,9 +785,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
_candidates = completion_response.get("candidates")
if _candidates and len(_candidates) > 0:
content_policy_violations = (
VertexGeminiConfig().get_flagged_finish_reasons()
)
content_policy_violations = VertexGeminiConfig().get_flagged_finish_reasons()
if (
"finishReason" in _candidates[0]
and _candidates[0]["finishReason"] in content_policy_violations.keys()
@ -740,88 +798,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
model_response.choices = [] # type: ignore
try:
## CHECK IF GROUNDING METADATA IN REQUEST
grounding_metadata: List[dict] = []
safety_ratings: List = []
citation_metadata: List = []
## GET TEXT ##
chat_completion_message: ChatCompletionResponseMessage = {
"role": "assistant"
}
chat_completion_logprobs: Optional[ChoiceLogprobs] = None
tools: Optional[List[ChatCompletionToolCallChunk]] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
grounding_metadata, safety_ratings, citation_metadata = [], [], []
if _candidates:
for idx, candidate in enumerate(_candidates):
if "content" not in candidate:
continue
if "groundingMetadata" in candidate:
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
if "safetyRatings" in candidate:
safety_ratings.append(candidate["safetyRatings"])
if "citationMetadata" in candidate:
citation_metadata.append(candidate["citationMetadata"])
if "parts" in candidate["content"]:
chat_completion_message[
"content"
] = VertexGeminiConfig().get_assistant_content_message(
parts=candidate["content"]["parts"]
)
functions, tools = self._transform_parts(
parts=candidate["content"]["parts"],
index=candidate.get("index", idx),
is_function_call=litellm_params.get(
"litellm_param_is_function_call"
),
)
if "logprobsResult" in candidate:
chat_completion_logprobs = self._transform_logprobs(
logprobs_result=candidate["logprobsResult"]
)
if tools:
chat_completion_message["tool_calls"] = tools
if functions is not None:
chat_completion_message["function_call"] = functions
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
message=chat_completion_message, # type: ignore
logprobs=chat_completion_logprobs,
enhancements=None,
)
model_response.choices.append(choice)
grounding_metadata, safety_ratings, citation_metadata = self._process_candidates(
_candidates, model_response, litellm_params
)
usage = self._calculate_usage(completion_response=completion_response)
setattr(model_response, "usage", usage)
## ADD GROUNDING METADATA ##
## ADD METADATA TO RESPONSE ##
setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata)
model_response._hidden_params[
"vertex_ai_grounding_metadata"
] = ( # older approach - maintaining to prevent regressions
grounding_metadata
)
## ADD SAFETY RATINGS ##
model_response._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata
setattr(model_response, "vertex_ai_safety_results", safety_ratings)
model_response._hidden_params[
"vertex_ai_safety_results"
] = safety_ratings # older approach - maintaining to prevent regressions
model_response._hidden_params["vertex_ai_safety_results"] = safety_ratings # older approach - maintaining to prevent regressions
## ADD CITATION METADATA ##
setattr(model_response, "vertex_ai_citation_metadata", citation_metadata)
model_response._hidden_params[
"vertex_ai_citation_metadata"
] = citation_metadata # older approach - maintaining to prevent regressions
model_response._hidden_params["vertex_ai_citation_metadata"] = citation_metadata # older approach - maintaining to prevent regressions
except Exception as e:
raise VertexAIError(

View file

@ -0,0 +1,66 @@
import pytest
import asyncio
from unittest.mock import MagicMock
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
import litellm
from litellm import ModelResponse
@pytest.mark.asyncio
async def test_transform_response_with_avglogprobs():
"""
Test that the transform_response method correctly handles the avgLogprobs key
from Gemini Flash 2.0 responses.
"""
# Create a mock response with avgLogprobs
response_json = {
"candidates": [{
"content": {"parts": [{"text": "Test response"}], "role": "model"},
"finishReason": "STOP",
"avgLogprobs": -0.3445799010140555
}],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 5,
"totalTokenCount": 15
}
}
# Create a mock HTTP response
mock_response = MagicMock()
mock_response.json.return_value = response_json
# Create a mock logging object
mock_logging = MagicMock()
# Create an instance of VertexGeminiConfig
config = VertexGeminiConfig()
# Create a ModelResponse object
model_response = ModelResponse(
id="test-id",
choices=[],
created=1234567890,
model="gemini-2.0-flash",
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
)
# Call the transform_response method
transformed_response = config.transform_response(
model="gemini-2.0-flash",
raw_response=mock_response,
model_response=model_response,
logging_obj=mock_logging,
request_data={},
messages=[],
optional_params={},
litellm_params={},
encoding=None
)
# Assert that the avgLogprobs was correctly added to the model response
assert len(transformed_response.choices) == 1
assert transformed_response.choices[0].logprobs == -0.3445799010140555