Litellm dev 12 30 2024 p1 (#7480)

* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* test: fix azure o1 test

* test: fix tests

* fix: fix test
This commit is contained in:
Krish Dholakia 2024-12-30 21:52:52 -08:00 committed by GitHub
parent 60bdfb437f
commit 347779b813
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 273 additions and 141 deletions

View file

@ -275,6 +275,7 @@ class OpenAIChatCompletion(BaseLLM):
is_async: bool,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
organization: Optional[str] = None,
@ -423,6 +424,9 @@ class OpenAIChatCompletion(BaseLLM):
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
dynamic_params: Optional[bool] = None,
azure_ad_token: Optional[str] = None,
acompletion: bool = False,
logger_fn=None,
headers: Optional[dict] = None,
@ -432,6 +436,7 @@ class OpenAIChatCompletion(BaseLLM):
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
super().completion()
try:
fake_stream: bool = False
@ -441,6 +446,7 @@ class OpenAIChatCompletion(BaseLLM):
)
stream: Optional[bool] = inference_params.pop("stream", False)
provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and model is not None:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
@ -450,6 +456,7 @@ class OpenAIChatCompletion(BaseLLM):
fake_stream = provider_config.should_fake_stream(
model=model, custom_llm_provider=custom_llm_provider, stream=stream
)
if headers:
inference_params["extra_headers"] = headers
if model is None or messages is None:
@ -469,7 +476,7 @@ class OpenAIChatCompletion(BaseLLM):
if messages is not None and provider_config is not None:
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
):
): # [TODO]: remove. no longer needed as .transform_request can just handle this.
messages = provider_config._transform_messages(
messages=messages, model=model
)
@ -504,6 +511,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
client=client,
max_retries=max_retries,
@ -520,6 +528,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response=model_response,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
client=client,
max_retries=max_retries,
@ -535,6 +544,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
client=client,
max_retries=max_retries,
@ -546,11 +556,11 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(
status_code=422, message="max retries must be an int"
)
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
@ -667,6 +677,7 @@ class OpenAIChatCompletion(BaseLLM):
timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
@ -684,6 +695,7 @@ class OpenAIChatCompletion(BaseLLM):
is_async=True,
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
@ -758,6 +770,7 @@ class OpenAIChatCompletion(BaseLLM):
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
@ -767,10 +780,12 @@ class OpenAIChatCompletion(BaseLLM):
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
@ -812,6 +827,7 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
@ -829,6 +845,7 @@ class OpenAIChatCompletion(BaseLLM):
is_async=True,
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,