fix(health.md): add rerank model health check information (#7295)

* fix(health.md): add rerank model health check information

* build(model_prices_and_context_window.json): add gemini 2.0 for google ai studio - pricing + commercial rate limits

* build(model_prices_and_context_window.json): add gemini-2.0 supports audio output = true

* docs(team_model_add.md): clarify allowing teams to add models is an enterprise feature

* fix(o1_transformation.py): add support for 'n', 'response_format' and 'stop' params for o1 and 'stream_options' param for o1-mini

* build(model_prices_and_context_window.json): add 'supports_system_message' to supporting openai models

needed as o1-preview, and o1-mini models don't support 'system message

* fix(o1_transformation.py): translate system message based on if o1 model supports it

* fix(o1_transformation.py): return 'stream' param support if o1-mini/o1-preview

o1 currently doesn't support streaming, but the other model versions do

Fixes https://github.com/BerriAI/litellm/issues/7292

* fix(o1_transformation.py): return tool calling/response_format in supported params if model map says so

Fixes https://github.com/BerriAI/litellm/issues/7292

* fix: fix linting errors

* fix: update '_transform_messages'

* fix(o1_transformation.py): fix provider passed for supported param checks

* test(base_llm_unit_tests.py): skip test if api takes >5s to respond

* fix(utils.py): return false in 'supports_factory' if can't find value

* fix(o1_transformation.py): always return stream + stream_options as supported params + handle stream options being passed in for azure o1

* feat(openai.py): support stream faking natively in openai handler

Allows o1 calls to be faked for just the "o1" model, allows native streaming for o1-mini, o1-preview

 Fixes https://github.com/BerriAI/litellm/issues/7292

* fix(openai.py): use inference param instead of original optional param
This commit is contained in:
Krish Dholakia 2024-12-18 19:18:10 -08:00 committed by GitHub
parent 6a45ee1ef7
commit 5253f639cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 800 additions and 515 deletions

View file

@ -33,6 +33,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import (
@ -198,7 +199,7 @@ class OpenAIConfig(BaseConfig):
return optional_params
def _transform_messages(
self, messages: List[AllMessageValues]
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
return messages
@ -410,6 +411,24 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise e
def mock_streaming(
self,
response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
model: str,
stream_options: Optional[dict] = None,
) -> CustomStreamWrapper:
completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=stream_options,
)
return streaming_response
def completion( # type: ignore # noqa: PLR0915
self,
model_response: ModelResponse,
@ -433,8 +452,21 @@ class OpenAIChatCompletion(BaseLLM):
):
super().completion()
try:
fake_stream: bool = False
if custom_llm_provider is not None and model is not None:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
fake_stream = provider_config.should_fake_stream(
model=model, custom_llm_provider=custom_llm_provider
)
inference_params = optional_params.copy()
stream_options: Optional[dict] = inference_params.pop(
"stream_options", None
)
stream: Optional[bool] = inference_params.pop("stream", False)
if headers:
optional_params["extra_headers"] = headers
inference_params["extra_headers"] = headers
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")
@ -456,7 +488,9 @@ class OpenAIChatCompletion(BaseLLM):
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
):
messages = provider_config._transform_messages(messages)
messages = provider_config._transform_messages(
messages=messages, model=model
)
for _ in range(
2
@ -464,7 +498,7 @@ class OpenAIChatCompletion(BaseLLM):
data = OpenAIConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
optional_params=inference_params,
litellm_params=litellm_params,
headers=headers or {},
)
@ -472,7 +506,7 @@ class OpenAIChatCompletion(BaseLLM):
try:
max_retries = data.pop("max_retries", 2)
if acompletion is True:
if optional_params.get("stream", False):
if stream is True and fake_stream is False:
return self.async_streaming(
logging_obj=logging_obj,
headers=headers,
@ -485,11 +519,13 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
stream_options=stream_options,
)
else:
return self.acompletion(
data=data,
headers=headers,
model=model,
logging_obj=logging_obj,
model_response=model_response,
api_base=api_base,
@ -499,8 +535,9 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries,
organization=organization,
drop_params=drop_params,
fake_stream=fake_stream,
)
elif optional_params.get("stream", False):
elif stream is True and fake_stream is False:
return self.streaming(
logging_obj=logging_obj,
headers=headers,
@ -512,6 +549,7 @@ class OpenAIChatCompletion(BaseLLM):
client=client,
max_retries=max_retries,
organization=organization,
stream_options=stream_options,
)
else:
if not isinstance(max_retries, int):
@ -557,16 +595,26 @@ class OpenAIChatCompletion(BaseLLM):
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
final_response_obj = convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
_response_headers=headers,
)
if fake_stream is True:
return self.mock_streaming(
response=cast(ModelResponse, final_response_obj),
logging_obj=logging_obj,
model=model,
stream_options=stream_options,
)
return final_response_obj
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
optional_params = drop_params_from_unprocessable_entity_error(
e, optional_params
inference_params = drop_params_from_unprocessable_entity_error(
e, inference_params
)
else:
raise e
@ -623,6 +671,7 @@ class OpenAIChatCompletion(BaseLLM):
async def acompletion(
self,
data: dict,
model: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
@ -633,6 +682,8 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
stream_options: Optional[dict] = None,
fake_stream: bool = False,
):
response = None
for _ in range(
@ -667,6 +718,7 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient=openai_aclient, data=data, timeout=timeout
)
stringified_response = response.model_dump()
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
@ -674,12 +726,22 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
)
logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object(
final_response_obj = convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
)
if fake_stream is True:
return self.mock_streaming(
response=cast(ModelResponse, final_response_obj),
logging_obj=logging_obj,
model=model,
stream_options=stream_options,
)
return final_response_obj
except openai.UnprocessableEntityError as e:
## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800
if litellm.drop_params is True or drop_params is True:
@ -710,7 +772,11 @@ class OpenAIChatCompletion(BaseLLM):
client=None,
max_retries=None,
headers=None,
stream_options: Optional[dict] = None,
):
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
@ -761,8 +827,12 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
headers=None,
drop_params: Optional[bool] = None,
stream_options: Optional[dict] = None,
):
response = None
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
for _ in range(2):
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore