litellm-mirror/litellm/llms/azure.py

262 lines
No EOL
11 KiB
Python

from typing import Optional, Union
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message
from typing import Callable, Optional
from litellm import OpenAIConfig
import aiohttp
class AzureOpenAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AzureOpenAIConfig(OpenAIConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]]= None,
functions: Optional[list]= None,
logit_bias: Optional[dict]= None,
max_tokens: Optional[int]= None,
n: Optional[int]= None,
presence_penalty: Optional[int]= None,
stop: Optional[Union[str,list]]=None,
temperature: Optional[int]= None,
top_p: Optional[int]= None) -> None:
super().__init__(frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p)
class AzureChatCompletion(BaseLLM):
_client_session: requests.Session
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["api-key"] = api_key
return headers
def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
try:
if response_object is None or model_response_object is None:
raise AzureOpenAIError(status_code=500, message="Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
model_response_object.usage = response_object["usage"]
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
return model_response_object
except:
AzureOpenAIError(status_code=500, message="Invalid response object.")
def completion(self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
print_verbose: Callable,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict]=None):
super().completion()
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/chat/completions?api-version={api_version}"
data = {
"messages": messages,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=api_base,
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
elif acompletion is True:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
else:
response = self._client_session.post(
url=api_base,
json=data,
headers=headers,
)
if response.status_code != 200:
raise AzureOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json()
if response.status != 200:
raise AzureOpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def embedding(self,
model: str,
input: list,
api_key: str,
api_base: str,
api_version: str,
logging_obj=None,
model_response=None,
optional_params=None,):
super().embedding()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key)
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}"
model = model
data = {
"model": model,
"input": input,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = self._client_session.post(
api_base, headers=headers, json=data
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
if response.status_code!=200:
raise AzureOpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = []
for idx, embedding in enumerate(embedding_response["data"]):
output_data.append(
{
"object": embedding["object"],
"index": embedding["index"],
"embedding": embedding["embedding"]
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response["usage"] = embedding_response["usage"]
return model_response
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())