litellm/litellm/llms/AzureOpenAI/azure.py
Krish Dholakia d57be47b0f
Litellm ruff linting enforcement (#5992)
* ci(config.yml): add a 'check_code_quality' step

Addresses https://github.com/BerriAI/litellm/issues/5991

* ci(config.yml): check why circle ci doesn't pick up this test

* ci(config.yml): fix to run 'check_code_quality' tests

* fix(__init__.py): fix unprotected import

* fix(__init__.py): don't remove unused imports

* build(ruff.toml): update ruff.toml to ignore unused imports

* fix: fix: ruff + pyright - fix linting + type-checking errors

* fix: fix linting errors

* fix(lago.py): fix module init error

* fix: fix linting errors

* ci(config.yml): cd into correct dir for checks

* fix(proxy_server.py): fix linting error

* fix(utils.py): fix bare except

causes ruff linting errors

* fix: ruff - fix remaining linting errors

* fix(clickhouse.py): use standard logging object

* fix(__init__.py): fix unprotected import

* fix: ruff - fix linting errors

* fix: fix linting errors

* ci(config.yml): cleanup code qa step (formatting handled in local_testing)

* fix(_health_endpoints.py): fix ruff linting errors

* ci(config.yml): just use ruff in check_code_quality pipeline for now

* build(custom_guardrail.py): include missing file

* style(embedding_handler.py): fix ruff check
2024-10-01 19:44:20 -04:00

2064 lines
76 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import time
import types
from typing import Any, Callable, Coroutine, Iterable, List, Literal, Optional, Union
import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI
from typing_extensions import overload
import litellm
from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import EmbeddingResponse
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
modify_url,
)
from ...types.llms.openai import (
Batch,
CancelBatchRequest,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
CreateBatchRequest,
HttpxBinaryResponseContent,
RetrieveBatchRequest,
)
from ..base import BaseLLM
from .common_utils import process_azure_headers
azure_ad_cache = DualCache()
class AzureOpenAIError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"temperature",
"n",
"stream",
"stop",
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"logprobs",
"top_logprobs",
"response_format",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
api_version: str, # Y-M-D-{optional}
drop_params,
) -> dict:
supported_openai_params = self.get_supported_openai_params()
api_version_times = api_version.split("-")
api_version_year = api_version_times[0]
api_version_month = api_version_times[1]
api_version_day = api_version_times[2]
for param, value in non_default_params.items():
if param == "tool_choice":
"""
This parameter requires API version 2023-12-01-preview or later
tool_choice='required' is not supported as of 2024-05-01-preview
"""
## check if api version supports this param ##
if (
api_version_year < "2023"
or (api_version_year == "2023" and api_version_month < "12")
or (
api_version_year == "2023"
and api_version_month == "12"
and api_version_day < "01"
)
):
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
raise UnsupportedParamsError(
status_code=400,
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
)
elif value == "required" and (
api_version_year == "2024" and api_version_month <= "05"
): ## check if tool_choice value is supported ##
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
raise UnsupportedParamsError(
status_code=400,
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
)
else:
optional_params["tool_choice"] = value
elif param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
schema_name: str = ""
if "response_schema" in value:
json_schema = value["response_schema"]
schema_name = "json_tool_call"
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
schema_name = value["json_schema"]["name"]
"""
Follow similar approach to anthropic - translate to a single tool call.
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if json_schema is not None and (
(api_version_year <= "2024" and api_version_month < "08")
or "gpt-4o" not in model
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=schema_name
),
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=schema_name, parameters=json_schema
),
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
else:
optional_params["response_format"] = value
elif param == "max_completion_tokens":
# TODO - Azure OpenAI will probably add support for this, we should pass it through when Azure adds support
optional_params["max_tokens"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "token":
optional_params["azure_ad_token"] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
"""
return ["europe", "sweden", "switzerland", "france", "uk"]
class AzureOpenAIAssistantsAPIConfig:
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
"""
def __init__(
self,
) -> None:
pass
def get_supported_openai_create_message_params(self):
return [
"role",
"content",
"attachments",
"metadata",
]
def map_openai_params_create_message_params(
self, non_default_params: dict, optional_params: dict
):
for param, value in non_default_params.items():
if param == "role":
optional_params["role"] = value
if param == "metadata":
optional_params["metadata"] = value
elif param == "content": # only string accepted
if isinstance(value, str):
optional_params["content"] = value
else:
raise litellm.utils.UnsupportedParamsError(
message="Azure only accepts content as a string.",
status_code=400,
)
elif (
param == "attachments"
): # this is a v2 param. Azure currently supports the old 'file_id's param
file_ids: List[str] = []
if isinstance(value, list):
for item in value:
if "file_id" in item:
file_ids.append(item["file_id"])
else:
if litellm.drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
value
),
status_code=400,
)
else:
raise litellm.utils.UnsupportedParamsError(
message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
type(value), value
),
status_code=400,
)
return optional_params
def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
client = litellm.module_level_client
req_token = client.post(
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
def _check_dynamic_azure_params(
azure_client_params: dict,
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
) -> bool:
"""
Returns True if user passed in client params != initialized azure client
Currently only implemented for api version
"""
if azure_client is None:
return True
dynamic_params = ["api_version"]
for k, v in azure_client_params.items():
if k in dynamic_params and k == "api_version":
if v is not None and v != azure_client._custom_query["api-version"]:
return True
return False
class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key, azure_ad_token):
headers = {
"content-type": "application/json",
}
if api_key is not None:
headers["api-key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def _get_sync_azure_client(
self,
api_version: Optional[str],
api_base: Optional[str],
api_key: Optional[str],
azure_ad_token: Optional[str],
model: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
client: Optional[Any],
client_type: Literal["sync", "async"],
):
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
elif client_type == "async":
azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
return azure_client
def make_sync_azure_openai_chat_completion_request(
self,
azure_client: AzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
async def make_azure_openai_chat_completion_request(
self,
azure_client: AsyncAzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = await azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: str,
dynamic_params: bool,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
):
super().completion()
try:
if model is None or messages is None:
raise AzureOpenAIError(
status_code=422, message="Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token
)
azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
data = {"model": None, "messages": messages, **optional_params}
else:
data = {
"model": model, # type: ignore
"messages": messages,
**optional_params,
}
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
dynamic_params=dynamic_params,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
dynamic_params=dynamic_params,
timeout=timeout,
client=client,
logging_obj=logging_obj,
convert_tool_call_to_json_mode=json_mode,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
dynamic_params=dynamic_params,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_version": api_version,
"api_base": api_base,
"complete_input_dict": data,
},
)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if (
client is None
or not isinstance(client, AzureOpenAI)
or dynamic_params
):
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault(
"api-version", api_version
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
)
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
convert_tool_call_to_json_mode=json_mode,
_response_headers=headers,
)
except AzureOpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion(
self,
api_key: str,
api_version: str,
model: str,
api_base: str,
data: dict,
timeout: Any,
dynamic_params: bool,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
):
response = None
try:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_azure_openai_chat_completion_request(
azure_client=azure_client,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=headers,
convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
)
except AzureOpenAIError as e:
## LOGGING
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
except asyncio.CancelledError as e:
## LOGGING
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise AzureOpenAIError(status_code=500, message=str(e))
except Exception as e:
## LOGGING
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
if hasattr(e, "status_code"):
raise e
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def streaming(
self,
logging_obj,
api_base: str,
api_key: str,
api_version: str,
dynamic_params: bool,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
_response_headers=process_azure_headers(headers),
)
return streamwrapper
async def async_streaming(
self,
logging_obj: LiteLLMLoggingObj,
api_base: str,
api_key: str,
api_version: str,
dynamic_params: bool,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
):
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=data["messages"],
api_key=azure_client.api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_azure_openai_chat_completion_request(
azure_client=azure_client,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
_response_headers=headers,
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def aembedding(
self,
data: dict,
model_response: EmbeddingResponse,
azure_client_params: dict,
input: list,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
client: Optional[AsyncAzureOpenAI] = None,
timeout=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
openai_aclient = client
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
_response_headers=process_azure_headers(headers),
response_type="embedding",
)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def embedding(
self,
model: str,
input: list,
api_base: str,
api_version: str,
timeout: float,
logging_obj: LiteLLMLoggingObj,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client=None,
aembedding=None,
):
super().embedding()
if self._client_session is None:
self._client_session = self.create_client_session()
try:
data = {"model": model, "input": input, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if aembedding:
azure_client_params["http_client"] = litellm.aclient_session
else:
azure_client_params["http_client"] = litellm.client_session
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
},
)
if aembedding is True:
response = self.aembedding(
data=data,
input=input,
logging_obj=logging_obj,
api_key=api_key,
model_response=model_response,
azure_client_params=azure_client_params,
timeout=timeout,
client=client,
)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## COMPLETION CALL
raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data, "api_base": api_base},
original_response=response,
)
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding", _response_headers=process_azure_headers(headers)) # type: ignore
except AzureOpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def make_async_azure_httpx_request(
self,
client: Optional[AsyncHTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
api_version: str,
api_key: str,
data: dict,
) -> httpx.Response:
"""
Implemented for azure dall-e-2 image gen calls
Alternative to needing a custom transport implementation
"""
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
if (
"images/generations" in api_base
and api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
data.pop(
"model", None
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
response = await async_handler.post(
url=api_base,
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"api-key": api_key,
},
)
if "operation-location" in response.headers:
operation_location_url = response.headers["operation-location"]
else:
raise AzureOpenAIError(status_code=500, message=response.text)
response = await async_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
await response.aread()
timeout_secs: int = 120
start_time = time.time()
if "status" not in response.json():
raise Exception(
"Expected 'status' in response. Got={}".format(response.json())
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
)
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
response = await async_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
await response.aread()
if response.json()["status"] == "failed":
error_data = response.json()
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
)
return await async_handler.post(
url=api_base,
json=data,
headers={
"Content-Type": "application/json;",
"api-key": api_key,
},
)
def make_sync_azure_httpx_request(
self,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
api_version: str,
api_key: str,
data: dict,
) -> httpx.Response:
"""
Implemented for azure dall-e-2 image gen calls
Alternative to needing a custom transport implementation
"""
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
if (
"images/generations" in api_base
and api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
data.pop(
"model", None
) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
response = sync_handler.post(
url=api_base,
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"api-key": api_key,
},
)
if "operation-location" in response.headers:
operation_location_url = response.headers["operation-location"]
else:
raise AzureOpenAIError(status_code=500, message=response.text)
response = sync_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
response.read()
timeout_secs: int = 120
start_time = time.time()
if "status" not in response.json():
raise Exception(
"Expected 'status' in response. Got={}".format(response.json())
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
)
time.sleep(int(response.headers.get("retry-after") or 10))
response = sync_handler.get(
url=operation_location_url,
headers={
"api-key": api_key,
},
)
response.read()
if response.json()["status"] == "failed":
error_data = response.json()
raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
)
return sync_handler.post(
url=api_base,
json=data,
headers={
"Content-Type": "application/json;",
"api-key": api_key,
},
)
def create_azure_base_url(
self, azure_client_params: dict, model: Optional[str]
) -> str:
api_base: str = azure_client_params.get(
"azure_endpoint", ""
) # "https://example-endpoint.openai.azure.com"
if api_base.endswith("/"):
api_base = api_base.rstrip("/")
api_version: str = azure_client_params.get("api_version", "")
if model is None:
model = ""
new_api_base = (
api_base
+ "/openai/deployments/"
+ model
+ "/images/generations"
+ "?api-version="
+ api_version
)
return new_api_base
async def aimage_generation(
self,
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
logging_obj: LiteLLMLoggingObj,
client=None,
timeout=None,
):
response: Optional[dict] = None
try:
# response = await azure_client.images.generate(**data, timeout=timeout)
api_base: str = azure_client_params.get(
"api_base", ""
) # "https://example-endpoint.openai.azure.com"
if api_base.endswith("/"):
api_base = api_base.rstrip("/")
api_version: str = azure_client_params.get("api_version", "")
img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "")
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": img_gen_api_base,
"headers": {"api_key": api_key},
},
)
httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
client=None,
timeout=timeout,
api_base=img_gen_api_base,
api_version=api_version,
api_key=api_key,
data=data,
)
response = httpx_response.json()
stringified_response = response
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
response_type="image_generation",
)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def image_generation(
self,
prompt: str,
timeout: float,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str] = None,
client=None,
aimg_generation=None,
):
try:
if model and len(model) > 0:
model = model
else:
model = None
## BASE MODEL CHECK
if (
model_response is not None
and optional_params.get("base_model", None) is not None
):
model_response._hidden_params["model"] = optional_params.pop(
"base_model"
)
data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation is True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response
img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "")
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": img_gen_api_base,
"headers": {"api_key": api_key},
},
)
httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
client=None,
timeout=timeout,
api_base=img_gen_api_base,
api_version=api_version or "",
api_key=api_key or "",
data=data,
)
response = httpx_response.json()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except AzureOpenAIError as e:
raise e
except Exception as e:
error_code = getattr(e, "status_code", None)
if error_code is not None:
raise AzureOpenAIError(status_code=error_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
def audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
organization: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
azure_ad_token: Optional[str] = None,
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2)
if aspeech is not None and aspeech is True:
return self.async_audio_speech(
model=model,
input=input,
voice=voice,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
timeout=timeout,
client=client,
) # type: ignore
azure_client: AzureOpenAI = self._get_sync_azure_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
model=model,
max_retries=max_retries,
timeout=timeout,
client=client,
client_type="sync",
) # type: ignore
response = azure_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
async def async_audio_speech(
self,
model: str,
input: str,
voice: str,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
api_base=api_base,
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
model=model,
max_retries=max_retries,
timeout=timeout,
client=client,
client_type="async",
) # type: ignore
response = await azure_client.audio.speech.create(
model=model,
voice=voice, # type: ignore
input=input,
**optional_params,
)
return response
def get_headers(
self,
model: Optional[str],
api_key: str,
api_base: str,
api_version: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = litellm.client_session or httpx.Client()
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if messages is None:
messages = [{"role": "user", "content": "Hey"}]
try:
completion = client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
except Exception as e:
raise e
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
api_base: str,
api_version: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = (
litellm.aclient_session or httpx.AsyncClient()
) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AsyncAzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if mode == "completion":
completion = await client.completions.with_raw_response.create(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "chat":
if messages is None:
raise Exception("messages is not set")
completion = await client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
elif mode == "embedding":
if input is None:
raise Exception("input is not set")
completion = await client.embeddings.with_raw_response.create(
model=model, # type: ignore
input=input, # type: ignore
)
elif mode == "image_generation":
if prompt is None:
raise Exception("prompt is not set")
completion = await client.images.with_raw_response.generate(
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_transcription":
# Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(pwd, "../tests/gettysburg.wav")
audio_file = open(file_path, "rb")
completion = await client.audio.transcriptions.with_raw_response.create(
file=audio_file,
model=model, # type: ignore
prompt=prompt, # type: ignore
)
elif mode == "audio_speech":
# Get the current directory of the file being run
completion = await client.audio.speech.with_raw_response.create(
model=model, # type: ignore
input=prompt, # type: ignore
voice="alloy",
)
elif mode == "batch":
completion = await client.batches.with_raw_response.list(limit=1) # type: ignore
else:
raise Exception("mode not set")
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
class AzureBatchesAPI(BaseLLM):
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> Batch:
response = await azure_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = azure_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = azure_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response