mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Complete o3 model support (#8183)
* fix(o_series_transformation.py): add 'reasoning_effort' as o series model param Closes https://github.com/BerriAI/litellm/issues/8182 * fix(main.py): ensure `reasoning_effort` is a mapped openai param * refactor(azure/): rename o1_[x] files to o_series_[x] * refactor(base_llm_unit_tests.py): refactor testing for o series reasoning effort * test(test_azure_o_series.py): have azure o series tests correctly inherit from base o series model tests * feat(base_utils.py): support translating 'developer' role to 'system' role for non-openai providers Makes it easy to switch from openai to anthropic * fix: fix linting errors * fix(base_llm_unit_tests.py): fix test * fix(main.py): add missing param
This commit is contained in:
parent
e4566d7b1c
commit
1105e35538
15 changed files with 230 additions and 11 deletions
|
@ -939,7 +939,7 @@ from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||||
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
||||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||||
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
|
||||||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||||
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
||||||
|
|
|
@ -118,6 +118,7 @@ OPENAI_CHAT_COMPLETION_PARAMS = [
|
||||||
"parallel_tool_calls",
|
"parallel_tool_calls",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
|
"reasoning_effort",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -2175,7 +2175,7 @@ from litellm.types.llms.bedrock import ToolUseBlock as BedrockToolUseBlock
|
||||||
|
|
||||||
def _parse_content_type(content_type: str) -> str:
|
def _parse_content_type(content_type: str) -> str:
|
||||||
m = Message()
|
m = Message()
|
||||||
m['content-type'] = content_type
|
m["content-type"] = content_type
|
||||||
return m.get_content_type()
|
return m.get_content_type()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
"""
|
||||||
|
Utility functions for base LLM classes.
|
||||||
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Type, Union
|
from typing import List, Optional, Type, Union
|
||||||
|
@ -5,6 +9,7 @@ from typing import List, Optional, Type, Union
|
||||||
from openai.lib import _parsing, _pydantic
|
from openai.lib import _parsing, _pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ProviderSpecificModelInfo
|
from litellm.types.utils import ProviderSpecificModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,3 +110,18 @@ def type_to_response_format_param(
|
||||||
"strict": True,
|
"strict": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def map_developer_role_to_system_role(
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||||
|
"""
|
||||||
|
new_messages: List[AllMessageValues] = []
|
||||||
|
for m in messages:
|
||||||
|
if m["role"] == "developer":
|
||||||
|
new_messages.append({"role": "system", "content": m["content"]})
|
||||||
|
else:
|
||||||
|
new_messages.append(m)
|
||||||
|
return new_messages
|
||||||
|
|
|
@ -18,10 +18,14 @@ from typing import (
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
from ..base_utils import type_to_response_format_param
|
from ..base_utils import (
|
||||||
|
map_developer_role_to_system_role,
|
||||||
|
type_to_response_format_param,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
@ -99,6 +103,20 @@ class BaseConfig(ABC):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def translate_developer_role_to_system_role(
|
||||||
|
self,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||||
|
|
||||||
|
Overriden by OpenAI/Azure
|
||||||
|
"""
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Translating developer role to system role for non-OpenAI providers."
|
||||||
|
) # ensure user knows what's happening with their input.
|
||||||
|
return map_developer_role_to_system_role(messages=messages)
|
||||||
|
|
||||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Support for o1 model family
|
Support for o1/o3 model family
|
||||||
|
|
||||||
https://platform.openai.com/docs/guides/reasoning
|
https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
|
@ -35,6 +35,14 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
||||||
def get_config(cls):
|
def get_config(cls):
|
||||||
return super().get_config()
|
return super().get_config()
|
||||||
|
|
||||||
|
def translate_developer_role_to_system_role(
|
||||||
|
self, messages: List[AllMessageValues]
|
||||||
|
) -> List[AllMessageValues]:
|
||||||
|
"""
|
||||||
|
O-series models support `developer` role.
|
||||||
|
"""
|
||||||
|
return messages
|
||||||
|
|
||||||
def should_fake_stream(
|
def should_fake_stream(
|
||||||
self,
|
self,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
@ -67,6 +75,10 @@ class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
o_series_only_param = ["reasoning_effort"]
|
||||||
|
|
||||||
|
all_openai_params.extend(o_series_only_param)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model, custom_llm_provider, api_base, api_key = get_llm_provider(
|
model, custom_llm_provider, api_base, api_key = get_llm_provider(
|
||||||
model=model
|
model=model
|
||||||
|
|
|
@ -67,6 +67,7 @@ from litellm.litellm_core_utils.mock_functions import (
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
get_content_from_model_response,
|
get_content_from_model_response,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.realtime_api.main import _realtime_health_check
|
from litellm.realtime_api.main import _realtime_health_check
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
@ -115,7 +116,7 @@ from .llms import baseten, maritalk, ollama_chat
|
||||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
from .llms.azure.chat.o_series_handler import AzureOpenAIO1ChatCompletion
|
||||||
from .llms.azure.completion.handler import AzureTextCompletion
|
from .llms.azure.completion.handler import AzureTextCompletion
|
||||||
from .llms.azure_ai.embed import AzureAIEmbedding
|
from .llms.azure_ai.embed import AzureAIEmbedding
|
||||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
|
@ -331,6 +332,7 @@ async def acompletion(
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: Optional[bool] = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: Optional[int] = None,
|
||||||
deployment_id=None,
|
deployment_id=None,
|
||||||
|
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
|
||||||
# set api_base, api_version, api_key
|
# set api_base, api_version, api_key
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
|
@ -425,6 +427,7 @@ async def acompletion(
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"model_list": model_list,
|
"model_list": model_list,
|
||||||
|
"reasoning_effort": reasoning_effort,
|
||||||
"extra_headers": extra_headers,
|
"extra_headers": extra_headers,
|
||||||
"acompletion": True, # assuming this is a required parameter
|
"acompletion": True, # assuming this is a required parameter
|
||||||
}
|
}
|
||||||
|
@ -777,6 +780,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
logit_bias: Optional[dict] = None,
|
logit_bias: Optional[dict] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
# openai v1.0+ new params
|
# openai v1.0+ new params
|
||||||
|
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
|
||||||
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
tools: Optional[List] = None,
|
tools: Optional[List] = None,
|
||||||
|
@ -1046,6 +1050,19 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
if eos_token:
|
if eos_token:
|
||||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||||
|
|
||||||
|
provider_config: Optional[BaseConfig] = None
|
||||||
|
if custom_llm_provider is not None and custom_llm_provider in [
|
||||||
|
provider.value for provider in LlmProviders
|
||||||
|
]:
|
||||||
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_config is not None:
|
||||||
|
messages = provider_config.translate_developer_role_to_system_role(
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
supports_system_message is not None
|
supports_system_message is not None
|
||||||
and isinstance(supports_system_message, bool)
|
and isinstance(supports_system_message, bool)
|
||||||
|
@ -1087,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -442,10 +442,20 @@ class OpenAIChatCompletionSystemMessage(TypedDict, total=False):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionDeveloperMessage(TypedDict, total=False):
|
||||||
|
role: Required[Literal["developer"]]
|
||||||
|
content: Required[Union[str, List]]
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False):
|
class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False):
|
||||||
cache_control: ChatCompletionCachedContent
|
cache_control: ChatCompletionCachedContent
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionDeveloperMessage(OpenAIChatCompletionDeveloperMessage, total=False):
|
||||||
|
cache_control: ChatCompletionCachedContent
|
||||||
|
|
||||||
|
|
||||||
ValidUserMessageContentTypes = [
|
ValidUserMessageContentTypes = [
|
||||||
"text",
|
"text",
|
||||||
"image_url",
|
"image_url",
|
||||||
|
@ -458,6 +468,7 @@ AllMessageValues = Union[
|
||||||
ChatCompletionToolMessage,
|
ChatCompletionToolMessage,
|
||||||
ChatCompletionSystemMessage,
|
ChatCompletionSystemMessage,
|
||||||
ChatCompletionFunctionMessage,
|
ChatCompletionFunctionMessage,
|
||||||
|
ChatCompletionDeveloperMessage,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2686,6 +2686,7 @@ def get_optional_params( # noqa: PLR0915
|
||||||
api_version=None,
|
api_version=None,
|
||||||
parallel_tool_calls=None,
|
parallel_tool_calls=None,
|
||||||
drop_params=None,
|
drop_params=None,
|
||||||
|
reasoning_effort=None,
|
||||||
additional_drop_params=None,
|
additional_drop_params=None,
|
||||||
messages: Optional[List[AllMessageValues]] = None,
|
messages: Optional[List[AllMessageValues]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -2771,6 +2772,7 @@ def get_optional_params( # noqa: PLR0915
|
||||||
"drop_params": None,
|
"drop_params": None,
|
||||||
"additional_drop_params": None,
|
"additional_drop_params": None,
|
||||||
"messages": None,
|
"messages": None,
|
||||||
|
"reasoning_effort": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# filter out those parameters that were passed with non-default values
|
# filter out those parameters that were passed with non-default values
|
||||||
|
|
|
@ -18,9 +18,11 @@ from litellm.utils import (
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
)
|
)
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
# test_example.py
|
# test_example.py
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
def _usage_format_tests(usage: litellm.Usage):
|
def _usage_format_tests(usage: litellm.Usage):
|
||||||
|
@ -70,6 +72,34 @@ class BaseLLMChatTest(ABC):
|
||||||
"""Must return the base completion call args"""
|
"""Must return the base completion call args"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_developer_role_translation(self):
|
||||||
|
"""
|
||||||
|
Test that the developer role is translated correctly for non-OpenAI providers.
|
||||||
|
|
||||||
|
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||||
|
"""
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "developer",
|
||||||
|
"content": "Be a good bot!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Hello, how are you?"}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
response = self.completion_function(
|
||||||
|
**base_completion_call_args,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
except litellm.InternalServerError:
|
||||||
|
pytest.skip("Model is overloaded")
|
||||||
|
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
def test_content_list_handling(self):
|
def test_content_list_handling(self):
|
||||||
"""Check if content list is supported by LLM API"""
|
"""Check if content list is supported by LLM API"""
|
||||||
base_completion_call_args = self.get_base_completion_call_args()
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
@ -605,3 +635,72 @@ class BaseLLMChatTest(ABC):
|
||||||
cost = completion_cost(response)
|
cost = completion_cost(response)
|
||||||
|
|
||||||
assert cost > 0
|
assert cost > 0
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOSeriesModelsTest(ABC): # test across azure/openai
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_completion_call_args(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_client(self) -> OpenAI:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_reasoning_effort(self):
|
||||||
|
"""Test that reasoning_effort is passed correctly to the model"""
|
||||||
|
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
client = self.get_client()
|
||||||
|
|
||||||
|
completion_args = self.get_base_completion_call_args()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client.chat.completions.with_raw_response, "create"
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
completion(
|
||||||
|
**completion_args,
|
||||||
|
reasoning_effort="low",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
request_body = mock_client.call_args.kwargs
|
||||||
|
print("request_body: ", request_body)
|
||||||
|
assert request_body["reasoning_effort"] == "low"
|
||||||
|
|
||||||
|
def test_developer_role_translation(self):
|
||||||
|
"""Test that developer role is translated correctly to system role for non-OpenAI providers"""
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
client = self.get_client()
|
||||||
|
|
||||||
|
completion_args = self.get_base_completion_call_args()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client.chat.completions.with_raw_response, "create"
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
completion(
|
||||||
|
**completion_args,
|
||||||
|
reasoning_effort="low",
|
||||||
|
messages=[
|
||||||
|
{"role": "developer", "content": "Be a good bot!"},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
request_body = mock_client.call_args.kwargs
|
||||||
|
print("request_body: ", request_body)
|
||||||
|
assert (
|
||||||
|
request_body["messages"][0]["role"] == "developer"
|
||||||
|
), "Got={} instead of system".format(request_body["messages"][0]["role"])
|
||||||
|
assert request_body["messages"][0]["content"] == "Be a good bot!"
|
||||||
|
|
|
@ -15,10 +15,10 @@ from respx import MockRouter
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Choices, Message, ModelResponse
|
from litellm import Choices, Message, ModelResponse
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
|
||||||
|
|
||||||
|
|
||||||
class TestAzureOpenAIO1(BaseLLMChatTest):
|
class TestAzureOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
|
||||||
def get_base_completion_call_args(self):
|
def get_base_completion_call_args(self):
|
||||||
return {
|
return {
|
||||||
"model": "azure/o1-preview",
|
"model": "azure/o1-preview",
|
||||||
|
@ -26,6 +26,15 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
|
||||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
|
return AzureOpenAI(
|
||||||
|
api_key="my-fake-o1-key",
|
||||||
|
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||||
|
api_version="2024-02-15-preview",
|
||||||
|
)
|
||||||
|
|
||||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
pass
|
pass
|
||||||
|
@ -65,6 +74,24 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
|
||||||
assert fake_stream is False
|
assert fake_stream is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureOpenAIO3(BaseOSeriesModelsTest):
|
||||||
|
def get_base_completion_call_args(self):
|
||||||
|
return {
|
||||||
|
"model": "azure/o3-mini",
|
||||||
|
"api_key": "my-fake-o1-key",
|
||||||
|
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
|
return AzureOpenAI(
|
||||||
|
api_key="my-fake-o1-key",
|
||||||
|
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||||
|
api_version="2024-02-15-preview",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_azure_o3_streaming():
|
def test_azure_o3_streaming():
|
||||||
"""
|
"""
|
||||||
Test that o3 models handles fake streaming correctly.
|
Test that o3 models handles fake streaming correctly.
|
||||||
|
|
|
@ -15,7 +15,7 @@ from respx import MockRouter
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Choices, Message, ModelResponse
|
from litellm import Choices, Message, ModelResponse
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
from base_llm_unit_tests import BaseLLMChatTest, BaseOSeriesModelsTest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"])
|
@pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"])
|
||||||
|
@ -152,12 +152,17 @@ def test_litellm_responses():
|
||||||
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)
|
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIO1(BaseLLMChatTest):
|
class TestOpenAIO1(BaseOSeriesModelsTest, BaseLLMChatTest):
|
||||||
def get_base_completion_call_args(self):
|
def get_base_completion_call_args(self):
|
||||||
return {
|
return {
|
||||||
"model": "o1",
|
"model": "o1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
return OpenAI(api_key="fake-api-key")
|
||||||
|
|
||||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
pass
|
pass
|
||||||
|
@ -167,12 +172,17 @@ class TestOpenAIO1(BaseLLMChatTest):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIO3(BaseLLMChatTest):
|
class TestOpenAIO3(BaseOSeriesModelsTest, BaseLLMChatTest):
|
||||||
def get_base_completion_call_args(self):
|
def get_base_completion_call_args(self):
|
||||||
return {
|
return {
|
||||||
"model": "o3-mini",
|
"model": "o3-mini",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
return OpenAI(api_key="fake-api-key")
|
||||||
|
|
||||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -29,7 +29,8 @@ def test_acompletion_params():
|
||||||
# Assert that the parameters are the same
|
# Assert that the parameters are the same
|
||||||
if keys_acompletion != keys_completion:
|
if keys_acompletion != keys_completion:
|
||||||
pytest.fail(
|
pytest.fail(
|
||||||
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
|
"The parameters of the litellm.acompletion function and litellm.completion are not the same. "
|
||||||
|
f"Completion has extra keys: {keys_completion - keys_acompletion}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue