o1 - add image param handling (#7312)

* fix(openai.py): fix returning o1 non-streaming requests

fixes issue where fake stream always true for o1

* build(model_prices_and_context_window.json): add 'supports_vision' for o1 models

* fix: add internal server error exception mapping

* fix(base_llm_unit_tests.py): drop temperature from test

* test: mark prompt caching as a flaky test
This commit is contained in:
Krish Dholakia 2024-12-19 11:22:25 -08:00 committed by GitHub
parent a101c1fff4
commit 62b00cf28d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 68 additions and 79 deletions

View file

@ -290,7 +290,10 @@ def exception_type( # type: ignore # noqa: PLR0915
response=getattr(original_exception, "response", None), response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "Web server is returning an unknown error" in error_str: elif (
"Web server is returning an unknown error" in error_str
or "The server had an error processing your request." in error_str
):
exception_mapping_worked = True exception_mapping_worked = True
raise litellm.InternalServerError( raise litellm.InternalServerError(
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",

View file

@ -83,7 +83,10 @@ class BaseConfig(ABC):
} }
def should_fake_stream( def should_fake_stream(
self, model: str, custom_llm_provider: Optional[str] = None self,
model: str,
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool: ) -> bool:
""" """
Returns True if the model/provider should fake stream Returns True if the model/provider should fake stream

View file

@ -37,8 +37,13 @@ class OpenAIO1Config(OpenAIGPTConfig):
return super().get_config() return super().get_config()
def should_fake_stream( def should_fake_stream(
self, model: str, custom_llm_provider: Optional[str] = None self,
model: str,
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool: ) -> bool:
if stream is not True:
return False
supported_stream_models = ["o1-mini", "o1-preview"] supported_stream_models = ["o1-mini", "o1-preview"]
for supported_model in supported_stream_models: for supported_model in supported_stream_models:
if supported_model in model: if supported_model in model:
@ -142,17 +147,4 @@ class OpenAIO1Config(OpenAIGPTConfig):
) )
messages[i] = new_message # Replace the old message with the new one messages[i] = new_message # Replace the old message with the new one
if "content" in message and isinstance(message["content"], list):
new_content = []
for content_item in message["content"]:
if content_item.get("type") == "image_url":
if litellm.drop_params is not True:
raise ValueError(
"Image content is not supported for O-1 models. Set litellm.drop_param to True to drop image content."
)
# If drop_param is True, we simply don't add the image content to new_content
else:
new_content.append(content_item)
message["content"] = new_content
return messages return messages

View file

@ -453,18 +453,18 @@ class OpenAIChatCompletion(BaseLLM):
super().completion() super().completion()
try: try:
fake_stream: bool = False fake_stream: bool = False
if custom_llm_provider is not None and model is not None:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
fake_stream = provider_config.should_fake_stream(
model=model, custom_llm_provider=custom_llm_provider
)
inference_params = optional_params.copy() inference_params = optional_params.copy()
stream_options: Optional[dict] = inference_params.pop( stream_options: Optional[dict] = inference_params.pop(
"stream_options", None "stream_options", None
) )
stream: Optional[bool] = inference_params.pop("stream", False) stream: Optional[bool] = inference_params.pop("stream", False)
if custom_llm_provider is not None and model is not None:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
fake_stream = provider_config.should_fake_stream(
model=model, custom_llm_provider=custom_llm_provider, stream=stream
)
if headers: if headers:
inference_params["extra_headers"] = headers inference_params["extra_headers"] = headers
if model is None or messages is None: if model is None or messages is None:
@ -502,7 +502,6 @@ class OpenAIChatCompletion(BaseLLM):
litellm_params=litellm_params, litellm_params=litellm_params,
headers=headers or {}, headers=headers or {},
) )
try: try:
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if acompletion is True: if acompletion is True:

View file

@ -205,7 +205,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_parallel_function_calling": true, "supports_parallel_function_calling": true,
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_system_messages": true, "supports_system_messages": true,
"supports_response_schema": true "supports_response_schema": true
@ -219,7 +219,7 @@
"cache_read_input_token_cost": 0.0000015, "cache_read_input_token_cost": 0.0000015,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-mini-2024-09-12": { "o1-mini-2024-09-12": {
@ -231,7 +231,7 @@
"cache_read_input_token_cost": 0.0000015, "cache_read_input_token_cost": 0.0000015,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-preview": { "o1-preview": {
@ -243,7 +243,7 @@
"cache_read_input_token_cost": 0.0000075, "cache_read_input_token_cost": 0.0000075,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-preview-2024-09-12": { "o1-preview-2024-09-12": {
@ -255,7 +255,7 @@
"cache_read_input_token_cost": 0.0000075, "cache_read_input_token_cost": 0.0000075,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-2024-12-17": { "o1-2024-12-17": {
@ -269,7 +269,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_parallel_function_calling": true, "supports_parallel_function_calling": true,
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_system_messages": true, "supports_system_messages": true,
"supports_response_schema": true "supports_response_schema": true

View file

