litellm-mirror/litellm/llms/azure_ai/chat/transformation.py
Krish Dholakia b0f570ee16 Litellm dev 12 30 2024 p2 (#7495)
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
2025-01-01 18:57:29 -08:00

235 lines
8.2 KiB
Python

from typing import Any, List, Optional, Tuple, cast
import httpx
from httpx import Response
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_audio_or_image_in_message_content,
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
from litellm.llms.openai.openai import OpenAIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
from litellm.types.utils import ModelResponse, ProviderField
from litellm.utils import _add_path_to_api_base
class AzureAIStudioConfig(OpenAIConfig):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_base and "services.ai.azure.com" in api_base:
headers["api-key"] = api_key
else:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Constructs a complete URL for the API request.
Args:
- api_base: Base URL, e.g.,
"https://litellm8397336933.services.ai.azure.com"
OR
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
- model: Model name.
- optional_params: Additional query parameters, including "api_version".
- stream: If streaming is required (optional).
Returns:
- A complete URL string, e.g.,
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
"""
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], optional_params.get("api_version"))
# Check if 'api-version' is already present
if "api-version" not in original_url.params and api_version:
# Add api_version to optional_params
original_url.params["api-version"] = api_version
# Add the path to the base URL
if "services.ai.azure.com" in api_base:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/models/chat/completions"
)
else:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/chat/completions"
)
# Convert optional_params to query parameters
query_params = original_url.params
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
def _transform_messages(
self,
messages: List[AllMessageValues],
model: str,
) -> List:
"""
- Azure AI Studio doesn't support content as a list. This handles:
1. Transforms list content to a string.
2. If message contains an image or audio, send as is (user-intended)
"""
for message in messages:
# Do nothing if the message contains an image or audio
if _audio_or_image_in_message_content(message):
continue
texts = convert_content_list_to_str(message=message)
if texts:
message["content"] = texts
return messages
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
try:
if "/" in model:
model = model.split("/", 1)[1]
if (
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
or model in litellm.open_ai_embedding_models
):
return True
except Exception:
return False
return False
def _get_openai_compatible_provider_info(
self,
model: str,
api_base: Optional[str],
api_key: Optional[str],
custom_llm_provider: str,
) -> Tuple[Optional[str], Optional[str], str]:
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
if self._is_azure_openai_model(model=model, api_base=api_base):
verbose_logger.debug(
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
model
)
)
custom_llm_provider = "azure"
return api_base, dynamic_api_key, custom_llm_provider
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
extra_body = optional_params.pop("extra_body", {})
if extra_body and isinstance(extra_body, dict):
optional_params.update(extra_body)
optional_params.pop("max_retries", None)
return super().transform_request(
model, messages, optional_params, litellm_params, headers
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
model_response.model = f"azure_ai/{model}"
return super().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
error_text = e.response.text
if should_drop_params and "Extra inputs are not permitted" in error_text:
return True
elif (
"unknown field: parameter index is not a valid field" in error_text
): # remove index from tool calls
return True
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
return 2
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
if (
"unknown field: parameter index is not a valid field" in e.response.text
and _messages is not None
):
litellm.remove_index_from_tool_calls(
messages=_messages,
)
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
return data