fix(transformation.py): pass back in gemini thinking content to api (#10173)

Ensures thinking content always returned
This commit is contained in:
Krish Dholakia 2025-04-19 18:03:05 -07:00 committed by GitHub
parent bbfcb1ac7e
commit 55a17730fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 54 additions and 18 deletions

View file

@ -216,6 +216,11 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
_message_content = assistant_msg.get("content", None)
reasoning_content = assistant_msg.get("reasoning_content", None)
if reasoning_content is not None:
assistant_content.append(
PartType(thought=True, text=reasoning_content)
)
if _message_content is not None and isinstance(_message_content, list):
_parts = []
for element in _message_content:
@ -223,6 +228,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
if element["type"] == "text":
_part = PartType(text=element["text"])
_parts.append(_part)
assistant_content.extend(_parts)
elif (
_message_content is not None

View file

@ -651,6 +651,7 @@ class OpenAIChatCompletionAssistantMessage(TypedDict, total=False):
name: Optional[str]
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
function_call: Optional[ChatCompletionToolCallFunctionChunk]
reasoning_content: Optional[str]
class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False):
@ -823,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk):
class Hyperparameters(BaseModel):
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
learning_rate_multiplier: Optional[Union[str, float]] = (
None # Scaling factor for the learning rate
)
n_epochs: Optional[Union[str, int]] = (
None # "The number of epochs to train the model for"
)
learning_rate_multiplier: Optional[
Union[str, float]
] = None # Scaling factor for the learning rate
n_epochs: Optional[
Union[str, int]
] = None # "The number of epochs to train the model for"
class FineTuningJobCreate(BaseModel):
@ -855,18 +856,18 @@ class FineTuningJobCreate(BaseModel):
model: str # "The name of the model to fine-tune."
training_file: str # "The ID of an uploaded file that contains training data."
hyperparameters: Optional[Hyperparameters] = (
None # "The hyperparameters used for the fine-tuning job."
)
suffix: Optional[str] = (
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
)
validation_file: Optional[str] = (
None # "The ID of an uploaded file that contains validation data."
)
integrations: Optional[List[str]] = (
None # "A list of integrations to enable for your fine-tuning job."
)
hyperparameters: Optional[
Hyperparameters
] = None # "The hyperparameters used for the fine-tuning job."
suffix: Optional[
str
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name."
validation_file: Optional[
str
] = None # "The ID of an uploaded file that contains validation data."
integrations: Optional[
List[str]
] = None # "A list of integrations to enable for your fine-tuning job."
seed: Optional[int] = None # "The seed controls the reproducibility of the job."

View file

@ -39,6 +39,7 @@ class PartType(TypedDict, total=False):
file_data: FileDataType
function_call: FunctionCall
function_response: FunctionResponse
thought: bool
class HttpxFunctionCall(TypedDict):

View file

@ -89,3 +89,31 @@ def test_gemini_image_generation():
def test_gemini_thinking():
litellm._turn_on_debug()
from litellm.types.utils import Message, CallTypes
from litellm.utils import return_raw_request
import json
messages = [
{"role": "user", "content": "Explain the concept of Occam's Razor and provide a simple, everyday example"}
]
reasoning_content = "I'm thinking about Occam's Razor."
assistant_message = Message(content='Okay, let\'s break down Occam\'s Razor.', reasoning_content=reasoning_content, role='assistant', tool_calls=None, function_call=None, provider_specific_fields=None)
messages.append(assistant_message)
raw_request = return_raw_request(
endpoint=CallTypes.completion,
kwargs={
"model": "gemini/gemini-2.5-flash-preview-04-17",
"messages": messages,
}
)
assert reasoning_content in json.dumps(raw_request)
response = completion(
model="gemini/gemini-2.5-flash-preview-04-17",
messages=messages, # make sure call works
)
print(response.choices[0].message)
assert response.choices[0].message.content is not None