from typing import Any, List, Optional, Tuple, cast from urllib.parse import urlparse import httpx from httpx import Response import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.prompt_templates.common_utils import ( _audio_or_image_in_message_content, convert_content_list_to_str, ) from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error from litellm.llms.openai.openai import OpenAIConfig from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse, ProviderField from litellm.utils import _add_path_to_api_base, supports_tool_choice class AzureAIStudioConfig(OpenAIConfig): def get_supported_openai_params(self, model: str) -> List: model_supports_tool_choice = True # azure ai supports this by default if not supports_tool_choice(model=f"azure_ai/{model}"): model_supports_tool_choice = False supported_params = super().get_supported_openai_params(model) if not model_supports_tool_choice: filtered_supported_params = [] for param in supported_params: if param != "tool_choice": filtered_supported_params.append(param) return filtered_supported_params return supported_params def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: if api_base and self._should_use_api_key_header(api_base): headers["api-key"] = api_key else: headers["Authorization"] = f"Bearer {api_key}" return headers def _should_use_api_key_header(self, api_base: str) -> bool: """ Returns True if the request should use `api-key` header for authentication. """ parsed_url = urlparse(api_base) host = parsed_url.hostname if host and ( host.endswith(".services.ai.azure.com") or host.endswith(".openai.azure.com") ): return True return False def get_complete_url( self, api_base: Optional[str], model: str, optional_params: dict, stream: Optional[bool] = None, ) -> str: """ Constructs a complete URL for the API request. Args: - api_base: Base URL, e.g., "https://litellm8397336933.services.ai.azure.com" OR "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" - model: Model name. - optional_params: Additional query parameters, including "api_version". - stream: If streaming is required (optional). Returns: - A complete URL string, e.g., "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" """ if api_base is None: raise ValueError( f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" ) original_url = httpx.URL(api_base) # Extract api_version or use default api_version = cast(Optional[str], optional_params.get("api_version")) # Check if 'api-version' is already present if "api-version" not in original_url.params and api_version: # Add api_version to optional_params original_url.params["api-version"] = api_version # Add the path to the base URL if "services.ai.azure.com" in api_base: new_url = _add_path_to_api_base( api_base=api_base, ending_path="/models/chat/completions" ) else: new_url = _add_path_to_api_base( api_base=api_base, ending_path="/chat/completions" ) # Convert optional_params to query parameters query_params = original_url.params final_url = httpx.URL(new_url).copy_with(params=query_params) return str(final_url) def get_required_params(self) -> List[ProviderField]: """For a given provider, return it's required fields with a description""" return [ ProviderField( field_name="api_key", field_type="string", field_description="Your Azure AI Studio API Key.", field_value="zEJ...", ), ProviderField( field_name="api_base", field_type="string", field_description="Your Azure AI Studio API Base.", field_value="https://Mistral-serverless.", ), ] def _transform_messages( self, messages: List[AllMessageValues], model: str, ) -> List: """ - Azure AI Studio doesn't support content as a list. This handles: 1. Transforms list content to a string. 2. If message contains an image or audio, send as is (user-intended) """ for message in messages: # Do nothing if the message contains an image or audio if _audio_or_image_in_message_content(message): continue texts = convert_content_list_to_str(message=message) if texts: message["content"] = texts return messages def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool: try: if "/" in model: model = model.split("/", 1)[1] if ( model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models or model in litellm.open_ai_embedding_models ): return True except Exception: return False return False def _get_openai_compatible_provider_info( self, model: str, api_base: Optional[str], api_key: Optional[str], custom_llm_provider: str, ) -> Tuple[Optional[str], Optional[str], str]: api_base = api_base or get_secret_str("AZURE_AI_API_BASE") dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY") if self._is_azure_openai_model(model=model, api_base=api_base): verbose_logger.debug( "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format( model ) ) custom_llm_provider = "azure" return api_base, dynamic_api_key, custom_llm_provider def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: extra_body = optional_params.pop("extra_body", {}) if extra_body and isinstance(extra_body, dict): optional_params.update(extra_body) optional_params.pop("max_retries", None) return super().transform_request( model, messages, optional_params, litellm_params, headers ) def transform_response( self, model: str, raw_response: Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: model_response.model = f"azure_ai/{model}" return super().transform_response( model=model, raw_response=raw_response, model_response=model_response, logging_obj=logging_obj, request_data=request_data, messages=messages, optional_params=optional_params, litellm_params=litellm_params, encoding=encoding, api_key=api_key, json_mode=json_mode, ) def should_retry_llm_api_inside_llm_translation_on_http_error( self, e: httpx.HTTPStatusError, litellm_params: dict ) -> bool: should_drop_params = litellm_params.get("drop_params") or litellm.drop_params error_text = e.response.text if should_drop_params and "Extra inputs are not permitted" in error_text: return True elif ( "unknown field: parameter index is not a valid field" in error_text ): # remove index from tool calls return True return super().should_retry_llm_api_inside_llm_translation_on_http_error( e=e, litellm_params=litellm_params ) @property def max_retry_on_unprocessable_entity_error(self) -> int: return 2 def transform_request_on_unprocessable_entity_error( self, e: httpx.HTTPStatusError, request_data: dict ) -> dict: _messages = cast(Optional[List[AllMessageValues]], request_data.get("messages")) if ( "unknown field: parameter index is not a valid field" in e.response.text and _messages is not None ): litellm.remove_index_from_tool_calls( messages=_messages, ) data = drop_params_from_unprocessable_entity_error(e=e, data=request_data) return data