diff --git a/litellm/llms/azure/chat/o1_handler.py b/litellm/llms/azure/chat/o1_handler.py index 3660ffdc73..1cb6f888c3 100644 --- a/litellm/llms/azure/chat/o1_handler.py +++ b/litellm/llms/azure/chat/o1_handler.py @@ -4,96 +4,48 @@ Handler file for calls to Azure OpenAI's o1 family of models Written separately to handle faking streaming for o1 models. """ -import asyncio -from typing import Any, Callable, List, Optional, Union +from typing import Optional, Union -from httpx._config import Timeout +import httpx +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI -from litellm.litellm_core_utils.litellm_logging import Logging -from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator -from litellm.types.utils import ModelResponse -from litellm.utils import CustomStreamWrapper - -from ..azure import AzureChatCompletion +from ...openai.openai import OpenAIChatCompletion +from ..common_utils import get_azure_openai_client -class AzureOpenAIO1ChatCompletion(AzureChatCompletion): - - async def mock_async_streaming( +class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion): + def _get_openai_client( self, - response: Any, - model: Optional[str], - logging_obj: Any, - ): - model_response = await response - completion_stream = MockResponseIterator(model_response=model_response) - streaming_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="azure", - logging_obj=logging_obj, + is_async: bool, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + max_retries: Optional[int] = 2, + organization: Optional[str] = None, + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, + ) -> Optional[ + Union[ + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + ] + ]: + + # Override to use Azure-specific client initialization + if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): + client = None + + return get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + api_version=api_version, + client=client, + _is_async=is_async, ) - return streaming_response - - def completion( - self, - model: str, - messages: List, - model_response: ModelResponse, - api_key: str, - api_base: str, - api_version: str, - api_type: str, - azure_ad_token: str, - dynamic_params: bool, - print_verbose: Callable[..., Any], - timeout: Union[float, Timeout], - logging_obj: Logging, - optional_params, - litellm_params, - logger_fn, - acompletion: bool = False, - headers: Optional[dict] = None, - client=None, - ): - stream: Optional[bool] = optional_params.pop("stream", False) - stream_options: Optional[dict] = optional_params.pop("stream_options", None) - response = super().completion( - model, - messages, - model_response, - api_key, - api_base, - api_version, - api_type, - azure_ad_token, - dynamic_params, - print_verbose, - timeout, - logging_obj, - optional_params, - litellm_params, - logger_fn, - acompletion, - headers, - client, - ) - - if stream is True: - if asyncio.iscoroutine(response): - return self.mock_async_streaming( - response=response, model=model, logging_obj=logging_obj # type: ignore - ) - - completion_stream = MockResponseIterator(model_response=response) - streaming_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="openai", - logging_obj=logging_obj, - stream_options=stream_options, - ) - - return streaming_response - else: - return response diff --git a/litellm/llms/azure/chat/o1_transformation.py b/litellm/llms/azure/chat/o1_transformation.py index 5a15a884e9..a14dd06966 100644 --- a/litellm/llms/azure/chat/o1_transformation.py +++ b/litellm/llms/azure/chat/o1_transformation.py @@ -12,10 +12,41 @@ Translations handled by LiteLLM: - Temperature => drop param (if user opts in to dropping param) """ +from typing import Optional + +from litellm import verbose_logger +from litellm.utils import get_model_info + from ...openai.chat.o1_transformation import OpenAIO1Config class AzureOpenAIO1Config(OpenAIO1Config): + def should_fake_stream( + self, + model: Optional[str], + stream: Optional[bool], + custom_llm_provider: Optional[str] = None, + ) -> bool: + """ + Currently no Azure OpenAI models support native streaming. + """ + if stream is not True: + return False + + if model is not None: + try: + model_info = get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) + if model_info.get("supports_native_streaming") is True: + return False + except Exception as e: + verbose_logger.debug( + f"Error getting model info in AzureOpenAIO1Config: {e}" + ) + + return True + def is_o1_model(self, model: str) -> bool: o1_models = ["o1-mini", "o1-preview"] for m in o1_models: diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index f374c18cf8..df954a8a67 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -1,7 +1,9 @@ from typing import Callable, Optional, Union import httpx +from openai import AsyncAzureOpenAI, AzureOpenAI +import litellm from litellm._logging import verbose_logger from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str @@ -25,6 +27,39 @@ class AzureOpenAIError(BaseLLMException): ) +def get_azure_openai_client( + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + api_version: Optional[str] = None, + organization: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + _is_async: bool = False, +) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: + received_args = locals() + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + if "api_version" not in data: + data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION + if _is_async is True: + openai_client = AsyncAzureOpenAI(**data) + else: + openai_client = AzureOpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict: openai_headers = {} if "x-ratelimit-limit-requests" in headers: diff --git a/litellm/llms/azure/files/handler.py b/litellm/llms/azure/files/handler.py index fd1ef0d535..f442af855e 100644 --- a/litellm/llms/azure/files/handler.py +++ b/litellm/llms/azure/files/handler.py @@ -4,43 +4,11 @@ import httpx from openai import AsyncAzureOpenAI, AzureOpenAI from openai.types.file_deleted import FileDeleted -import litellm from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM from litellm.types.llms.openai import * - -def get_azure_openai_client( - api_key: Optional[str], - api_base: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - api_version: Optional[str] = None, - organization: Optional[str] = None, - client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, - _is_async: bool = False, -) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: - received_args = locals() - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None - if client is None: - data = {} - for k, v in received_args.items(): - if k == "self" or k == "client" or k == "_is_async": - pass - elif k == "api_base" and v is not None: - data["azure_endpoint"] = v - elif v is not None: - data[k] = v - if "api_version" not in data: - data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION - if _is_async is True: - openai_client = AsyncAzureOpenAI(**data) - else: - openai_client = AzureOpenAI(**data) # type: ignore - else: - openai_client = client - - return openai_client +from ..common_utils import get_azure_openai_client class AzureOpenAIFilesAPI(BaseLLM): diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 0ee8e3dadd..2ec9037e32 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -275,6 +275,7 @@ class OpenAIChatCompletion(BaseLLM): is_async: bool, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), max_retries: Optional[int] = 2, organization: Optional[str] = None, @@ -423,6 +424,9 @@ class OpenAIChatCompletion(BaseLLM): print_verbose: Optional[Callable] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, + dynamic_params: Optional[bool] = None, + azure_ad_token: Optional[str] = None, acompletion: bool = False, logger_fn=None, headers: Optional[dict] = None, @@ -432,6 +436,7 @@ class OpenAIChatCompletion(BaseLLM): custom_llm_provider: Optional[str] = None, drop_params: Optional[bool] = None, ): + super().completion() try: fake_stream: bool = False @@ -441,6 +446,7 @@ class OpenAIChatCompletion(BaseLLM): ) stream: Optional[bool] = inference_params.pop("stream", False) provider_config: Optional[BaseConfig] = None + if custom_llm_provider is not None and model is not None: provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=LlmProviders(custom_llm_provider) @@ -450,6 +456,7 @@ class OpenAIChatCompletion(BaseLLM): fake_stream = provider_config.should_fake_stream( model=model, custom_llm_provider=custom_llm_provider, stream=stream ) + if headers: inference_params["extra_headers"] = headers if model is None or messages is None: @@ -469,7 +476,7 @@ class OpenAIChatCompletion(BaseLLM): if messages is not None and provider_config is not None: if isinstance(provider_config, OpenAIGPTConfig) or isinstance( provider_config, OpenAIConfig - ): + ): # [TODO]: remove. no longer needed as .transform_request can just handle this. messages = provider_config._transform_messages( messages=messages, model=model ) @@ -504,6 +511,7 @@ class OpenAIChatCompletion(BaseLLM): model=model, api_base=api_base, api_key=api_key, + api_version=api_version, timeout=timeout, client=client, max_retries=max_retries, @@ -520,6 +528,7 @@ class OpenAIChatCompletion(BaseLLM): model_response=model_response, api_base=api_base, api_key=api_key, + api_version=api_version, timeout=timeout, client=client, max_retries=max_retries, @@ -535,6 +544,7 @@ class OpenAIChatCompletion(BaseLLM): model=model, api_base=api_base, api_key=api_key, + api_version=api_version, timeout=timeout, client=client, max_retries=max_retries, @@ -546,11 +556,11 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError( status_code=422, message="max retries must be an int" ) - openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, api_key=api_key, api_base=api_base, + api_version=api_version, timeout=timeout, max_retries=max_retries, organization=organization, @@ -667,6 +677,7 @@ class OpenAIChatCompletion(BaseLLM): timeout: Union[float, httpx.Timeout], api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, organization: Optional[str] = None, client=None, max_retries=None, @@ -684,6 +695,7 @@ class OpenAIChatCompletion(BaseLLM): is_async=True, api_key=api_key, api_base=api_base, + api_version=api_version, timeout=timeout, max_retries=max_retries, organization=organization, @@ -758,6 +770,7 @@ class OpenAIChatCompletion(BaseLLM): model: str, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, organization: Optional[str] = None, client=None, max_retries=None, @@ -767,10 +780,12 @@ class OpenAIChatCompletion(BaseLLM): data["stream"] = True if stream_options is not None: data["stream_options"] = stream_options + openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, api_key=api_key, api_base=api_base, + api_version=api_version, timeout=timeout, max_retries=max_retries, organization=organization, @@ -812,6 +827,7 @@ class OpenAIChatCompletion(BaseLLM): logging_obj: LiteLLMLoggingObj, api_key: Optional[str] = None, api_base: Optional[str] = None, + api_version: Optional[str] = None, organization: Optional[str] = None, client=None, max_retries=None, @@ -829,6 +845,7 @@ class OpenAIChatCompletion(BaseLLM): is_async=True, api_key=api_key, api_base=api_base, + api_version=api_version, timeout=timeout, max_retries=max_retries, organization=organization, diff --git a/litellm/main.py b/litellm/main.py index de1874b8b8..54205d150b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1225,10 +1225,7 @@ def completion( # type: ignore # noqa: PLR0915 if extra_headers is not None: optional_params["extra_headers"] = extra_headers - if ( - litellm.enable_preview_features - and litellm.AzureOpenAIO1Config().is_o1_model(model=model) - ): + if litellm.AzureOpenAIO1Config().is_o1_model(model=model): ## LOAD CONFIG - if set config = litellm.AzureOpenAIO1Config.get_config() for k, v in config.items(): @@ -1244,7 +1241,6 @@ def completion( # type: ignore # noqa: PLR0915 api_key=api_key, api_base=api_base, api_version=api_version, - api_type=api_type, dynamic_params=dynamic_params, azure_ad_token=azure_ad_token, model_response=model_response, @@ -1256,6 +1252,7 @@ def completion( # type: ignore # noqa: PLR0915 acompletion=acompletion, timeout=timeout, # type: ignore client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + custom_llm_provider=custom_llm_provider, ) else: ## LOAD CONFIG - if set diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 9bbc1fd875..0000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -