Complete o3 model support (#8183)

* fix(o_series_transformation.py): add 'reasoning_effort' as o series model param

Closes https://github.com/BerriAI/litellm/issues/8182

* fix(main.py): ensure `reasoning_effort` is a mapped openai param

* refactor(azure/): rename o1_[x] files to o_series_[x]

* refactor(base_llm_unit_tests.py): refactor testing for o series reasoning effort

* test(test_azure_o_series.py): have azure o series tests correctly inherit from base o series model tests

* feat(base_utils.py): support translating 'developer' role to 'system' role for non-openai providers

Makes it easy to switch from openai to anthropic

* fix: fix linting errors

* fix(base_llm_unit_tests.py): fix test

* fix(main.py): add missing param
This commit is contained in:
Krish Dholakia 2025-02-02 22:36:37 -08:00 committed by GitHub
parent e4566d7b1c
commit 1105e35538
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 230 additions and 11 deletions

View file

@ -939,7 +939,7 @@ from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig

View file

@ -118,6 +118,7 @@ OPENAI_CHAT_COMPLETION_PARAMS = [
"parallel_tool_calls",
"logprobs",
"top_logprobs",
"reasoning_effort",
"extra_headers",
]

View file

@ -2175,7 +2175,7 @@ from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
def _parse_content_type(content_type: str) -> str:
m = Message()
m['content-type'] = content_type
m["content-type"] = content_type
return m.get_content_type()

View file

@ -1,3 +1,7 @@
"""
Utility functions for base LLM classes.
"""
import copy
from abc import ABC, abstractmethod
from typing import List, Optional, Type, Union
@ -5,6 +9,7 @@ from typing import List, Optional, Type, Union
from openai.lib import _parsing, _pydantic
from pydantic import BaseModel
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderSpecificModelInfo
@ -105,3 +110,18 @@ def type_to_response_format_param(
"strict": True,
},
}
def map_developer_role_to_system_role(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
"""
new_messages: List[AllMessageValues] = []
for m in messages:
if m["role"] == "developer":
new_messages.append({"role": "system", "content": m["content"]})
else:
new_messages.append(m)
return new_messages

View file

@ -18,10 +18,14 @@ from typing import (
import httpx
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..base_utils import type_to_response_format_param
from ..base_utils import (
map_developer_role_to_system_role,
type_to_response_format_param,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -99,6 +103,20 @@ class BaseConfig(ABC):
"""
return False
def translate_developer_role_to_system_role(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
Overriden by OpenAI/Azure
"""
verbose_logger.debug(
"Translating developer role to system role for non-OpenAI providers."
) # ensure user knows what's happening with their input.
return map_developer_role_to_system_role(messages=messages)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:

View file

@ -1,5 +1,5 @@
"""
Support for o1 model family
Support for o1/o3 model family
https://platform.openai.com/docs/guides/reasoning
@ -35,6 +35,14 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
def get_config(cls):
return super().get_config()
def translate_developer_role_to_system_role(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
O-series models support `developer` role.
"""
return messages
def should_fake_stream(
self,
model: Optional[str],
@ -67,6 +75,10 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
"top_logprobs",
]
o_series_only_param = ["reasoning_effort"]
all_openai_params.extend(o_series_only_param)
try:
model, custom_llm_provider, api_base, api_key = get_llm_provider(
model=model

View file

@ -67,6 +67,7 @@ from litellm.litellm_core_utils.mock_functions import (
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str
@ -115,7 +116,7 @@ from .llms import baseten, maritalk, ollama_chat
from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.azure.chat.o_series_handler import AzureOpenAIO1ChatCompletion
from .llms.azure.completion.handler import AzureTextCompletion
from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
@ -331,6 +332,7 @@ async def acompletion(
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
@ -425,6 +427,7 @@ async def acompletion(
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
"reasoning_effort": reasoning_effort,
"extra_headers": extra_headers,
"acompletion": True, # assuming this is a required parameter
}
@ -777,6 +780,7 @@ def completion( # type: ignore # noqa: PLR0915
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
# openai v1.0+ new params
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
@ -1046,6 +1050,19 @@ def completion( # type: ignore # noqa: PLR0915
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
if provider_config is not None:
messages = provider_config.translate_developer_role_to_system_role(
messages=messages
)
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
@ -1087,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
api_version=api_version,
parallel_tool_calls=parallel_tool_calls,
messages=messages,
reasoning_effort=reasoning_effort,
**non_default_params,
)

View file

@ -442,10 +442,20 @@ class OpenAIChatCompletionSystemMessage(TypedDict, total=False):
name: str
class OpenAIChatCompletionDeveloperMessage(TypedDict, total=False):
role: Required[Literal["developer"]]
content: Required[Union[str, List]]
name: str
class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False):
cache_control: ChatCompletionCachedContent
class ChatCompletionDeveloperMessage(OpenAIChatCompletionDeveloperMessage, total=False):
cache_control: ChatCompletionCachedContent
ValidUserMessageContentTypes = [
"text",
"image_url",
@ -458,6 +468,7 @@ AllMessageValues = Union[
ChatCompletionToolMessage,
ChatCompletionSystemMessage,
ChatCompletionFunctionMessage,
ChatCompletionDeveloperMessage,
]

View file

@ -2686,6 +2686,7 @@ def get_optional_params( # noqa: PLR0915
api_version=None,
parallel_tool_calls=None,
drop_params=None,
reasoning_effort=None,
additional_drop_params=None,
messages: Optional[List[AllMessageValues]] = None,
**kwargs,
@ -2771,6 +2772,7 @@ def get_optional_params( # noqa: PLR0915
"drop_params": None,
"additional_drop_params": None,
"messages": None,
"reasoning_effort": None,
}
# filter out those parameters that were passed with non-default values

View file

@ -18,9 +18,11 @@ from litellm.utils import (
get_supported_openai_params,
get_optional_params,
)
from typing import Union
# test_example.py
from abc import ABC, abstractmethod
from openai import OpenAI
def _usage_format_tests(usage: litellm.Usage):
@ -70,6 +72,34 @@ class BaseLLMChatTest(ABC):
"""Must return the base completion call args"""
pass
def test_developer_role_translation(self):
"""
Test that the developer role is translated correctly for non-OpenAI providers.
Translate `developer` role to `system` role for non-OpenAI providers.
"""
base_completion_call_args = self.get_base_completion_call_args()
messages = [
{
"role": "developer",
"content": "Be a good bot!",
},
{
"role": "user",
"content": [{"type": "text", "text": "Hello, how are you?"}],
},
]
try:
response = self.completion_function(
**base_completion_call_args,
messages=messages,
)
assert response is not None
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
assert response.choices[0].message.content is not None
def test_content_list_handling(self):
"""Check if content list is supported by LLM API"""
base_completion_call_args = self.get_base_completion_call_args()
@ -605,3 +635,72 @@ class BaseLLMChatTest(ABC):
cost = completion_cost(response)
assert cost > 0
class BaseOSeriesModelsTest(ABC): # test across azure/openai
@abstractmethod
def get_base_completion_call_args(self):
pass
@abstractmethod
def get_client(self) -> OpenAI:
pass
def test_reasoning_effort(self):
"""Test that reasoning_effort is passed correctly to the model"""
from litellm import completion
client = self.get_client()
completion_args = self.get_base_completion_call_args()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
completion(
**completion_args,
reasoning_effort="low",
messages=[{"role": "user", "content": "Hello!"}],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert request_body["reasoning_effort"] == "low"
def test_developer_role_translation(self):
"""Test that developer role is translated correctly to system role for non-OpenAI providers"""
from litellm import completion
client = self.get_client()
completion_args = self.get_base_completion_call_args()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
completion(
**completion_args,
reasoning_effort="low",
messages=[
{"role": "developer", "content": "Be a good bot!"},
{"role": "user", "content": "Hello!"},
],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert (
request_body["messages"][0]["role"] == "developer"
), "Got={} instead of system".format(request_body["messages"][0]["role"])
assert request_body["messages"][0]["content"] == "Be a good bot!"

View file

@ -15,10 +15,10 @@ from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse
from base_llm_unit_tests import BaseLLMChatTest
from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
class TestAzureOpenAIO1(BaseLLMChatTest):
class TestAzureOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
def get_base_completion_call_args(self):
return {
"model": "azure/o1-preview",
@ -26,6 +26,15 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
}
def get_client(self):
from openai import AzureOpenAI
return AzureOpenAI(
api_key="my-fake-o1-key",
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2024-02-15-preview",
)
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
@ -65,6 +74,24 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
assert fake_stream is False
class TestAzureOpenAIO3(BaseOSeriesModelsTest):
def get_base_completion_call_args(self):
return {
"model": "azure/o3-mini",
"api_key": "my-fake-o1-key",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
}
def get_client(self):
from openai import AzureOpenAI
return AzureOpenAI(
api_key="my-fake-o1-key",
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
api_version="2024-02-15-preview",
)
def test_azure_o3_streaming():
"""
Test that o3 models handles fake streaming correctly.

View file

@ -15,7 +15,7 @@ from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse
from base_llm_unit_tests import BaseLLMChatTest
from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
@pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"])
@ -152,12 +152,17 @@ def test_litellm_responses():
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)
class TestOpenAIO1(BaseLLMChatTest):
class TestOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
def get_base_completion_call_args(self):
return {
"model": "o1",
}
def get_client(self):
from openai import OpenAI
return OpenAI(api_key="fake-api-key")
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
@ -167,12 +172,17 @@ class TestOpenAIO1(BaseLLMChatTest):
pass
class TestOpenAIO3(BaseLLMChatTest):
class TestOpenAIO3(BaseOSeriesModelsTest, BaseLLMChatTest):
def get_base_completion_call_args(self):
return {
"model": "o3-mini",
}
def get_client(self):
from openai import OpenAI
return OpenAI(api_key="fake-api-key")
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass

View file

@ -29,7 +29,8 @@ def test_acompletion_params():
# Assert that the parameters are the same
if keys_acompletion != keys_completion:
pytest.fail(
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
"The parameters of the litellm.acompletion function and litellm.completion are not the same. "
f"Completion has extra keys: {keys_completion - keys_acompletion}"
)