mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
321 lines
11 KiB
Python
321 lines
11 KiB
Python
import enum
|
|
from typing import Any, List, Optional, Tuple, cast
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
from httpx import Response
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|
_audio_or_image_in_message_content,
|
|
convert_content_list_to_str,
|
|
)
|
|
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
|
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
|
|
from litellm.llms.openai.openai import OpenAIConfig
|
|
from litellm.secret_managers.main import get_secret_str
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import ModelResponse, ProviderField
|
|
from litellm.utils import _add_path_to_api_base, supports_tool_choice
|
|
|
|
|
|
class AzureFoundryErrorStrings(str, enum.Enum):
|
|
SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'"
|
|
|
|
|
|
class AzureAIStudioConfig(OpenAIConfig):
|
|
def get_supported_openai_params(self, model: str) -> List:
|
|
model_supports_tool_choice = True # azure ai supports this by default
|
|
if not supports_tool_choice(model=f"azure_ai/{model}"):
|
|
model_supports_tool_choice = False
|
|
supported_params = super().get_supported_openai_params(model)
|
|
if not model_supports_tool_choice:
|
|
filtered_supported_params = []
|
|
for param in supported_params:
|
|
if param != "tool_choice":
|
|
filtered_supported_params.append(param)
|
|
return filtered_supported_params
|
|
return supported_params
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
if api_base and self._should_use_api_key_header(api_base):
|
|
headers["api-key"] = api_key
|
|
else:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
return headers
|
|
|
|
def _should_use_api_key_header(self, api_base: str) -> bool:
|
|
"""
|
|
Returns True if the request should use `api-key` header for authentication.
|
|
"""
|
|
parsed_url = urlparse(api_base)
|
|
host = parsed_url.hostname
|
|
if host and (
|
|
host.endswith(".services.ai.azure.com")
|
|
or host.endswith(".openai.azure.com")
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: Optional[str],
|
|
api_key: Optional[str],
|
|
model: str,
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
stream: Optional[bool] = None,
|
|
) -> str:
|
|
"""
|
|
Constructs a complete URL for the API request.
|
|
|
|
Args:
|
|
- api_base: Base URL, e.g.,
|
|
"https://litellm8397336933.services.ai.azure.com"
|
|
OR
|
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
|
- model: Model name.
|
|
- optional_params: Additional query parameters, including "api_version".
|
|
- stream: If streaming is required (optional).
|
|
|
|
Returns:
|
|
- A complete URL string, e.g.,
|
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
|
"""
|
|
if api_base is None:
|
|
raise ValueError(
|
|
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
|
)
|
|
original_url = httpx.URL(api_base)
|
|
|
|
# Extract api_version or use default
|
|
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
|
|
|
# Create a new dictionary with existing params
|
|
query_params = dict(original_url.params)
|
|
|
|
# Add api_version if needed
|
|
if "api-version" not in query_params and api_version:
|
|
query_params["api-version"] = api_version
|
|
|
|
# Add the path to the base URL
|
|
if "services.ai.azure.com" in api_base:
|
|
new_url = _add_path_to_api_base(
|
|
api_base=api_base, ending_path="/models/chat/completions"
|
|
)
|
|
else:
|
|
new_url = _add_path_to_api_base(
|
|
api_base=api_base, ending_path="/chat/completions"
|
|
)
|
|
|
|
# Use the new query_params dictionary
|
|
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
|
|
|
return str(final_url)
|
|
|
|
def get_required_params(self) -> List[ProviderField]:
|
|
"""For a given provider, return it's required fields with a description"""
|
|
return [
|
|
ProviderField(
|
|
field_name="api_key",
|
|
field_type="string",
|
|
field_description="Your Azure AI Studio API Key.",
|
|
field_value="zEJ...",
|
|
),
|
|
ProviderField(
|
|
field_name="api_base",
|
|
field_type="string",
|
|
field_description="Your Azure AI Studio API Base.",
|
|
field_value="https://Mistral-serverless.",
|
|
),
|
|
]
|
|
|
|
def _transform_messages(
|
|
self,
|
|
messages: List[AllMessageValues],
|
|
model: str,
|
|
) -> List:
|
|
"""
|
|
- Azure AI Studio doesn't support content as a list. This handles:
|
|
1. Transforms list content to a string.
|
|
2. If message contains an image or audio, send as is (user-intended)
|
|
"""
|
|
for message in messages:
|
|
# Do nothing if the message contains an image or audio
|
|
if _audio_or_image_in_message_content(message):
|
|
continue
|
|
|
|
texts = convert_content_list_to_str(message=message)
|
|
if texts:
|
|
message["content"] = texts
|
|
return messages
|
|
|
|
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
|
|
try:
|
|
if "/" in model:
|
|
model = model.split("/", 1)[1]
|
|
if (
|
|
model in litellm.open_ai_chat_completion_models
|
|
or model in litellm.open_ai_text_completion_models
|
|
or model in litellm.open_ai_embedding_models
|
|
):
|
|
return True
|
|
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
def _get_openai_compatible_provider_info(
|
|
self,
|
|
model: str,
|
|
api_base: Optional[str],
|
|
api_key: Optional[str],
|
|
custom_llm_provider: str,
|
|
) -> Tuple[Optional[str], Optional[str], str]:
|
|
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
|
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
|
|
|
if self._is_azure_openai_model(model=model, api_base=api_base):
|
|
verbose_logger.debug(
|
|
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
|
model
|
|
)
|
|
)
|
|
custom_llm_provider = "azure"
|
|
return api_base, dynamic_api_key, custom_llm_provider
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
extra_body = optional_params.pop("extra_body", {})
|
|
if extra_body and isinstance(extra_body, dict):
|
|
optional_params.update(extra_body)
|
|
optional_params.pop("max_retries", None)
|
|
return super().transform_request(
|
|
model, messages, optional_params, litellm_params, headers
|
|
)
|
|
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: Any,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
model_response.model = f"azure_ai/{model}"
|
|
return super().transform_response(
|
|
model=model,
|
|
raw_response=raw_response,
|
|
model_response=model_response,
|
|
logging_obj=logging_obj,
|
|
request_data=request_data,
|
|
messages=messages,
|
|
optional_params=optional_params,
|
|
litellm_params=litellm_params,
|
|
encoding=encoding,
|
|
api_key=api_key,
|
|
json_mode=json_mode,
|
|
)
|
|
|
|
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
|
self, e: httpx.HTTPStatusError, litellm_params: dict
|
|
) -> bool:
|
|
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
|
|
error_text = e.response.text
|
|
|
|
if should_drop_params and "Extra inputs are not permitted" in error_text:
|
|
return True
|
|
elif (
|
|
"unknown field: parameter index is not a valid field" in error_text
|
|
): # remove index from tool calls
|
|
return True
|
|
elif (
|
|
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
|
in error_text
|
|
): # remove extra-parameters from tool calls
|
|
return True
|
|
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
|
|
e=e, litellm_params=litellm_params
|
|
)
|
|
|
|
@property
|
|
def max_retry_on_unprocessable_entity_error(self) -> int:
|
|
return 2
|
|
|
|
def transform_request_on_unprocessable_entity_error(
|
|
self, e: httpx.HTTPStatusError, request_data: dict
|
|
) -> dict:
|
|
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
|
|
if (
|
|
"unknown field: parameter index is not a valid field" in e.response.text
|
|
and _messages is not None
|
|
):
|
|
litellm.remove_index_from_tool_calls(
|
|
messages=_messages,
|
|
)
|
|
elif (
|
|
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
|
in e.response.text
|
|
):
|
|
request_data = self._drop_extra_params_from_request_data(
|
|
request_data, e.response.text
|
|
)
|
|
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
|
|
return data
|
|
|
|
def _drop_extra_params_from_request_data(
|
|
self, request_data: dict, error_text: str
|
|
) -> dict:
|
|
params_to_drop = self._extract_params_to_drop_from_error_text(error_text)
|
|
if params_to_drop:
|
|
for param in params_to_drop:
|
|
if param in request_data:
|
|
request_data.pop(param, None)
|
|
return request_data
|
|
|
|
def _extract_params_to_drop_from_error_text(
|
|
self, error_text: str
|
|
) -> Optional[List[str]]:
|
|
"""
|
|
Error text looks like this"
|
|
"Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'.
|
|
"""
|
|
import re
|
|
|
|
# Extract parameters within square brackets
|
|
match = re.search(r"\[(.*?)\]", error_text)
|
|
if not match:
|
|
return []
|
|
|
|
# Parse the extracted string into a list of parameter names
|
|
params_str = match.group(1)
|
|
params = []
|
|
for param in params_str.split(","):
|
|
# Clean up the parameter name (remove quotes, spaces)
|
|
clean_param = param.strip().strip("'").strip('"')
|
|
if clean_param:
|
|
params.append(clean_param)
|
|
return params
|