fix handle o1 not supporting system message

This commit is contained in:
Ishaan Jaff 2024-09-12 14:09:13 -07:00
parent f5e9e9fc9a
commit a5a0773b19
4 changed files with 51 additions and 5 deletions

View file

@ -944,6 +944,9 @@ from .llms.OpenAI.openai import (
GroqConfig, GroqConfig,
AzureAIStudioConfig, AzureAIStudioConfig,
) )
from .llms.OpenAI.o1_reasoning import (
OpenAIO1Config,
)
from .llms.nvidia_nim import NvidiaNimConfig from .llms.nvidia_nim import NvidiaNimConfig
from .llms.cerebras.chat import CerebrasConfig from .llms.cerebras.chat import CerebrasConfig
from .llms.AI21.chat import AI21ChatConfig from .llms.AI21.chat import AI21ChatConfig

View file

@ -12,7 +12,7 @@ Translations handled by LiteLLM:
""" """
import types import types
from typing import Optional, Union from typing import Any, List, Optional, Union
import litellm import litellm
@ -49,7 +49,7 @@ class OpenAIO1Config(OpenAIConfig):
""" """
all_openai_params = litellm.OpenAIConfig.get_supported_openai_params( all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
model="gpt-4o" model="gpt-4o"
) )
non_supported_params = [ non_supported_params = [
@ -70,3 +70,34 @@ class OpenAIO1Config(OpenAIConfig):
if param == "max_tokens": if param == "max_tokens":
optional_params["max_completion_tokens"] = value optional_params["max_completion_tokens"] = value
return optional_params return optional_params
def is_model_o1_reasoning_model(self, model: str) -> bool:
if "o1" in model:
return True
return False
def o1_prompt_factory(self, messages: List[Any]):
"""
Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
"""
for message in messages:
if message["role"] == "system":
message["role"] = "user"
if isinstance(message["content"], list):
new_content = []
for content_item in message["content"]:
if content_item.get("type") == "image_url":
if litellm.drop_params is not True:
raise ValueError(
"Image content is not supported for O-1 models. Set litellm.drop_param to True to drop image content."
)
# If drop_param is True, we simply don't add the image content to new_content
else:
new_content.append(content_item)
message["content"] = new_content
return messages

View file

@ -550,6 +550,8 @@ class OpenAIConfig:
] # works across all models ] # works across all models
model_specific_params = [] model_specific_params = []
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
if ( if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4" model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format' ): # gpt-4 does not support 'response_format'
@ -566,6 +568,12 @@ class OpenAIConfig:
def map_openai_params( def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str self, non_default_params: dict, optional_params: dict, model: str
) -> dict: ) -> dict:
""" """
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
supported_openai_params = self.get_supported_openai_params(model) supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param in supported_openai_params: if param in supported_openai_params:
@ -861,6 +869,13 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
if (
litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model)
and messages is not None
):
messages = litellm.OpenAIO1Config().o1_prompt_factory(
messages=messages,
)
for _ in range( for _ in range(
2 2

View file

@ -51,6 +51,3 @@ async def test_o1_handle_system_role(respx_mock: MockRouter):
print(f"response: {response}") print(f"response: {response}")
assert isinstance(response, ModelResponse) assert isinstance(response, ModelResponse)
assert response.choices[0].message.content == "Mocked response" assert response.choices[0].message.content == "Mocked response"
# ... existing code ...