mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Complete o3 model support (#8183)
* fix(o_series_transformation.py): add 'reasoning_effort' as o series model param Closes https://github.com/BerriAI/litellm/issues/8182 * fix(main.py): ensure `reasoning_effort` is a mapped openai param * refactor(azure/): rename o1_[x] files to o_series_[x] * refactor(base_llm_unit_tests.py): refactor testing for o series reasoning effort * test(test_azure_o_series.py): have azure o series tests correctly inherit from base o series model tests * feat(base_utils.py): support translating 'developer' role to 'system' role for non-openai providers Makes it easy to switch from openai to anthropic * fix: fix linting errors * fix(base_llm_unit_tests.py): fix test * fix(main.py): add missing param
This commit is contained in:
parent
e4566d7b1c
commit
1105e35538
15 changed files with 230 additions and 11 deletions
|
@ -939,7 +939,7 @@ from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
|||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
||||
|
|
|
@ -118,6 +118,7 @@ OPENAI_CHAT_COMPLETION_PARAMS = [
|
|||
"parallel_tool_calls",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"reasoning_effort",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
|
|
|
@ -2175,7 +2175,7 @@ from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
|
|||
|
||||
def _parse_content_type(content_type: str) -> str:
|
||||
m = Message()
|
||||
m['content-type'] = content_type
|
||||
m["content-type"] = content_type
|
||||
return m.get_content_type()
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
Utility functions for base LLM classes.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type, Union
|
||||
|
@ -5,6 +9,7 @@ from typing import List, Optional, Type, Union
|
|||
from openai.lib import _parsing, _pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderSpecificModelInfo
|
||||
|
||||
|
||||
|
@ -105,3 +110,18 @@ def type_to_response_format_param(
|
|||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def map_developer_role_to_system_role(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
"""
|
||||
new_messages: List[AllMessageValues] = []
|
||||
for m in messages:
|
||||
if m["role"] == "developer":
|
||||
new_messages.append({"role": "system", "content": m["content"]})
|
||||
else:
|
||||
new_messages.append(m)
|
||||
return new_messages
|
||||
|
|
|
@ -18,10 +18,14 @@ from typing import (
|
|||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..base_utils import type_to_response_format_param
|
||||
from ..base_utils import (
|
||||
map_developer_role_to_system_role,
|
||||
type_to_response_format_param,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
@ -99,6 +103,20 @@ class BaseConfig(ABC):
|
|||
"""
|
||||
return False
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
|
||||
Overriden by OpenAI/Azure
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"Translating developer role to system role for non-OpenAI providers."
|
||||
) # ensure user knows what's happening with their input.
|
||||
return map_developer_role_to_system_role(messages=messages)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Support for o1 model family
|
||||
Support for o1/o3 model family
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
|
@ -35,6 +35,14 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
|||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
O-series models support `developer` role.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
|
@ -67,6 +75,10 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
|||
"top_logprobs",
|
||||
]
|
||||
|
||||
o_series_only_param = ["reasoning_effort"]
|
||||
|
||||
all_openai_params.extend(o_series_only_param)
|
||||
|
||||
try:
|
||||
model, custom_llm_provider, api_base, api_key = get_llm_provider(
|
||||
model=model
|
||||
|
|
|
@ -67,6 +67,7 @@ from litellm.litellm_core_utils.mock_functions import (
|
|||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_content_from_model_response,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.realtime_api.main import _realtime_health_check
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
@ -115,7 +116,7 @@ from .llms import baseten, maritalk, ollama_chat
|
|||
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||
from .llms.azure.chat.o_series_handler import AzureOpenAIO1ChatCompletion
|
||||
from .llms.azure.completion.handler import AzureTextCompletion
|
||||
from .llms.azure_ai.embed import AzureAIEmbedding
|
||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
|
@ -331,6 +332,7 @@ async def acompletion(
|
|||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
deployment_id=None,
|
||||
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
|
@ -425,6 +427,7 @@ async def acompletion(
|
|||
"api_version": api_version,
|
||||
"api_key": api_key,
|
||||
"model_list": model_list,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"extra_headers": extra_headers,
|
||||
"acompletion": True, # assuming this is a required parameter
|
||||
}
|
||||
|
@ -777,6 +780,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
# openai v1.0+ new params
|
||||
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
|
||||
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||
seed: Optional[int] = None,
|
||||
tools: Optional[List] = None,
|
||||
|
@ -1046,6 +1050,19 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if eos_token:
|
||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||
|
||||
provider_config: Optional[BaseConfig] = None
|
||||
if custom_llm_provider is not None and custom_llm_provider in [
|
||||
provider.value for provider in LlmProviders
|
||||
]:
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
||||
if provider_config is not None:
|
||||
messages = provider_config.translate_developer_role_to_system_role(
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if (
|
||||
supports_system_message is not None
|
||||
and isinstance(supports_system_message, bool)
|
||||
|
@ -1087,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_version=api_version,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
messages=messages,
|
||||
reasoning_effort=reasoning_effort,
|
||||
**non_default_params,
|
||||
)
|
||||
|
||||
|
|
|
@ -442,10 +442,20 @@ class OpenAIChatCompletionSystemMessage(TypedDict, total=False):
|
|||
name: str
|
||||
|
||||
|
||||
class OpenAIChatCompletionDeveloperMessage(TypedDict, total=False):
|
||||
role: Required[Literal["developer"]]
|
||||
content: Required[Union[str, List]]
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False):
|
||||
cache_control: ChatCompletionCachedContent
|
||||
|
||||
|
||||
class ChatCompletionDeveloperMessage(OpenAIChatCompletionDeveloperMessage, total=False):
|
||||
cache_control: ChatCompletionCachedContent
|
||||
|
||||
|
||||
ValidUserMessageContentTypes = [
|
||||
"text",
|
||||
"image_url",
|
||||
|
@ -458,6 +468,7 @@ AllMessageValues = Union[
|
|||
ChatCompletionToolMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionFunctionMessage,
|
||||
ChatCompletionDeveloperMessage,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -2686,6 +2686,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
api_version=None,
|
||||
parallel_tool_calls=None,
|
||||
drop_params=None,
|
||||
reasoning_effort=None,
|
||||
additional_drop_params=None,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
**kwargs,
|
||||
|
@ -2771,6 +2772,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
"drop_params": None,
|
||||
"additional_drop_params": None,
|
||||
"messages": None,
|
||||
"reasoning_effort": None,
|
||||
}
|
||||
|
||||
# filter out those parameters that were passed with non-default values
|
||||
|
|
|
@ -18,9 +18,11 @@ from litellm.utils import (
|
|||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
)
|
||||
from typing import Union
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
def _usage_format_tests(usage: litellm.Usage):
|
||||
|
@ -70,6 +72,34 @@ class BaseLLMChatTest(ABC):
|
|||
"""Must return the base completion call args"""
|
||||
pass
|
||||
|
||||
def test_developer_role_translation(self):
|
||||
"""
|
||||
Test that the developer role is translated correctly for non-OpenAI providers.
|
||||
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
"""
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
messages = [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": "Be a good bot!",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, how are you?"}],
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = self.completion_function(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Model is overloaded")
|
||||
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
def test_content_list_handling(self):
|
||||
"""Check if content list is supported by LLM API"""
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
|
@ -605,3 +635,72 @@ class BaseLLMChatTest(ABC):
|
|||
cost = completion_cost(response)
|
||||
|
||||
assert cost > 0
|
||||
|
||||
|
||||
class BaseOSeriesModelsTest(ABC): # test across azure/openai
|
||||
@abstractmethod
|
||||
def get_base_completion_call_args(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_client(self) -> OpenAI:
|
||||
pass
|
||||
|
||||
def test_reasoning_effort(self):
|
||||
"""Test that reasoning_effort is passed correctly to the model"""
|
||||
|
||||
from litellm import completion
|
||||
|
||||
client = self.get_client()
|
||||
|
||||
completion_args = self.get_base_completion_call_args()
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
completion(
|
||||
**completion_args,
|
||||
reasoning_effort="low",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
print("request_body: ", request_body)
|
||||
assert request_body["reasoning_effort"] == "low"
|
||||
|
||||
def test_developer_role_translation(self):
|
||||
"""Test that developer role is translated correctly to system role for non-OpenAI providers"""
|
||||
from litellm import completion
|
||||
|
||||
client = self.get_client()
|
||||
|
||||
completion_args = self.get_base_completion_call_args()
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
completion(
|
||||
**completion_args,
|
||||
reasoning_effort="low",
|
||||
messages=[
|
||||
{"role": "developer", "content": "Be a good bot!"},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
print("request_body: ", request_body)
|
||||
assert (
|
||||
request_body["messages"][0]["role"] == "developer"
|
||||
), "Got={} instead of system".format(request_body["messages"][0]["role"])
|
||||
assert request_body["messages"][0]["content"] == "Be a good bot!"
|
||||
|
|
|
@ -15,10 +15,10 @@ from respx import MockRouter
|
|||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
|
||||
|
||||
|
||||
class TestAzureOpenAIO1(BaseLLMChatTest):
|
||||
class TestAzureOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "azure/o1-preview",
|
||||
|
@ -26,6 +26,15 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
|
|||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
from openai import AzureOpenAI
|
||||
|
||||
return AzureOpenAI(
|
||||
api_key="my-fake-o1-key",
|
||||
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
api_version="2024-02-15-preview",
|
||||
)
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
@ -65,6 +74,24 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
|
|||
assert fake_stream is False
|
||||
|
||||
|
||||
class TestAzureOpenAIO3(BaseOSeriesModelsTest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "azure/o3-mini",
|
||||
"api_key": "my-fake-o1-key",
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
from openai import AzureOpenAI
|
||||
|
||||
return AzureOpenAI(
|
||||
api_key="my-fake-o1-key",
|
||||
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
api_version="2024-02-15-preview",
|
||||
)
|
||||
|
||||
|
||||
def test_azure_o3_streaming():
|
||||
"""
|
||||
Test that o3 models handles fake streaming correctly.
|
||||
|
|
|
@ -15,7 +15,7 @@ from respx import MockRouter
|
|||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"])
|
||||
|
@ -152,12 +152,17 @@ def test_litellm_responses():
|
|||
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)
|
||||
|
||||
|
||||
class TestOpenAIO1(BaseLLMChatTest):
|
||||
class TestOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "o1",
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(api_key="fake-api-key")
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
@ -167,12 +172,17 @@ class TestOpenAIO1(BaseLLMChatTest):
|
|||
pass
|
||||
|
||||
|
||||
class TestOpenAIO3(BaseLLMChatTest):
|
||||
class TestOpenAIO3(BaseOSeriesModelsTest, BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "o3-mini",
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(api_key="fake-api-key")
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
|
|
@ -29,7 +29,8 @@ def test_acompletion_params():
|
|||
# Assert that the parameters are the same
|
||||
if keys_acompletion != keys_completion:
|
||||
pytest.fail(
|
||||
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
|
||||
"The parameters of the litellm.acompletion function and litellm.completion are not the same. "
|
||||
f"Completion has extra keys: {keys_completion - keys_acompletion}"
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue