Merge pull request #5666 from BerriAI/litellm_add_openai_o1

[Feat] Add OpenAI O1 Family Param mapping / config
This commit is contained in:
Ishaan Jaff 2024-09-12 16:15:53 -07:00 committed by GitHub
commit fe5e0bcd15
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 205 additions and 17 deletions

View file

@ -49,7 +49,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.40.0
pip install openai==1.45.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -313,7 +313,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.40.0"
pip install "openai==1.45.0"
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image
@ -406,7 +406,7 @@ jobs:
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
pip install "openai==1.40.0"
pip install "openai==1.45.0"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"
@ -513,7 +513,7 @@ jobs:
pip install "pytest-asyncio==0.21.1"
pip install "google-cloud-aiplatform==1.43.0"
pip install aiohttp
pip install "openai==1.40.0"
pip install "openai==1.45.0"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"

View file

@ -944,6 +944,9 @@ from .llms.OpenAI.openai import (
GroqConfig,
AzureAIStudioConfig,
)
from .llms.OpenAI.o1_reasoning import (
OpenAIO1Config,
)
from .llms.nvidia_nim import NvidiaNimConfig
from .llms.cerebras.chat import CerebrasConfig
from .llms.AI21.chat import AI21ChatConfig

View file

@ -0,0 +1,109 @@
"""
Support for o1 model family
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
"""
import types
from typing import Any, List, Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from .openai import OpenAIConfig
class OpenAIO1Config(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/guides/reasoning
"""
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the given model
"""
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
model="gpt-4o"
)
non_supported_params = [
"logprobs",
"tools",
"tool_choice",
"parallel_tool_calls",
"function_call",
"functions",
]
return [
param for param in all_openai_params if param not in non_supported_params
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_completion_tokens"] = value
return optional_params
def is_model_o1_reasoning_model(self, model: str) -> bool:
if "o1" in model:
return True
return False
def o1_prompt_factory(self, messages: List[AllMessageValues]):
"""
Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
"""
for i, message in enumerate(messages):
if message["role"] == "system":
new_message = ChatCompletionUserMessage(
content=message["content"], role="user"
)
messages[i] = new_message # Replace the old message with the new one
if isinstance(message["content"], list):
new_content = []
for content_item in message["content"]:
if content_item.get("type") == "image_url":
if litellm.drop_params is not True:
raise ValueError(
"Image content is not supported for O-1 models. Set litellm.drop_param to True to drop image content."
)
# If drop_param is True, we simply don't add the image content to new_content
else:
new_content.append(content_item)
message["content"] = new_content
return messages

View file

@ -550,6 +550,8 @@ class OpenAIConfig:
] # works across all models
model_specific_params = []
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
@ -566,6 +568,13 @@ class OpenAIConfig:
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
""" """
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
@ -861,6 +870,13 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if (
litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model)
and messages is not None
):
messages = litellm.OpenAIO1Config().o1_prompt_factory(
messages=messages,
)
for _ in range(
2

View file

@ -0,0 +1,52 @@
import json
from datetime import datetime
from unittest.mock import AsyncMock
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse
@pytest.mark.asyncio
@pytest.mark.respx
async def test_o1_handle_system_role(respx_mock: MockRouter):
"""
Tests that:
- max_tokens is translated to 'max_completion_tokens'
- role 'system' is translated to 'user'
"""
litellm.set_verbose = True
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="o1-preview",
)
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
return_value=httpx.Response(200, json=mock_response.dict())
)
response = await litellm.acompletion(
model="o1-preview",
max_tokens=10,
messages=[{"role": "system", "content": "Hello!"}],
)
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body)
assert request_body == {
"model": "o1-preview",
"max_completion_tokens": 10,
"messages": [{"role": "user", "content": "Hello!"}],
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)

View file

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from openai.types.audio.transcription_create_params import FileTypes
from openai.types.completion_usage import CompletionUsage
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage
from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override
@ -473,6 +473,7 @@ class Usage(CompletionUsage):
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None,
reasoning_tokens: Optional[int] = None,
**params,
):
## DEEPSEEK PROMPT TOKEN HANDLING ## - follow the anthropic format, of having prompt tokens be just the non-cached token input. Enables accurate cost-tracking - Relevant issue: https://github.com/BerriAI/litellm/issues/5285
@ -482,12 +483,19 @@ class Usage(CompletionUsage):
and prompt_tokens is not None
):
prompt_tokens = params["prompt_cache_miss_tokens"]
data = {
"prompt_tokens": prompt_tokens or 0,
"completion_tokens": completion_tokens or 0,
"total_tokens": total_tokens or 0,
}
super().__init__(**data)
# handle reasoning_tokens
completion_tokens_details = None
if reasoning_tokens:
completion_tokens_details = CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
)
super().__init__(
prompt_tokens=prompt_tokens or 0,
completion_tokens=completion_tokens or 0,
total_tokens=total_tokens or 0,
completion_tokens_details=completion_tokens_details or None,
)
## ANTHROPIC MAPPING ##
if "cache_creation_input_tokens" in params and isinstance(

8
poetry.lock generated
View file

@ -1761,13 +1761,13 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "openai"
version = "1.40.1"
version = "1.45.0"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.7.1"
files = [
{file = "openai-1.40.1-py3-none-any.whl", hash = "sha256:cf5929076c6ca31c26f1ed207e9fd19eb05404cc9104f64c9d29bb0ac0c5bcd4"},
{file = "openai-1.40.1.tar.gz", hash = "sha256:cb1294ac1f8c6a1acbb07e090698eb5ad74a7a88484e77126612a4f22579673d"},
{file = "openai-1.45.0-py3-none-any.whl", hash = "sha256:2f1f7b7cf90f038a9f1c24f0d26c0f1790c102ec5acd07ffd70a9b7feac1ff4e"},
{file = "openai-1.45.0.tar.gz", hash = "sha256:731207d10637335413aa3c0955f8f8df30d7636a4a0f9c381f2209d32cf8de97"},
]
[package.dependencies]
@ -3484,4 +3484,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "cryptography", "fastapi", "fastapi-
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0, !=3.9.7"
content-hash = "ad04b75d2f51072f1ee86bf000a236914b30b02184dcc8b3475c14cd300219f0"
content-hash = "6795344f245df1fac99329e370f6a997bbf5010e6841c723dc5e73cf22c3885d"

View file

@ -17,7 +17,7 @@ documentation = "https://docs.litellm.ai"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0, !=3.9.7"
openai = ">=1.40.0"
openai = ">=1.45.0"
python-dotenv = ">=0.2.0"
tiktoken = ">=0.7.0"
importlib-metadata = ">=6.8.0"

View file

@ -1,6 +1,6 @@
# LITELLM PROXY DEPENDENCIES #
anyio==4.4.0 # openai + http req.
openai==1.40.0 # openai req.
openai==1.45.0 # openai req.
fastapi==0.111.0 # server dep
backoff==2.2.1 # server dep
pyyaml==6.0.0 # server dep