mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(transformation.py): pass back in gemini thinking content to api (#10173)
Ensures thinking content always returned
This commit is contained in:
parent
bbfcb1ac7e
commit
55a17730fb
4 changed files with 54 additions and 18 deletions
|
@ -216,6 +216,11 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||
msg_dict = messages[msg_i] # type: ignore
|
||||
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
||||
_message_content = assistant_msg.get("content", None)
|
||||
reasoning_content = assistant_msg.get("reasoning_content", None)
|
||||
if reasoning_content is not None:
|
||||
assistant_content.append(
|
||||
PartType(thought=True, text=reasoning_content)
|
||||
)
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts = []
|
||||
for element in _message_content:
|
||||
|
@ -223,6 +228,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
|
||||
assistant_content.extend(_parts)
|
||||
elif (
|
||||
_message_content is not None
|
||||
|
|
|
@ -651,6 +651,7 @@ class OpenAIChatCompletionAssistantMessage(TypedDict, total=False):
|
|||
name: Optional[str]
|
||||
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
|
||||
function_call: Optional[ChatCompletionToolCallFunctionChunk]
|
||||
reasoning_content: Optional[str]
|
||||
|
||||
|
||||
class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False):
|
||||
|
@ -823,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk):
|
|||
|
||||
class Hyperparameters(BaseModel):
|
||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||
learning_rate_multiplier: Optional[Union[str, float]] = (
|
||||
None # Scaling factor for the learning rate
|
||||
)
|
||||
n_epochs: Optional[Union[str, int]] = (
|
||||
None # "The number of epochs to train the model for"
|
||||
)
|
||||
learning_rate_multiplier: Optional[
|
||||
Union[str, float]
|
||||
] = None # Scaling factor for the learning rate
|
||||
n_epochs: Optional[
|
||||
Union[str, int]
|
||||
] = None # "The number of epochs to train the model for"
|
||||
|
||||
|
||||
class FineTuningJobCreate(BaseModel):
|
||||
|
@ -855,18 +856,18 @@ class FineTuningJobCreate(BaseModel):
|
|||
|
||||
model: str # "The name of the model to fine-tune."
|
||||
training_file: str # "The ID of an uploaded file that contains training data."
|
||||
hyperparameters: Optional[Hyperparameters] = (
|
||||
None # "The hyperparameters used for the fine-tuning job."
|
||||
)
|
||||
suffix: Optional[str] = (
|
||||
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||
)
|
||||
validation_file: Optional[str] = (
|
||||
None # "The ID of an uploaded file that contains validation data."
|
||||
)
|
||||
integrations: Optional[List[str]] = (
|
||||
None # "A list of integrations to enable for your fine-tuning job."
|
||||
)
|
||||
hyperparameters: Optional[
|
||||
Hyperparameters
|
||||
] = None # "The hyperparameters used for the fine-tuning job."
|
||||
suffix: Optional[
|
||||
str
|
||||
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||
validation_file: Optional[
|
||||
str
|
||||
] = None # "The ID of an uploaded file that contains validation data."
|
||||
integrations: Optional[
|
||||
List[str]
|
||||
] = None # "A list of integrations to enable for your fine-tuning job."
|
||||
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
||||
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class PartType(TypedDict, total=False):
|
|||
file_data: FileDataType
|
||||
function_call: FunctionCall
|
||||
function_response: FunctionResponse
|
||||
thought: bool
|
||||
|
||||
|
||||
class HttpxFunctionCall(TypedDict):
|
||||
|
|
|
@ -89,3 +89,31 @@ def test_gemini_image_generation():
|
|||
|
||||
|
||||
|
||||
def test_gemini_thinking():
|
||||
litellm._turn_on_debug()
|
||||
from litellm.types.utils import Message, CallTypes
|
||||
from litellm.utils import return_raw_request
|
||||
import json
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Explain the concept of Occam's Razor and provide a simple, everyday example"}
|
||||
]
|
||||
reasoning_content = "I'm thinking about Occam's Razor."
|
||||
assistant_message = Message(content='Okay, let\'s break down Occam\'s Razor.', reasoning_content=reasoning_content, role='assistant', tool_calls=None, function_call=None, provider_specific_fields=None)
|
||||
|
||||
messages.append(assistant_message)
|
||||
|
||||
raw_request = return_raw_request(
|
||||
endpoint=CallTypes.completion,
|
||||
kwargs={
|
||||
"model": "gemini/gemini-2.5-flash-preview-04-17",
|
||||
"messages": messages,
|
||||
}
|
||||
)
|
||||
assert reasoning_content in json.dumps(raw_request)
|
||||
response = completion(
|
||||
model="gemini/gemini-2.5-flash-preview-04-17",
|
||||
messages=messages, # make sure call works
|
||||
)
|
||||
print(response.choices[0].message)
|
||||
assert response.choices[0].message.content is not None
|
Loading…
Add table
Add a link
Reference in a new issue