forked from phoenix/litellm-mirror
Merge pull request #5666 from BerriAI/litellm_add_openai_o1
[Feat] Add OpenAI O1 Family Param mapping / config
This commit is contained in:
commit
fe5e0bcd15
9 changed files with 205 additions and 17 deletions
|
@ -49,7 +49,7 @@ jobs:
|
||||||
pip install opentelemetry-api==1.25.0
|
pip install opentelemetry-api==1.25.0
|
||||||
pip install opentelemetry-sdk==1.25.0
|
pip install opentelemetry-sdk==1.25.0
|
||||||
pip install opentelemetry-exporter-otlp==1.25.0
|
pip install opentelemetry-exporter-otlp==1.25.0
|
||||||
pip install openai==1.40.0
|
pip install openai==1.45.0
|
||||||
pip install prisma==0.11.0
|
pip install prisma==0.11.0
|
||||||
pip install "detect_secrets==1.5.0"
|
pip install "detect_secrets==1.5.0"
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
|
@ -313,7 +313,7 @@ jobs:
|
||||||
pip install "aiodynamo==23.10.1"
|
pip install "aiodynamo==23.10.1"
|
||||||
pip install "asyncio==3.4.3"
|
pip install "asyncio==3.4.3"
|
||||||
pip install "PyGithub==1.59.1"
|
pip install "PyGithub==1.59.1"
|
||||||
pip install "openai==1.40.0"
|
pip install "openai==1.45.0"
|
||||||
# Run pytest and generate JUnit XML report
|
# Run pytest and generate JUnit XML report
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
|
@ -406,7 +406,7 @@ jobs:
|
||||||
pip install "pytest-retry==1.6.3"
|
pip install "pytest-retry==1.6.3"
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
pip install "openai==1.40.0"
|
pip install "openai==1.45.0"
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
|
@ -513,7 +513,7 @@ jobs:
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
pip install "google-cloud-aiplatform==1.43.0"
|
pip install "google-cloud-aiplatform==1.43.0"
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
pip install "openai==1.40.0"
|
pip install "openai==1.45.0"
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
|
|
|
@ -944,6 +944,9 @@ from .llms.OpenAI.openai import (
|
||||||
GroqConfig,
|
GroqConfig,
|
||||||
AzureAIStudioConfig,
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.OpenAI.o1_reasoning import (
|
||||||
|
OpenAIO1Config,
|
||||||
|
)
|
||||||
from .llms.nvidia_nim import NvidiaNimConfig
|
from .llms.nvidia_nim import NvidiaNimConfig
|
||||||
from .llms.cerebras.chat import CerebrasConfig
|
from .llms.cerebras.chat import CerebrasConfig
|
||||||
from .llms.AI21.chat import AI21ChatConfig
|
from .llms.AI21.chat import AI21ChatConfig
|
||||||
|
|
109
litellm/llms/OpenAI/o1_reasoning.py
Normal file
109
litellm/llms/OpenAI/o1_reasoning.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
"""
|
||||||
|
Support for o1 model family
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
|
Translations handled by LiteLLM:
|
||||||
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
|
- role: system ==> translate to role 'user'
|
||||||
|
- streaming => faked by LiteLLM
|
||||||
|
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||||
|
- Logprobs => drop param (if user opts in to dropping param)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||||
|
|
||||||
|
from .openai import OpenAIConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIO1Config(OpenAIConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
"""
|
||||||
|
Get the supported OpenAI params for the given model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
|
||||||
|
model="gpt-4o"
|
||||||
|
)
|
||||||
|
non_supported_params = [
|
||||||
|
"logprobs",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"parallel_tool_calls",
|
||||||
|
"function_call",
|
||||||
|
"functions",
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
param for param in all_openai_params if param not in non_supported_params
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_completion_tokens"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def is_model_o1_reasoning_model(self, model: str) -> bool:
|
||||||
|
if "o1" in model:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def o1_prompt_factory(self, messages: List[AllMessageValues]):
|
||||||
|
"""
|
||||||
|
Handles limitations of O-1 model family.
|
||||||
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
|
- role: system ==> translate to role 'user'
|
||||||
|
"""
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
new_message = ChatCompletionUserMessage(
|
||||||
|
content=message["content"], role="user"
|
||||||
|
)
|
||||||
|
messages[i] = new_message # Replace the old message with the new one
|
||||||
|
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
new_content = []
|
||||||
|
for content_item in message["content"]:
|
||||||
|
if content_item.get("type") == "image_url":
|
||||||
|
if litellm.drop_params is not True:
|
||||||
|
raise ValueError(
|
||||||
|
"Image content is not supported for O-1 models. Set litellm.drop_param to True to drop image content."
|
||||||
|
)
|
||||||
|
# If drop_param is True, we simply don't add the image content to new_content
|
||||||
|
else:
|
||||||
|
new_content.append(content_item)
|
||||||
|
message["content"] = new_content
|
||||||
|
|
||||||
|
return messages
|
|
@ -550,6 +550,8 @@ class OpenAIConfig:
|
||||||
] # works across all models
|
] # works across all models
|
||||||
|
|
||||||
model_specific_params = []
|
model_specific_params = []
|
||||||
|
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||||
|
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
|
||||||
if (
|
if (
|
||||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||||
): # gpt-4 does not support 'response_format'
|
): # gpt-4 does not support 'response_format'
|
||||||
|
@ -566,6 +568,13 @@ class OpenAIConfig:
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
""" """
|
||||||
|
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||||
|
return litellm.OpenAIO1Config().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
supported_openai_params = self.get_supported_openai_params(model)
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param in supported_openai_params:
|
if param in supported_openai_params:
|
||||||
|
@ -861,6 +870,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model)
|
||||||
|
and messages is not None
|
||||||
|
):
|
||||||
|
messages = litellm.OpenAIO1Config().o1_prompt_factory(
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
|
|
52
litellm/tests/test_openai_o1.py
Normal file
52
litellm/tests/test_openai_o1.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from respx import MockRouter
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Choices, Message, ModelResponse
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.respx
|
||||||
|
async def test_o1_handle_system_role(respx_mock: MockRouter):
|
||||||
|
"""
|
||||||
|
Tests that:
|
||||||
|
- max_tokens is translated to 'max_completion_tokens'
|
||||||
|
- role 'system' is translated to 'user'
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
mock_response = ModelResponse(
|
||||||
|
id="cmpl-mock",
|
||||||
|
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||||
|
created=int(datetime.now().timestamp()),
|
||||||
|
model="o1-preview",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||||
|
return_value=httpx.Response(200, json=mock_response.dict())
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="o1-preview",
|
||||||
|
max_tokens=10,
|
||||||
|
messages=[{"role": "system", "content": "Hello!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_request.called
|
||||||
|
request_body = json.loads(mock_request.calls[0].request.content)
|
||||||
|
|
||||||
|
print("request_body: ", request_body)
|
||||||
|
|
||||||
|
assert request_body == {
|
||||||
|
"model": "o1-preview",
|
||||||
|
"max_completion_tokens": 10,
|
||||||
|
"messages": [{"role": "user", "content": "Hello!"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert isinstance(response, ModelResponse)
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
from openai.types.audio.transcription_create_params import FileTypes
|
from openai.types.audio.transcription_create_params import FileTypes
|
||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage
|
||||||
from pydantic import ConfigDict, Field, PrivateAttr
|
from pydantic import ConfigDict, Field, PrivateAttr
|
||||||
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
@ -473,6 +473,7 @@ class Usage(CompletionUsage):
|
||||||
prompt_tokens: Optional[int] = None,
|
prompt_tokens: Optional[int] = None,
|
||||||
completion_tokens: Optional[int] = None,
|
completion_tokens: Optional[int] = None,
|
||||||
total_tokens: Optional[int] = None,
|
total_tokens: Optional[int] = None,
|
||||||
|
reasoning_tokens: Optional[int] = None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
## DEEPSEEK PROMPT TOKEN HANDLING ## - follow the anthropic format, of having prompt tokens be just the non-cached token input. Enables accurate cost-tracking - Relevant issue: https://github.com/BerriAI/litellm/issues/5285
|
## DEEPSEEK PROMPT TOKEN HANDLING ## - follow the anthropic format, of having prompt tokens be just the non-cached token input. Enables accurate cost-tracking - Relevant issue: https://github.com/BerriAI/litellm/issues/5285
|
||||||
|
@ -482,12 +483,19 @@ class Usage(CompletionUsage):
|
||||||
and prompt_tokens is not None
|
and prompt_tokens is not None
|
||||||
):
|
):
|
||||||
prompt_tokens = params["prompt_cache_miss_tokens"]
|
prompt_tokens = params["prompt_cache_miss_tokens"]
|
||||||
data = {
|
|
||||||
"prompt_tokens": prompt_tokens or 0,
|
# handle reasoning_tokens
|
||||||
"completion_tokens": completion_tokens or 0,
|
completion_tokens_details = None
|
||||||
"total_tokens": total_tokens or 0,
|
if reasoning_tokens:
|
||||||
}
|
completion_tokens_details = CompletionTokensDetails(
|
||||||
super().__init__(**data)
|
reasoning_tokens=reasoning_tokens
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
prompt_tokens=prompt_tokens or 0,
|
||||||
|
completion_tokens=completion_tokens or 0,
|
||||||
|
total_tokens=total_tokens or 0,
|
||||||
|
completion_tokens_details=completion_tokens_details or None,
|
||||||
|
)
|
||||||
|
|
||||||
## ANTHROPIC MAPPING ##
|
## ANTHROPIC MAPPING ##
|
||||||
if "cache_creation_input_tokens" in params and isinstance(
|
if "cache_creation_input_tokens" in params and isinstance(
|
||||||
|
|
8
poetry.lock
generated
8
poetry.lock
generated
|
@ -1761,13 +1761,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai"
|
name = "openai"
|
||||||
version = "1.40.1"
|
version = "1.45.0"
|
||||||
description = "The official Python library for the openai API"
|
description = "The official Python library for the openai API"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7.1"
|
python-versions = ">=3.7.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "openai-1.40.1-py3-none-any.whl", hash = "sha256:cf5929076c6ca31c26f1ed207e9fd19eb05404cc9104f64c9d29bb0ac0c5bcd4"},
|
{file = "openai-1.45.0-py3-none-any.whl", hash = "sha256:2f1f7b7cf90f038a9f1c24f0d26c0f1790c102ec5acd07ffd70a9b7feac1ff4e"},
|
||||||
{file = "openai-1.40.1.tar.gz", hash = "sha256:cb1294ac1f8c6a1acbb07e090698eb5ad74a7a88484e77126612a4f22579673d"},
|
{file = "openai-1.45.0.tar.gz", hash = "sha256:731207d10637335413aa3c0955f8f8df30d7636a4a0f9c381f2209d32cf8de97"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -3484,4 +3484,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "cryptography", "fastapi", "fastapi-
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0, !=3.9.7"
|
python-versions = ">=3.8.1,<4.0, !=3.9.7"
|
||||||
content-hash = "ad04b75d2f51072f1ee86bf000a236914b30b02184dcc8b3475c14cd300219f0"
|
content-hash = "6795344f245df1fac99329e370f6a997bbf5010e6841c723dc5e73cf22c3885d"
|
||||||
|
|
|
@ -17,7 +17,7 @@ documentation = "https://docs.litellm.ai"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0, !=3.9.7"
|
python = ">=3.8.1,<4.0, !=3.9.7"
|
||||||
openai = ">=1.40.0"
|
openai = ">=1.45.0"
|
||||||
python-dotenv = ">=0.2.0"
|
python-dotenv = ">=0.2.0"
|
||||||
tiktoken = ">=0.7.0"
|
tiktoken = ">=0.7.0"
|
||||||
importlib-metadata = ">=6.8.0"
|
importlib-metadata = ">=6.8.0"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# LITELLM PROXY DEPENDENCIES #
|
# LITELLM PROXY DEPENDENCIES #
|
||||||
anyio==4.4.0 # openai + http req.
|
anyio==4.4.0 # openai + http req.
|
||||||
openai==1.40.0 # openai req.
|
openai==1.45.0 # openai req.
|
||||||
fastapi==0.111.0 # server dep
|
fastapi==0.111.0 # server dep
|
||||||
backoff==2.2.1 # server dep
|
backoff==2.2.1 # server dep
|
||||||
pyyaml==6.0.0 # server dep
|
pyyaml==6.0.0 # server dep
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue