mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Improved O3 + Azure O3 support (#8181)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
* fix: support azure o3 model family for fake streaming workaround (#8162) * fix: support azure o3 model family for fake streaming workaround * refactor: rename helper to is_o_series_model for clarity * update function calling parameters for o3 models (#8178) * refactor(o1_transformation.py): refactor o1 config to be o series config, expand o series model check to o3 ensures max_tokens is correctly translated for o3 * feat(openai/): refactor o1 files to be 'o_series' files expands naming to cover o3 * fix(azure/chat/o1_handler.py): azure openai is an instance of openai - was causing resets * test(test_azure_o_series.py): assert stream faked for azure o3 mini Resolves https://github.com/BerriAI/litellm/pull/8162 * fix(o1_transformation.py): fix o1 transformation logic to handle explicit o1_series routing * docs(azure.md): update doc with `o_series/` model name --------- Co-authored-by: byrongrogan <47910641+byrongrogan@users.noreply.github.com> Co-authored-by: Low Jian Sheng <15527690+lowjiansheng@users.noreply.github.com>
This commit is contained in:
parent
91ed05df29
commit
23f458d2da
14 changed files with 211 additions and 37 deletions
|
@ -10,7 +10,7 @@ import TabItem from '@theme/TabItem';
|
|||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Azure OpenAI Service provides REST API access to OpenAI's powerful language models including o1, o1-mini, GPT-4o, GPT-4o mini, GPT-4 Turbo with Vision, GPT-4, GPT-3.5-Turbo, and Embeddings model series |
|
||||
| Provider Route on LiteLLM | `azure/` |
|
||||
| Provider Route on LiteLLM | `azure/`, [`azure/o_series/`](#azure-o-series-models) |
|
||||
| Supported Operations | [`/chat/completions`](#azure-openai-chat-completion-models), [`/completions`](#azure-instruct-models), [`/embeddings`](../embedding/supported_embedding#azure-openai-embedding-models), [`/audio/speech`](#azure-text-to-speech-tts), [`/audio/transcriptions`](../audio_transcription), `/fine_tuning`, [`/batches`](#azure-batches-api), `/files`, [`/images`](../image_generation#azure-openai-image-generation-models) |
|
||||
| Link to Provider Doc | [Azure OpenAI ↗](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
|
||||
|
||||
|
@ -948,6 +948,65 @@ Expected Response:
|
|||
{"data":[{"id":"batch_R3V...}
|
||||
```
|
||||
|
||||
## O-Series Models
|
||||
|
||||
Azure OpenAI O-Series models are supported on LiteLLM.
|
||||
|
||||
LiteLLM routes any deployment name with `o1` or `o3` in the model name, to the O-Series [transformation](https://github.com/BerriAI/litellm/blob/91ed05df2962b8eee8492374b048d27cc144d08c/litellm/llms/azure/chat/o1_transformation.py#L4) logic.
|
||||
|
||||
To set this explicitly, set `model` to `azure/o_series/<your-deployment-name>`.
|
||||
|
||||
**Automatic Routing**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
litellm.completion(model="azure/my-o3-deployment", messages=[{"role": "user", "content": "Hello, world!"}]) # 👈 Note: 'o3' in the deployment name
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: o3-mini
|
||||
litellm_params:
|
||||
model: azure/o3-model
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**Explicit Routing**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
litellm.completion(model="azure/o_series/my-random-deployment-name", messages=[{"role": "user", "content": "Hello, world!"}]) # 👈 Note: 'o_series/' in the deployment name
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: o3-mini
|
||||
litellm_params:
|
||||
model: azure/o_series/my-random-deployment-name
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## Advanced
|
||||
### Azure API Load-Balancing
|
||||
|
||||
|
|
|
@ -886,11 +886,12 @@ from .llms.groq.chat.transformation import GroqChatConfig
|
|||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.openai.chat.o1_transformation import (
|
||||
OpenAIO1Config,
|
||||
from .llms.openai.chat.o_series_transformation import (
|
||||
OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility
|
||||
OpenAIOSeriesConfig,
|
||||
)
|
||||
|
||||
openAIO1Config = OpenAIO1Config()
|
||||
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
||||
from .llms.openai.chat.gpt_transformation import (
|
||||
OpenAIGPTConfig,
|
||||
)
|
||||
|
|
|
@ -81,7 +81,7 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
elif custom_llm_provider == "openai":
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "azure":
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):
|
||||
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Handler file for calls to Azure OpenAI's o1 family of models
|
||||
Handler file for calls to Azure OpenAI's o1/o3 family of models
|
||||
|
||||
Written separately to handle faking streaming for o1 models.
|
||||
Written separately to handle faking streaming for o1 and o3 models.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
@ -36,7 +36,9 @@ class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
|
|||
]:
|
||||
|
||||
# Override to use Azure-specific client initialization
|
||||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||
if not isinstance(client, AzureOpenAI) and not isinstance(
|
||||
client, AsyncAzureOpenAI
|
||||
):
|
||||
client = None
|
||||
|
||||
return get_azure_openai_client(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Support for o1 model family
|
||||
Support for o1 and o3 model families
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
|
@ -12,15 +12,16 @@ Translations handled by LiteLLM:
|
|||
- Temperature => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
from ...openai.chat.o1_transformation import OpenAIO1Config
|
||||
from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
|
||||
|
||||
class AzureOpenAIO1Config(OpenAIO1Config):
|
||||
class AzureOpenAIO1Config(OpenAIOSeriesConfig):
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
|
@ -28,8 +29,9 @@ class AzureOpenAIO1Config(OpenAIO1Config):
|
|||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Currently no Azure OpenAI models support native streaming.
|
||||
Currently no Azure O Series models support native streaming.
|
||||
"""
|
||||
|
||||
if stream is not True:
|
||||
return False
|
||||
|
||||
|
@ -38,14 +40,31 @@ class AzureOpenAIO1Config(OpenAIO1Config):
|
|||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if model_info.get("supports_native_streaming") is True:
|
||||
|
||||
if (
|
||||
model_info.get("supports_native_streaming") is True
|
||||
): # allow user to override default with model_info={"supports_native_streaming": true}
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error getting model info in AzureOpenAIO1Config: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def is_o1_model(self, model: str) -> bool:
|
||||
return "o1" in model
|
||||
def is_o_series_model(self, model: str) -> bool:
|
||||
return "o1" in model or "o3" in model or "o_series/" in model
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
model = model.replace(
|
||||
"o_series/", ""
|
||||
) # handle o_series/my-random-deployment-name
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ from litellm.utils import (
|
|||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIO1Config(OpenAIGPTConfig):
|
||||
class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||
"""
|
||||
|
@ -128,8 +128,10 @@ class OpenAIO1Config(OpenAIGPTConfig):
|
|||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def is_model_o1_reasoning_model(self, model: str) -> bool:
|
||||
if model in litellm.open_ai_chat_completion_models and "o1" in model:
|
||||
def is_model_o_series_model(self, model: str) -> bool:
|
||||
if model in litellm.open_ai_chat_completion_models and (
|
||||
"o1" in model or "o3" in model
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
|
@ -47,8 +47,11 @@ from litellm.utils import (
|
|||
|
||||
from ...types.llms.openai import *
|
||||
from ..base import BaseLLM
|
||||
from .chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
||||
|
||||
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
||||
|
||||
|
||||
class MistralEmbeddingConfig:
|
||||
"""
|
||||
|
@ -174,8 +177,8 @@ class OpenAIConfig(BaseConfig):
|
|||
Returns:
|
||||
list: List of supported openai parameters
|
||||
"""
|
||||
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
||||
return litellm.openAIO1Config.get_supported_openai_params(model=model)
|
||||
if openaiOSeriesConfig.is_model_o_series_model(model=model):
|
||||
return openaiOSeriesConfig.get_supported_openai_params(model=model)
|
||||
elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model):
|
||||
return litellm.openAIGPTAudioConfig.get_supported_openai_params(model=model)
|
||||
else:
|
||||
|
@ -203,8 +206,8 @@ class OpenAIConfig(BaseConfig):
|
|||
drop_params: bool,
|
||||
) -> dict:
|
||||
""" """
|
||||
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
||||
return litellm.openAIO1Config.map_openai_params(
|
||||
if openaiOSeriesConfig.is_model_o_series_model(model=model):
|
||||
return openaiOSeriesConfig.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
|
|
|
@ -1201,7 +1201,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureOpenAIO1Config.get_config()
|
||||
for k, v in config.items():
|
||||
|
|
|
@ -211,8 +211,11 @@
|
|||
"cache_read_input_token_cost": 0.00000055,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": false,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
"supports_prompt_caching": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"o3-mini-2025-01-31": {
|
||||
"max_tokens": 100000,
|
||||
|
@ -223,8 +226,11 @@
|
|||
"cache_read_input_token_cost": 0.00000055,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": false,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
"supports_prompt_caching": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"o1-mini-2024-09-12": {
|
||||
"max_tokens": 65536,
|
||||
|
@ -978,8 +984,9 @@
|
|||
"cache_read_input_token_cost": 0.00000055,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat",
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"azure/o1-mini": {
|
||||
"max_tokens": 65536,
|
||||
|
|
|
@ -3485,7 +3485,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):
|
||||
optional_params = litellm.AzureOpenAIO1Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
|
@ -5918,9 +5918,9 @@ class ProviderConfigManager:
|
|||
"""
|
||||
if (
|
||||
provider == LlmProviders.OPENAI
|
||||
and litellm.openAIO1Config.is_model_o1_reasoning_model(model=model)
|
||||
and litellm.openaiOSeriesConfig.is_model_o_series_model(model=model)
|
||||
):
|
||||
return litellm.OpenAIO1Config()
|
||||
return litellm.openaiOSeriesConfig
|
||||
elif litellm.LlmProviders.DEEPSEEK == provider:
|
||||
return litellm.DeepSeekChatConfig()
|
||||
elif litellm.LlmProviders.GROQ == provider:
|
||||
|
@ -5993,7 +5993,7 @@ class ProviderConfigManager:
|
|||
):
|
||||
return litellm.AI21ChatConfig()
|
||||
elif litellm.LlmProviders.AZURE == provider:
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
if litellm.AzureOpenAIO1Config().is_o_series_model(model=model):
|
||||
return litellm.AzureOpenAIO1Config()
|
||||
return litellm.AzureOpenAIConfig()
|
||||
elif litellm.LlmProviders.AZURE_AI == provider:
|
||||
|
|
|
@ -211,8 +211,11 @@
|
|||
"cache_read_input_token_cost": 0.00000055,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": false,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
"supports_prompt_caching": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"o3-mini-2025-01-31": {
|
||||
"max_tokens": 100000,
|
||||
|
@ -223,8 +226,11 @@
|
|||
"cache_read_input_token_cost": 0.00000055,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": false,
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true
|
||||
"supports_prompt_caching": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"o1-mini-2024-09-12": {
|
||||
"max_tokens": 65536,
|
||||
|
@ -978,8 +984,9 @@
|
|||
"cache_read_input_token_cost": 0.00000055,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat",
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true
|
||||
"supports_vision": false,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_response_schema": true
|
||||
},
|
||||
"azure/o1-mini": {
|
||||
"max_tokens": 65536,
|
||||
|
|
|
@ -63,3 +63,65 @@ class TestAzureOpenAIO1(BaseLLMChatTest):
|
|||
model="azure/o1-preview", stream=True
|
||||
)
|
||||
assert fake_stream is False
|
||||
|
||||
|
||||
def test_azure_o3_streaming():
|
||||
"""
|
||||
Test that o3 models handles fake streaming correctly.
|
||||
"""
|
||||
from openai import AzureOpenAI
|
||||
from litellm import completion
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key="my-fake-o1-key",
|
||||
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
api_version="2024-02-15-preview",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_create:
|
||||
try:
|
||||
completion(
|
||||
model="azure/o3-mini",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
stream=True,
|
||||
client=client,
|
||||
)
|
||||
except (
|
||||
Exception
|
||||
) as e: # expect output translation error as mock response doesn't return a json
|
||||
print(e)
|
||||
assert mock_create.call_count == 1
|
||||
assert "stream" not in mock_create.call_args.kwargs
|
||||
|
||||
|
||||
def test_azure_o_series_routing():
|
||||
"""
|
||||
Allows user to pass model="azure/o_series/<any-deployment-name>" for explicit o_series model routing.
|
||||
"""
|
||||
from openai import AzureOpenAI
|
||||
from litellm import completion
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key="my-fake-o1-key",
|
||||
base_url="https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
api_version="2024-02-15-preview",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_create:
|
||||
try:
|
||||
completion(
|
||||
model="azure/o_series/my-random-deployment-name",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
stream=True,
|
||||
client=client,
|
||||
)
|
||||
except (
|
||||
Exception
|
||||
) as e: # expect output translation error as mock response doesn't return a json
|
||||
print(e)
|
||||
assert mock_create.call_count == 1
|
||||
assert "stream" not in mock_create.call_args.kwargs
|
|
@ -167,6 +167,17 @@ class TestOpenAIO1(BaseLLMChatTest):
|
|||
pass
|
||||
|
||||
|
||||
class TestOpenAIO3(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "o3-mini",
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
||||
|
||||
def test_o1_supports_vision():
|
||||
"""Test that o1 supports vision"""
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue