Complete o3 model support (#8183)

* fix(o_series_transformation.py): add 'reasoning_effort' as o series model param

Closes https://github.com/BerriAI/litellm/issues/8182

* fix(main.py): ensure `reasoning_effort` is a mapped openai param

* refactor(azure/): rename o1_[x] files to o_series_[x]

* refactor(base_llm_unit_tests.py): refactor testing for o series reasoning effort

* test(test_azure_o_series.py): have azure o series tests correctly inherit from base o series model tests

* feat(base_utils.py): support translating 'developer' role to 'system' role for non-openai providers

Makes it easy to switch from openai to anthropic

* fix: fix linting errors

* fix(base_llm_unit_tests.py): fix test

* fix(main.py): add missing param
This commit is contained in:
Krish Dholakia 2025-02-02 22:36:37 -08:00 committed by GitHub
parent e4566d7b1c
commit 1105e35538
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 230 additions and 11 deletions

View file

@ -939,7 +939,7 @@ from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
from .llms.perplexity.chat.transformation import PerplexityChatConfig from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig

View file

@ -118,6 +118,7 @@ OPENAI_CHAT_COMPLETION_PARAMS = [
"parallel_tool_calls", "parallel_tool_calls",
"logprobs", "logprobs",
"top_logprobs", "top_logprobs",
"reasoning_effort",
"extra_headers", "extra_headers",
] ]

View file

@ -2175,7 +2175,7 @@ from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
def _parse_content_type(content_type: str) -> str: def _parse_content_type(content_type: str) -> str:
m = Message() m = Message()
m['content-type'] = content_type m["content-type"] = content_type
return m.get_content_type() return m.get_content_type()

View file

@ -1,3 +1,7 @@
"""
Utility functions for base LLM classes.
"""
import copy import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
@ -5,6 +9,7 @@ from typing import List, Optional, Type, Union
from openai.lib import _parsing, _pydantic from openai.lib import _parsing, _pydantic
from pydantic import BaseModel from pydantic import BaseModel
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderSpecificModelInfo from litellm.types.utils import ProviderSpecificModelInfo
@ -105,3 +110,18 @@ def type_to_response_format_param(
"strict": True, "strict": True,
}, },
} }
def map_developer_role_to_system_role(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
"""
new_messages: List[AllMessageValues] = []
for m in messages:
if m["role"] == "developer":
new_messages.append({"role": "system", "content": m["content"]})
else:
new_messages.append(m)
return new_messages

View file

@ -18,10 +18,14 @@ from typing import (
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from ..base_utils import type_to_response_format_param from ..base_utils import (
map_developer_role_to_system_role,
type_to_response_format_param,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -99,6 +103,20 @@ class BaseConfig(ABC):
""" """
return False return False
def translate_developer_role_to_system_role(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
Overriden by OpenAI/Azure
"""
verbose_logger.debug(
"Translating developer role to system role for non-OpenAI providers."
) # ensure user knows what's happening with their input.
return map_developer_role_to_system_role(messages=messages)
def should_retry_llm_api_inside_llm_translation_on_http_error( def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool: ) -> bool:

View file

@ -1,5 +1,5 @@
""" """
Support for o1 model family Support for o1/o3 model family
https://platform.openai.com/docs/guides/reasoning https://platform.openai.com/docs/guides/reasoning
@ -35,6 +35,14 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
def get_config(cls): def get_config(cls):
return super().get_config() return super().get_config()
def translate_developer_role_to_system_role(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
O-series models support `developer` role.
"""
return messages
def should_fake_stream( def should_fake_stream(
self, self,
model: Optional[str], model: Optional[str],
@ -67,6 +75,10 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
"top_logprobs", "top_logprobs",
] ]
o_series_only_param = ["reasoning_effort"]
all_openai_params.extend(o_series_only_param)
try: try:
model, custom_llm_provider, api_base, api_key = get_llm_provider( model, custom_llm_provider, api_base, api_key = get_llm_provider(
model=model model=model

View file

@ -67,6 +67,7 @@ from litellm.litellm_core_utils.mock_functions import (
from litellm.litellm_core_utils.prompt_templates.common_utils import ( from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response, get_content_from_model_response,
) )
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
@ -115,7 +116,7 @@ from .llms import baseten, maritalk, ollama_chat
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion from .llms.azure.chat.o_series_handler import AzureOpenAIO1ChatCompletion
from .llms.azure.completion.handler import AzureTextCompletion from .llms.azure.completion.handler import AzureTextCompletion
from .llms.azure_ai.embed import AzureAIEmbedding from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
@ -331,6 +332,7 @@ async def acompletion(
logprobs: Optional[bool] = None, logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
deployment_id=None, deployment_id=None,
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
# set api_base, api_version, api_key # set api_base, api_version, api_key
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
@ -425,6 +427,7 @@ async def acompletion(
"api_version": api_version, "api_version": api_version,
"api_key": api_key, "api_key": api_key,
"model_list": model_list, "model_list": model_list,
"reasoning_effort": reasoning_effort,
"extra_headers": extra_headers, "extra_headers": extra_headers,
"acompletion": True, # assuming this is a required parameter "acompletion": True, # assuming this is a required parameter
} }
@ -777,6 +780,7 @@ def completion( # type: ignore # noqa: PLR0915
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
user: Optional[str] = None, user: Optional[str] = None,
# openai v1.0+ new params # openai v1.0+ new params
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
response_format: Optional[Union[dict, Type[BaseModel]]] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
tools: Optional[List] = None, tools: Optional[List] = None,
@ -1046,6 +1050,19 @@ def completion( # type: ignore # noqa: PLR0915
if eos_token: if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token custom_prompt_dict[model]["eos_token"] = eos_token
provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
if provider_config is not None:
messages = provider_config.translate_developer_role_to_system_role(
messages=messages
)
if ( if (
supports_system_message is not None supports_system_message is not None
and isinstance(supports_system_message, bool) and isinstance(supports_system_message, bool)
@ -1087,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
api_version=api_version, api_version=api_version,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=parallel_tool_calls,
messages=messages, messages=messages,
reasoning_effort=reasoning_effort,
**non_default_params, **non_default_params,
) )

View file

@ -442,10 +442,20 @@ class OpenAIChatCompletionSystemMessage(TypedDict, total=False):
name: str name: str
class OpenAIChatCompletionDeveloperMessage(TypedDict, total=False):
role: Required[Literal["developer"]]
content: Required[Union[str, List]]
name: str
class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False): class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False):
cache_control: ChatCompletionCachedContent cache_control: ChatCompletionCachedContent
class ChatCompletionDeveloperMessage(OpenAIChatCompletionDeveloperMessage, total=False):
cache_control: ChatCompletionCachedContent
ValidUserMessageContentTypes = [ ValidUserMessageContentTypes = [
"text", "text",
"image_url", "image_url",
@ -458,6 +468,7 @@ AllMessageValues = Union[
ChatCompletionToolMessage, ChatCompletionToolMessage,
ChatCompletionSystemMessage, ChatCompletionSystemMessage,
ChatCompletionFunctionMessage, ChatCompletionFunctionMessage,
ChatCompletionDeveloperMessage,
] ]

View file

@ -2686,6 +2686,7 @@ def get_optional_params( # noqa: PLR0915
api_version=None, api_version=None,
parallel_tool_calls=None, parallel_tool_calls=None,
drop_params=None, drop_params=None,
reasoning_effort=None,
additional_drop_params=None, additional_drop_params=None,
messages: Optional[List[AllMessageValues]] = None, messages: Optional[List[AllMessageValues]] = None,
**kwargs, **kwargs,
@ -2771,6 +2772,7 @@ def get_optional_params( # noqa: PLR0915
"drop_params": None, "drop_params": None,
"additional_drop_params": None, "additional_drop_params": None,
"messages": None, "messages": None,
"reasoning_effort": None,
} }
# filter out those parameters that were passed with non-default values # filter out those parameters that were passed with non-default values

View file

@ -18,9 +18,11 @@ from litellm.utils import (
get_supported_openai_params, get_supported_openai_params,
get_optional_params, get_optional_params,
) )
from typing import Union
# test_example.py # test_example.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from openai import OpenAI
def _usage_format_tests(usage: litellm.Usage): def _usage_format_tests(usage: litellm.Usage):
@ -70,6 +72,34 @@ class BaseLLMChatTest(ABC):
"""Must return the base completion call args""" """Must return the base completion call args"""
pass pass
def test_developer_role_translation(self):
"""
Test that the developer role is translated correctly for non-OpenAI providers.
Translate `developer` role to `system` role for non-OpenAI providers.
"""
base_completion_call_args = self.get_base_completion_call_args()
messages = [
{
"role": "developer",
"content": "Be a good bot!",
},
{
"role": "user",
"content": [{"type": "text", "text": "Hello, how are you?"}],
},
]
try:
response = self.completion_function(
**base_completion_call_args,
messages=messages,
)
assert response is not None
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
assert response.choices[0].message.content is not None
def test_content_list_handling(self): def test_content_list_handling(self):
"""Check if content list is supported by LLM API""" """Check if content list is supported by LLM API"""
base_completion_call_args = self.get_base_completion_call_args() base_completion_call_args = self.get_base_completion_call_args()
@ -605,3 +635,72 @@ class BaseLLMChatTest(ABC):
cost = completion_cost(response) cost = completion_cost(response)
assert cost > 0 assert cost > 0
class BaseOSeriesModelsTest(ABC): # test across azure/openai
@abstractmethod
def get_base_completion_call_args(self):
pass
@abstractmethod
def get_client(self) -> OpenAI:
pass
def test_reasoning_effort(self):
"""Test that reasoning_effort is passed correctly to the model"""
from litellm import completion
client = self.get_client()
completion_args = self.get_base_completion_call_args()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
completion(
**completion_args,
reasoning_effort="low",
messages=[{"role": "user", "content": "Hello!"}],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert request_body["reasoning_effort"] == "low"
def test_developer_role_translation(self):
"""Test that developer role is translated correctly to system role for non-OpenAI providers"""
from litellm import completion
client = self.get_client()
completion_args = self.get_base_completion_call_args()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
completion(
**completion_args,
reasoning_effort="low",
messages=[
{"role": "developer", "content": "Be a good bot!"},
{"role": "user", "content": "Hello!"},
],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert (
request_body["messages"][0]["role"] == "developer"
), "Got={} instead of system".format(request_body["messages"][0]["role"])
assert request_body["messages"][0]["content"] == "Be a good bot!"

View file

@ -15,10 +15,10 @@ from respx import MockRouter
import litellm import litellm
from litellm import Choices, Message, ModelResponse from litellm import Choices, Message, ModelResponse
from base_llm_unit_tests import BaseLLMChatTest from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
class TestAzureOpenAIO1(BaseLLMChatTest): class TestAzureOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
def get_base_completion_call_args(self): def get_base_completion_call_args(self):
return { return {
"model": "azure/o1-preview", "model": "azure/o1-preview",
@ -26,6 +26,15 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com", "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
} }
def get_client(self):
from openai import AzureOpenAI
return AzureOpenAI(
api_key="my-fake-o1-key",
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2024-02-15-preview",
)
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass pass
@ -65,6 +74,24 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
assert fake_stream is False assert fake_stream is False
class TestAzureOpenAIO3(BaseOSeriesModelsTest):
def get_base_completion_call_args(self):
return {
"model": "azure/o3-mini",
"api_key": "my-fake-o1-key",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
}
def get_client(self):
from openai import AzureOpenAI
return AzureOpenAI(
api_key="my-fake-o1-key",
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2024-02-15-preview",
)
def test_azure_o3_streaming(): def test_azure_o3_streaming():
""" """
Test that o3 models handles fake streaming correctly. Test that o3 models handles fake streaming correctly.

View file

@ -15,7 +15,7 @@ from respx import MockRouter
import litellm import litellm
from litellm import Choices, Message, ModelResponse from litellm import Choices, Message, ModelResponse
from base_llm_unit_tests import BaseLLMChatTest from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
@pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"]) @pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"])
@ -152,12 +152,17 @@ def test_litellm_responses():
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails) assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)
class TestOpenAIO1(BaseLLMChatTest): class TestOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
def get_base_completion_call_args(self): def get_base_completion_call_args(self):
return { return {
"model": "o1", "model": "o1",
} }
def get_client(self):
from openai import OpenAI
return OpenAI(api_key="fake-api-key")
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass pass
@ -167,12 +172,17 @@ class TestOpenAIO1(BaseLLMChatTest):
pass pass
class TestOpenAIO3(BaseLLMChatTest): class TestOpenAIO3(BaseOSeriesModelsTest, BaseLLMChatTest):
def get_base_completion_call_args(self): def get_base_completion_call_args(self):
return { return {
"model": "o3-mini", "model": "o3-mini",
} }
def get_client(self):
from openai import OpenAI
return OpenAI(api_key="fake-api-key")
def test_tool_call_no_arguments(self, tool_call_no_arguments): def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass pass

View file

@ -29,7 +29,8 @@ def test_acompletion_params():
# Assert that the parameters are the same # Assert that the parameters are the same
if keys_acompletion != keys_completion: if keys_acompletion != keys_completion:
pytest.fail( pytest.fail(
"The parameters of the litellm.acompletion function and litellm.completion are not the same." "The parameters of the litellm.acompletion function and litellm.completion are not the same. "
f"Completion has extra keys: {keys_completion - keys_acompletion}"
) )