forked from phoenix/litellm-mirror
fix handle o1 not supporting system message
This commit is contained in:
parent
f5e9e9fc9a
commit
a5a0773b19
4 changed files with 51 additions and 5 deletions
|
@ -944,6 +944,9 @@ from .llms.OpenAI.openai import (
|
||||||
GroqConfig,
|
GroqConfig,
|
||||||
AzureAIStudioConfig,
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.OpenAI.o1_reasoning import (
|
||||||
|
OpenAIO1Config,
|
||||||
|
)
|
||||||
from .llms.nvidia_nim import NvidiaNimConfig
|
from .llms.nvidia_nim import NvidiaNimConfig
|
||||||
from .llms.cerebras.chat import CerebrasConfig
|
from .llms.cerebras.chat import CerebrasConfig
|
||||||
from .llms.AI21.chat import AI21ChatConfig
|
from .llms.AI21.chat import AI21ChatConfig
|
||||||
|
|
|
@ -12,7 +12,7 @@ Translations handled by LiteLLM:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import types
|
import types
|
||||||
from typing import Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class OpenAIO1Config(OpenAIConfig):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_openai_params = litellm.OpenAIConfig.get_supported_openai_params(
|
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
|
||||||
model="gpt-4o"
|
model="gpt-4o"
|
||||||
)
|
)
|
||||||
non_supported_params = [
|
non_supported_params = [
|
||||||
|
@ -70,3 +70,34 @@ class OpenAIO1Config(OpenAIConfig):
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
optional_params["max_completion_tokens"] = value
|
optional_params["max_completion_tokens"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def is_model_o1_reasoning_model(self, model: str) -> bool:
|
||||||
|
if "o1" in model:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def o1_prompt_factory(self, messages: List[Any]):
|
||||||
|
"""
|
||||||
|
Handles limitations of O-1 model family.
|
||||||
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
|
- role: system ==> translate to role 'user'
|
||||||
|
"""
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
message["role"] = "user"
|
||||||
|
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
new_content = []
|
||||||
|
for content_item in message["content"]:
|
||||||
|
if content_item.get("type") == "image_url":
|
||||||
|
if litellm.drop_params is not True:
|
||||||
|
raise ValueError(
|
||||||
|
"Image content is not supported for O-1 models. Set litellm.drop_param to True to drop image content."
|
||||||
|
)
|
||||||
|
# If drop_param is True, we simply don't add the image content to new_content
|
||||||
|
else:
|
||||||
|
new_content.append(content_item)
|
||||||
|
message["content"] = new_content
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
|
@ -550,6 +550,8 @@ class OpenAIConfig:
|
||||||
] # works across all models
|
] # works across all models
|
||||||
|
|
||||||
model_specific_params = []
|
model_specific_params = []
|
||||||
|
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||||
|
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
|
||||||
if (
|
if (
|
||||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||||
): # gpt-4 does not support 'response_format'
|
): # gpt-4 does not support 'response_format'
|
||||||
|
@ -566,6 +568,12 @@ class OpenAIConfig:
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
""" """
|
||||||
|
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||||
|
return litellm.OpenAIO1Config().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
supported_openai_params = self.get_supported_openai_params(model)
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param in supported_openai_params:
|
if param in supported_openai_params:
|
||||||
|
@ -861,6 +869,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model)
|
||||||
|
and messages is not None
|
||||||
|
):
|
||||||
|
messages = litellm.OpenAIO1Config().o1_prompt_factory(
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
|
|
|
@ -51,6 +51,3 @@ async def test_o1_handle_system_role(respx_mock: MockRouter):
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert isinstance(response, ModelResponse)
|
assert isinstance(response, ModelResponse)
|
||||||
assert response.choices[0].message.content == "Mocked response"
|
assert response.choices[0].message.content == "Mocked response"
|
||||||
|
|
||||||
|
|
||||||
# ... existing code ...
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue