mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(transformation.py): pass back in gemini thinking content to api
Ensures thinking content always returned
This commit is contained in:
parent
bbfcb1ac7e
commit
c141e573ab
4 changed files with 54 additions and 18 deletions
|
@ -216,6 +216,11 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||||
msg_dict = messages[msg_i] # type: ignore
|
msg_dict = messages[msg_i] # type: ignore
|
||||||
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
||||||
_message_content = assistant_msg.get("content", None)
|
_message_content = assistant_msg.get("content", None)
|
||||||
|
reasoning_content = assistant_msg.get("reasoning_content", None)
|
||||||
|
if reasoning_content is not None:
|
||||||
|
assistant_content.append(
|
||||||
|
PartType(thought=True, text=reasoning_content)
|
||||||
|
)
|
||||||
if _message_content is not None and isinstance(_message_content, list):
|
if _message_content is not None and isinstance(_message_content, list):
|
||||||
_parts = []
|
_parts = []
|
||||||
for element in _message_content:
|
for element in _message_content:
|
||||||
|
@ -223,6 +228,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||||
if element["type"] == "text":
|
if element["type"] == "text":
|
||||||
_part = PartType(text=element["text"])
|
_part = PartType(text=element["text"])
|
||||||
_parts.append(_part)
|
_parts.append(_part)
|
||||||
|
|
||||||
assistant_content.extend(_parts)
|
assistant_content.extend(_parts)
|
||||||
elif (
|
elif (
|
||||||
_message_content is not None
|
_message_content is not None
|
||||||
|
|
|
@ -651,6 +651,7 @@ class OpenAIChatCompletionAssistantMessage(TypedDict, total=False):
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
|
tool_calls: Optional[List[ChatCompletionAssistantToolCall]]
|
||||||
function_call: Optional[ChatCompletionToolCallFunctionChunk]
|
function_call: Optional[ChatCompletionToolCallFunctionChunk]
|
||||||
|
reasoning_content: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False):
|
class ChatCompletionAssistantMessage(OpenAIChatCompletionAssistantMessage, total=False):
|
||||||
|
@ -823,12 +824,12 @@ class OpenAIChatCompletionChunk(ChatCompletionChunk):
|
||||||
|
|
||||||
class Hyperparameters(BaseModel):
|
class Hyperparameters(BaseModel):
|
||||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||||
learning_rate_multiplier: Optional[Union[str, float]] = (
|
learning_rate_multiplier: Optional[
|
||||||
None # Scaling factor for the learning rate
|
Union[str, float]
|
||||||
)
|
] = None # Scaling factor for the learning rate
|
||||||
n_epochs: Optional[Union[str, int]] = (
|
n_epochs: Optional[
|
||||||
None # "The number of epochs to train the model for"
|
Union[str, int]
|
||||||
)
|
] = None # "The number of epochs to train the model for"
|
||||||
|
|
||||||
|
|
||||||
class FineTuningJobCreate(BaseModel):
|
class FineTuningJobCreate(BaseModel):
|
||||||
|
@ -855,18 +856,18 @@ class FineTuningJobCreate(BaseModel):
|
||||||
|
|
||||||
model: str # "The name of the model to fine-tune."
|
model: str # "The name of the model to fine-tune."
|
||||||
training_file: str # "The ID of an uploaded file that contains training data."
|
training_file: str # "The ID of an uploaded file that contains training data."
|
||||||
hyperparameters: Optional[Hyperparameters] = (
|
hyperparameters: Optional[
|
||||||
None # "The hyperparameters used for the fine-tuning job."
|
Hyperparameters
|
||||||
)
|
] = None # "The hyperparameters used for the fine-tuning job."
|
||||||
suffix: Optional[str] = (
|
suffix: Optional[
|
||||||
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
str
|
||||||
)
|
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||||
validation_file: Optional[str] = (
|
validation_file: Optional[
|
||||||
None # "The ID of an uploaded file that contains validation data."
|
str
|
||||||
)
|
] = None # "The ID of an uploaded file that contains validation data."
|
||||||
integrations: Optional[List[str]] = (
|
integrations: Optional[
|
||||||
None # "A list of integrations to enable for your fine-tuning job."
|
List[str]
|
||||||
)
|
] = None # "A list of integrations to enable for your fine-tuning job."
|
||||||
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ class PartType(TypedDict, total=False):
|
||||||
file_data: FileDataType
|
file_data: FileDataType
|
||||||
function_call: FunctionCall
|
function_call: FunctionCall
|
||||||
function_response: FunctionResponse
|
function_response: FunctionResponse
|
||||||
|
thought: bool
|
||||||
|
|
||||||
|
|
||||||
class HttpxFunctionCall(TypedDict):
|
class HttpxFunctionCall(TypedDict):
|
||||||
|
|
|
@ -89,3 +89,31 @@ def test_gemini_image_generation():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemini_thinking():
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
from litellm.types.utils import Message, CallTypes
|
||||||
|
from litellm.utils import return_raw_request
|
||||||
|
import json
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Explain the concept of Occam's Razor and provide a simple, everyday example"}
|
||||||
|
]
|
||||||
|
reasoning_content = "I'm thinking about Occam's Razor."
|
||||||
|
assistant_message = Message(content='Okay, let\'s break down Occam\'s Razor.', reasoning_content=reasoning_content, role='assistant', tool_calls=None, function_call=None, provider_specific_fields=None)
|
||||||
|
|
||||||
|
messages.append(assistant_message)
|
||||||
|
|
||||||
|
raw_request = return_raw_request(
|
||||||
|
endpoint=CallTypes.completion,
|
||||||
|
kwargs={
|
||||||
|
"model": "gemini/gemini-2.5-flash-preview-04-17",
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert reasoning_content in json.dumps(raw_request)
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.5-flash-preview-04-17",
|
||||||
|
messages=messages, # make sure call works
|
||||||
|
)
|
||||||
|
print(response.choices[0].message)
|
||||||
|
assert response.choices[0].message.content is not None
|
Loading…
Add table
Add a link
Reference in a new issue