@ -205,7 +205,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_parallel_function_calling": true, "supports_parallel_function_calling": true,
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_system_messages": true, "supports_system_messages": true,
"supports_response_schema": true "supports_response_schema": true
@ -219,7 +219,7 @@
"cache_read_input_token_cost": 0.0000015, "cache_read_input_token_cost": 0.0000015,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-mini-2024-09-12": { "o1-mini-2024-09-12": {
@ -231,7 +231,7 @@
"cache_read_input_token_cost": 0.0000015, "cache_read_input_token_cost": 0.0000015,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-preview": { "o1-preview": {
@ -243,7 +243,7 @@
"cache_read_input_token_cost": 0.0000075, "cache_read_input_token_cost": 0.0000075,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-preview-2024-09-12": { "o1-preview-2024-09-12": {
@ -255,7 +255,7 @@
"cache_read_input_token_cost": 0.0000075, "cache_read_input_token_cost": 0.0000075,
"litellm_provider": "openai", "litellm_provider": "openai",
"mode": "chat", "mode": "chat",
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true "supports_prompt_caching": true
}, },
"o1-2024-12-17": { "o1-2024-12-17": {
@ -269,7 +269,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_parallel_function_calling": true, "supports_parallel_function_calling": true,
"supports_vision": false, "supports_vision": true,
"supports_prompt_caching": true, "supports_prompt_caching": true,
"supports_system_messages": true, "supports_system_messages": true,
"supports_response_schema": true "supports_response_schema": true

View file

@ -140,20 +140,6 @@ class BaseLLMChatTest(ABC):
) )
assert response is not None assert response is not None
def test_multilingual_requests(self):
"""
Tests that the provider can handle multilingual requests and invalid utf-8 sequences
Context: https://github.com/openai/openai-python/issues/1921
"""
base_completion_call_args = self.get_base_completion_call_args()
response = self.completion_function(
**base_completion_call_args,
messages=[{"role": "user", "content": "你好世界!\ud83e, ö"}],
)
print("multilingual response: ", response)
assert response is not None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"response_format", "response_format",
[ [
@ -343,6 +329,7 @@ class BaseLLMChatTest(ABC):
) )
assert response is not None assert response is not None
@pytest.mark.flaky(retries=4, delay=1)
def test_prompt_caching(self): def test_prompt_caching(self):
litellm.set_verbose = True litellm.set_verbose = True
from litellm.utils import supports_prompt_caching from litellm.utils import supports_prompt_caching
@ -399,7 +386,6 @@ class BaseLLMChatTest(ABC):
], ],
}, },
], ],
temperature=0.2,
max_tokens=10, max_tokens=10,
) )

View file

@ -280,6 +280,19 @@ class TestOpenAIChatCompletion(BaseLLMChatTest):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass pass
def test_multilingual_requests(self):
"""
Tests that the provider can handle multilingual requests and invalid utf-8 sequences
Context: https://github.com/openai/openai-python/issues/1921
"""
base_completion_call_args = self.get_base_completion_call_args()
response = self.completion_function(
**base_completion_call_args,
messages=[{"role": "user", "content": "你好世界!\ud83e, ö"}],
)
assert response is not None
def test_completion_bad_org(): def test_completion_bad_org():
import litellm import litellm

View file

@ -15,6 +15,7 @@ from respx import MockRouter
import litellm import litellm
from litellm import Choices, Message, ModelResponse from litellm import Choices, Message, ModelResponse
from base_llm_unit_tests import BaseLLMChatTest
@pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"]) @pytest.mark.parametrize("model", ["o1-preview", "o1-mini", "o1"])
@ -94,34 +95,6 @@ async def test_o1_handle_tool_calling_optional_params(
assert expected_tool_calling_support == ("tools" in supported_params) assert expected_tool_calling_support == ("tools" in supported_params)
# @pytest.mark.parametrize(
# "model",
# ["o1"], # "o1-preview", "o1-mini",
# )
# @pytest.mark.asyncio
# async def test_o1_handle_streaming_e2e(model):
# """
# Tests that:
# - max_tokens is translated to 'max_completion_tokens'
# - role 'system' is translated to 'user'
# """
# from openai import AsyncOpenAI
# from litellm.utils import ProviderConfigManager
# from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
# from litellm.types.utils import LlmProviders
# resp = litellm.completion(
# model=model,
# messages=[{"role": "user", "content": "Hello!"}],
# stream=True,
# )
# assert isinstance(resp, CustomStreamWrapper)
# for chunk in resp:
# print("chunk: ", chunk)
# assert True
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"]) @pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"])
async def test_o1_max_completion_tokens(model: str): async def test_o1_max_completion_tokens(model: str):
@ -177,3 +150,23 @@ def test_litellm_responses():
print("response: ", response) print("response: ", response)
assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails) assert isinstance(response.usage.completion_tokens_details, CompletionTokensDetails)
class TestOpenAIO1(BaseLLMChatTest):
def get_base_completion_call_args(self):
return {
"model": "o1",
}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass
def test_o1_supports_vision():
"""Test that o1 supports vision"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
for k, v in litellm.model_cost.items():
if k.startswith("o1") and v.get("litellm_provider") == "openai":
assert v.get("supports_vision") is True, f"{k} does not support vision"