mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix:Gemini Flash 2.0 implementation is not returning the logprobs (#9713)
* fix:Gemini Flash 2.0 implementation is not returning the logprobs * fix: linting error by adding a helper method called _process_candidates
This commit is contained in:
parent
6dda1ba6dd
commit
4a4328b5bb
2 changed files with 137 additions and 76 deletions
|
@ -676,6 +676,66 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
|
||||
return usage
|
||||
|
||||
def _process_candidates(self, _candidates, model_response, litellm_params):
|
||||
"""Helper method to process candidates and extract metadata"""
|
||||
grounding_metadata: List[dict] = []
|
||||
safety_ratings: List = []
|
||||
citation_metadata: List = []
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
chat_completion_logprobs: Optional[ChoiceLogprobs] = None
|
||||
tools: Optional[List[ChatCompletionToolCallChunk]] = []
|
||||
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
||||
|
||||
for idx, candidate in enumerate(_candidates):
|
||||
if "content" not in candidate:
|
||||
continue
|
||||
|
||||
if "groundingMetadata" in candidate:
|
||||
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
|
||||
|
||||
if "safetyRatings" in candidate:
|
||||
safety_ratings.append(candidate["safetyRatings"])
|
||||
|
||||
if "citationMetadata" in candidate:
|
||||
citation_metadata.append(candidate["citationMetadata"])
|
||||
|
||||
if "parts" in candidate["content"]:
|
||||
chat_completion_message["content"] = VertexGeminiConfig().get_assistant_content_message(
|
||||
parts=candidate["content"]["parts"]
|
||||
)
|
||||
|
||||
functions, tools = self._transform_parts(
|
||||
parts=candidate["content"]["parts"],
|
||||
index=candidate.get("index", idx),
|
||||
is_function_call=litellm_params.get("litellm_param_is_function_call"),
|
||||
)
|
||||
|
||||
if "logprobsResult" in candidate:
|
||||
chat_completion_logprobs = self._transform_logprobs(
|
||||
logprobs_result=candidate["logprobsResult"]
|
||||
)
|
||||
# Handle avgLogprobs for Gemini Flash 2.0
|
||||
elif "avgLogprobs" in candidate:
|
||||
chat_completion_logprobs = candidate["avgLogprobs"]
|
||||
|
||||
if tools:
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
if functions is not None:
|
||||
chat_completion_message["function_call"] = functions
|
||||
|
||||
choice = litellm.Choices(
|
||||
finish_reason=candidate.get("finishReason", "stop"),
|
||||
index=candidate.get("index", idx),
|
||||
message=chat_completion_message, # type: ignore
|
||||
logprobs=chat_completion_logprobs,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
model_response.choices.append(choice)
|
||||
|
||||
return grounding_metadata, safety_ratings, citation_metadata
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -725,9 +785,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
|
||||
_candidates = completion_response.get("candidates")
|
||||
if _candidates and len(_candidates) > 0:
|
||||
content_policy_violations = (
|
||||
VertexGeminiConfig().get_flagged_finish_reasons()
|
||||
)
|
||||
content_policy_violations = VertexGeminiConfig().get_flagged_finish_reasons()
|
||||
if (
|
||||
"finishReason" in _candidates[0]
|
||||
and _candidates[0]["finishReason"] in content_policy_violations.keys()
|
||||
|
@ -740,88 +798,25 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
model_response.choices = [] # type: ignore
|
||||
|
||||
try:
|
||||
## CHECK IF GROUNDING METADATA IN REQUEST
|
||||
grounding_metadata: List[dict] = []
|
||||
safety_ratings: List = []
|
||||
citation_metadata: List = []
|
||||
## GET TEXT ##
|
||||
chat_completion_message: ChatCompletionResponseMessage = {
|
||||
"role": "assistant"
|
||||
}
|
||||
chat_completion_logprobs: Optional[ChoiceLogprobs] = None
|
||||
tools: Optional[List[ChatCompletionToolCallChunk]] = []
|
||||
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
||||
grounding_metadata, safety_ratings, citation_metadata = [], [], []
|
||||
if _candidates:
|
||||
for idx, candidate in enumerate(_candidates):
|
||||
if "content" not in candidate:
|
||||
continue
|
||||
|
||||
if "groundingMetadata" in candidate:
|
||||
grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
|
||||
|
||||
if "safetyRatings" in candidate:
|
||||
safety_ratings.append(candidate["safetyRatings"])
|
||||
|
||||
if "citationMetadata" in candidate:
|
||||
citation_metadata.append(candidate["citationMetadata"])
|
||||
if "parts" in candidate["content"]:
|
||||
chat_completion_message[
|
||||
"content"
|
||||
] = VertexGeminiConfig().get_assistant_content_message(
|
||||
parts=candidate["content"]["parts"]
|
||||
)
|
||||
|
||||
functions, tools = self._transform_parts(
|
||||
parts=candidate["content"]["parts"],
|
||||
index=candidate.get("index", idx),
|
||||
is_function_call=litellm_params.get(
|
||||
"litellm_param_is_function_call"
|
||||
),
|
||||
)
|
||||
|
||||
if "logprobsResult" in candidate:
|
||||
chat_completion_logprobs = self._transform_logprobs(
|
||||
logprobs_result=candidate["logprobsResult"]
|
||||
)
|
||||
|
||||
if tools:
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
if functions is not None:
|
||||
chat_completion_message["function_call"] = functions
|
||||
choice = litellm.Choices(
|
||||
finish_reason=candidate.get("finishReason", "stop"),
|
||||
index=candidate.get("index", idx),
|
||||
message=chat_completion_message, # type: ignore
|
||||
logprobs=chat_completion_logprobs,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
model_response.choices.append(choice)
|
||||
grounding_metadata, safety_ratings, citation_metadata = self._process_candidates(
|
||||
_candidates, model_response, litellm_params
|
||||
)
|
||||
|
||||
usage = self._calculate_usage(completion_response=completion_response)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
## ADD GROUNDING METADATA ##
|
||||
## ADD METADATA TO RESPONSE ##
|
||||
setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata)
|
||||
model_response._hidden_params[
|
||||
"vertex_ai_grounding_metadata"
|
||||
] = ( # older approach - maintaining to prevent regressions
|
||||
grounding_metadata
|
||||
)
|
||||
|
||||
## ADD SAFETY RATINGS ##
|
||||
model_response._hidden_params["vertex_ai_grounding_metadata"] = grounding_metadata
|
||||
|
||||
setattr(model_response, "vertex_ai_safety_results", safety_ratings)
|
||||
model_response._hidden_params[
|
||||
"vertex_ai_safety_results"
|
||||
] = safety_ratings # older approach - maintaining to prevent regressions
|
||||
|
||||
model_response._hidden_params["vertex_ai_safety_results"] = safety_ratings # older approach - maintaining to prevent regressions
|
||||
|
||||
## ADD CITATION METADATA ##
|
||||
setattr(model_response, "vertex_ai_citation_metadata", citation_metadata)
|
||||
model_response._hidden_params[
|
||||
"vertex_ai_citation_metadata"
|
||||
] = citation_metadata # older approach - maintaining to prevent regressions
|
||||
model_response._hidden_params["vertex_ai_citation_metadata"] = citation_metadata # older approach - maintaining to prevent regressions
|
||||
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
||||
import litellm
|
||||
from litellm import ModelResponse
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transform_response_with_avglogprobs():
|
||||
"""
|
||||
Test that the transform_response method correctly handles the avgLogprobs key
|
||||
from Gemini Flash 2.0 responses.
|
||||
"""
|
||||
# Create a mock response with avgLogprobs
|
||||
response_json = {
|
||||
"candidates": [{
|
||||
"content": {"parts": [{"text": "Test response"}], "role": "model"},
|
||||
"finishReason": "STOP",
|
||||
"avgLogprobs": -0.3445799010140555
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 5,
|
||||
"totalTokenCount": 15
|
||||
}
|
||||
}
|
||||
|
||||
# Create a mock HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_json
|
||||
|
||||
# Create a mock logging object
|
||||
mock_logging = MagicMock()
|
||||
|
||||
# Create an instance of VertexGeminiConfig
|
||||
config = VertexGeminiConfig()
|
||||
|
||||
# Create a ModelResponse object
|
||||
model_response = ModelResponse(
|
||||
id="test-id",
|
||||
choices=[],
|
||||
created=1234567890,
|
||||
model="gemini-2.0-flash",
|
||||
usage={
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
)
|
||||
|
||||
# Call the transform_response method
|
||||
transformed_response = config.transform_response(
|
||||
model="gemini-2.0-flash",
|
||||
raw_response=mock_response,
|
||||
model_response=model_response,
|
||||
logging_obj=mock_logging,
|
||||
request_data={},
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
encoding=None
|
||||
)
|
||||
|
||||
# Assert that the avgLogprobs was correctly added to the model response
|
||||
assert len(transformed_response.choices) == 1
|
||||
assert transformed_response.choices[0].logprobs == -0.3445799010140555
|
Loading…
Add table
Add a link
Reference in a new issue