From a5a0773b19bfc65bb7342e87c8ac48565a1c8645 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 12 Sep 2024 14:09:13 -0700 Subject: [PATCH] fix handle o1 not supporting system message --- litellm/__init__.py | 3 +++ litellm/llms/OpenAI/o1_reasoning.py | 35 +++++++++++++++++++++++++++-- litellm/llms/OpenAI/openai.py | 15 +++++++++++++ litellm/tests/test_openai_o1.py | 3 --- 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 95c276edf..6afec1079 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -944,6 +944,9 @@ from .llms.OpenAI.openai import ( GroqConfig, AzureAIStudioConfig, ) +from .llms.OpenAI.o1_reasoning import ( + OpenAIO1Config, +) from .llms.nvidia_nim import NvidiaNimConfig from .llms.cerebras.chat import CerebrasConfig from .llms.AI21.chat import AI21ChatConfig diff --git a/litellm/llms/OpenAI/o1_reasoning.py b/litellm/llms/OpenAI/o1_reasoning.py index 03038b7ce..dcfe2d06c 100644 --- a/litellm/llms/OpenAI/o1_reasoning.py +++ b/litellm/llms/OpenAI/o1_reasoning.py @@ -12,7 +12,7 @@ Translations handled by LiteLLM: """ import types -from typing import Optional, Union +from typing import Any, List, Optional, Union import litellm @@ -49,7 +49,7 @@ class OpenAIO1Config(OpenAIConfig): """ - all_openai_params = litellm.OpenAIConfig.get_supported_openai_params( + all_openai_params = litellm.OpenAIConfig().get_supported_openai_params( model="gpt-4o" ) non_supported_params = [ @@ -70,3 +70,34 @@ class OpenAIO1Config(OpenAIConfig): if param == "max_tokens": optional_params["max_completion_tokens"] = value return optional_params + + def is_model_o1_reasoning_model(self, model: str) -> bool: + if "o1" in model: + return True + return False + + def o1_prompt_factory(self, messages: List[Any]): + """ + Handles limitations of O-1 model family. + - modalities: image => drop param (if user opts in to dropping param) + - role: system ==> translate to role 'user' + """ + + for message in messages: + if message["role"] == "system": + message["role"] = "user" + + if isinstance(message["content"], list): + new_content = [] + for content_item in message["content"]: + if content_item.get("type") == "image_url": + if litellm.drop_params is not True: + raise ValueError( + "Image content is not supported for O-1 models. Set litellm.drop_param to True to drop image content." + ) + # If drop_param is True, we simply don't add the image content to new_content + else: + new_content.append(content_item) + message["content"] = new_content + + return messages diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index ed4d199f6..d90c04b62 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -550,6 +550,8 @@ class OpenAIConfig: ] # works across all models model_specific_params = [] + if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model): + return litellm.OpenAIO1Config().get_supported_openai_params(model=model) if ( model != "gpt-3.5-turbo-16k" and model != "gpt-4" ): # gpt-4 does not support 'response_format' @@ -566,6 +568,12 @@ class OpenAIConfig: def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str ) -> dict: + """ """ + if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model): + return litellm.OpenAIO1Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) supported_openai_params = self.get_supported_openai_params(model) for param, value in non_default_params.items(): if param in supported_openai_params: @@ -861,6 +869,13 @@ class OpenAIChatCompletion(BaseLLM): messages=messages, custom_llm_provider=custom_llm_provider, ) + if ( + litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model) + and messages is not None + ): + messages = litellm.OpenAIO1Config().o1_prompt_factory( + messages=messages, + ) for _ in range( 2 diff --git a/litellm/tests/test_openai_o1.py b/litellm/tests/test_openai_o1.py index f08c71ca9..7c450d7e7 100644 --- a/litellm/tests/test_openai_o1.py +++ b/litellm/tests/test_openai_o1.py @@ -51,6 +51,3 @@ async def test_o1_handle_system_role(respx_mock: MockRouter): print(f"response: {response}") assert isinstance(response, ModelResponse) assert response.choices[0].message.content == "Mocked response" - - -# ... existing code ...