diff --git a/litellm/__init__.py b/litellm/__init__.py index 04c2d23c7..2e3110aa4 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -455,6 +455,7 @@ provider_list: List = [ "ai21", "baseten", "azure", + "azure_text", "sagemaker", "bedrock", "vllm", diff --git a/litellm/llms/azure_text.py b/litellm/llms/azure_text.py new file mode 100644 index 000000000..690d67b88 --- /dev/null +++ b/litellm/llms/azure_text.py @@ -0,0 +1,511 @@ +from typing import Optional, Union, Any +import types, requests +from .base import BaseLLM +from litellm.utils import ( + ModelResponse, + Choices, + Message, + CustomStreamWrapper, + convert_to_model_response_object, + TranscriptionResponse, +) +from typing import Callable, Optional, BinaryIO +from litellm import OpenAIConfig +import litellm, json +import httpx +from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport +from openai import AzureOpenAI, AsyncAzureOpenAI +from ..llms.openai import OpenAITextCompletion +import uuid +from .prompt_templates.factory import prompt_factory, custom_prompt + +openai_text_completion = OpenAITextCompletion() + + +class AzureOpenAIError(Exception): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): + self.status_code = status_code + self.message = message + if request: + self.request = request + else: + self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") + if response: + self.response = response + else: + self.response = httpx.Response( + status_code=status_code, request=self.request + ) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class AzureOpenAIConfig(OpenAIConfig): + """ + Reference: https://platform.openai.com/docs/api-reference/chat/create + + The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters:: + + - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. + + - `function_call` (string or object): This optional parameter controls how the model calls functions. + + - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. + + - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. + + - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. + + - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. + + - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. + + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. + + - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + """ + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: + super().__init__( + frequency_penalty, + function_call, + functions, + logit_bias, + max_tokens, + n, + presence_penalty, + stop, + temperature, + top_p, + ) + + +def select_azure_base_url_or_endpoint(azure_client_params: dict): + # azure_client_params = { + # "api_version": api_version, + # "azure_endpoint": api_base, + # "azure_deployment": model, + # "http_client": litellm.client_session, + # "max_retries": max_retries, + # "timeout": timeout, + # } + azure_endpoint = azure_client_params.get("azure_endpoint", None) + if azure_endpoint is not None: + # see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192 + if "/openai/deployments" in azure_endpoint: + # this is base_url, not an azure_endpoint + azure_client_params["base_url"] = azure_endpoint + azure_client_params.pop("azure_endpoint") + + return azure_client_params + + +class AzureTextCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def validate_environment(self, api_key, azure_ad_token): + headers = { + "content-type": "application/json", + } + if api_key is not None: + headers["api-key"] = api_key + elif azure_ad_token is not None: + headers["Authorization"] = f"Bearer {azure_ad_token}" + return headers + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + api_key: str, + api_base: str, + api_version: str, + api_type: str, + azure_ad_token: str, + print_verbose: Callable, + timeout, + logging_obj, + optional_params, + litellm_params, + logger_fn, + acompletion: bool = False, + headers: Optional[dict] = None, + client=None, + ): + super().completion() + exception_mapping_worked = False + try: + if model is None or messages is None: + raise AzureOpenAIError( + status_code=422, message=f"Missing model or messages" + ) + + max_retries = optional_params.pop("max_retries", 2) + prompt = prompt_factory( + messages=messages, model=model, custom_llm_provider="azure_text" + ) + + ### CHECK IF CLOUDFLARE AI GATEWAY ### + ### if so - set the model as part of the base url + if "gateway.ai.cloudflare.com" in api_base: + ## build base url - assume api base includes resource name + if client is None: + if not api_base.endswith("/"): + api_base += "/" + api_base += f"{model}" + + azure_client_params = { + "api_version": api_version, + "base_url": f"{api_base}", + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + + if acompletion is True: + client = AsyncAzureOpenAI(**azure_client_params) + else: + client = AzureOpenAI(**azure_client_params) + + data = {"model": None, "prompt": prompt, **optional_params} + else: + data = { + "model": model, # type: ignore + "prompt": prompt, + **optional_params, + } + + if acompletion is True: + if optional_params.get("stream", False): + return self.async_streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + ) + else: + return self.acompletion( + api_base=api_base, + data=data, + model_response=model_response, + api_key=api_key, + api_version=api_version, + model=model, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + logging_obj=logging_obj, + ) + elif "stream" in optional_params and optional_params["stream"] == True: + return self.streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + ) + else: + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={ + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, + "api_version": api_version, + "api_base": api_base, + "complete_input_dict": data, + }, + ) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + if client is None: + azure_client = AzureOpenAI(**azure_client_params) + else: + azure_client = client + if api_version is not None and isinstance( + azure_client._custom_query, dict + ): + # set api_version to version passed by user + azure_client._custom_query.setdefault( + "api-version", api_version + ) + + response = azure_client.completions.create(**data, timeout=timeout) # type: ignore + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=stringified_response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + return openai_text_completion.convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + ) + except AzureOpenAIError as e: + exception_mapping_worked = True + raise e + except Exception as e: + if hasattr(e, "status_code"): + raise AzureOpenAIError(status_code=e.status_code, message=str(e)) + else: + raise AzureOpenAIError(status_code=500, message=str(e)) + + async def acompletion( + self, + api_key: str, + api_version: str, + model: str, + api_base: str, + data: dict, + timeout: Any, + model_response: ModelResponse, + azure_ad_token: Optional[str] = None, + client=None, # this is the AsyncAzureOpenAI + logging_obj=None, + ): + response = None + try: + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + + # setting Azure client + if client is None: + azure_client = AsyncAzureOpenAI(**azure_client_params) + else: + azure_client = client + if api_version is not None and isinstance( + azure_client._custom_query, dict + ): + # set api_version to version passed by user + azure_client._custom_query.setdefault("api-version", api_version) + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + response = await azure_client.completions.create(**data, timeout=timeout) + return openai_text_completion.convert_to_model_response_object( + response_object=response.model_dump(), + model_response_object=model_response, + ) + except AzureOpenAIError as e: + exception_mapping_worked = True + raise e + except Exception as e: + if hasattr(e, "status_code"): + raise e + else: + raise AzureOpenAIError(status_code=500, message=str(e)) + + def streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + data: dict, + model: str, + timeout: Any, + azure_ad_token: Optional[str] = None, + client=None, + ): + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + if client is None: + azure_client = AzureOpenAI(**azure_client_params) + else: + azure_client = client + if api_version is not None and isinstance(azure_client._custom_query, dict): + # set api_version to version passed by user + azure_client._custom_query.setdefault("api-version", api_version) + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + response = azure_client.completions.create(**data, timeout=timeout) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + ) + return streamwrapper + + async def async_streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + data: dict, + model: str, + timeout: Any, + azure_ad_token: Optional[str] = None, + client=None, + ): + try: + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": data.pop("max_retries", 2), + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + azure_client_params["azure_ad_token"] = azure_ad_token + if client is None: + azure_client = AsyncAzureOpenAI(**azure_client_params) + else: + azure_client = client + if api_version is not None and isinstance( + azure_client._custom_query, dict + ): + # set api_version to version passed by user + azure_client._custom_query.setdefault("api-version", api_version) + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + response = await azure_client.completions.create(**data, timeout=timeout) + # return response + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + ) + return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails + except Exception as e: + if hasattr(e, "status_code"): + raise AzureOpenAIError(status_code=e.status_code, message=str(e)) + else: + raise AzureOpenAIError(status_code=500, message=str(e)) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index dc4207a05..ae12d954a 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -807,6 +807,20 @@ def gemini_text_image_pt(messages: list): return content +def azure_text_pt(messages: list): + prompt = "" + for message in messages: + if isinstance(message["content"], str): + prompt += message["content"] + elif isinstance(message["content"], list): + # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models + for element in message["content"]: + if isinstance(element, dict): + if element["type"] == "text": + prompt += element["text"] + return prompt + + # Function call template def function_call_prompt(messages: list, functions: list): function_prompt = ( @@ -907,6 +921,8 @@ def prompt_factory( for message in messages: message.pop("name", None) return messages + elif custom_llm_provider == "azure_text": + return azure_text_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/main.py b/litellm/main.py index 114b46948..a73cf6b88 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -64,6 +64,7 @@ from .llms import ( ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion +from .llms.azure_text import AzureTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.prompt_templates.factory import ( prompt_factory, @@ -96,6 +97,7 @@ dotenv.load_dotenv() # Loading env variables using dotenv openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() azure_chat_completions = AzureChatCompletion() +azure_text_completions = AzureTextCompletion() huggingface = Huggingface() ####### COMPLETION ENDPOINTS ################ @@ -254,6 +256,7 @@ async def acompletion( if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" @@ -800,6 +803,71 @@ def completion( client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + elif custom_llm_provider == "azure_text": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_text_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 729bf7bd9..c4760a10a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1152,6 +1152,16 @@ def test_completion_azure_key_completion_arg(): # test_completion_azure_key_completion_arg() +def test_azure_instruct(): + litellm.set_verbose = True + response = completion( + model="azure_text/instruct-model", + messages=[{"role": "user", "content": "What is the weather like in Boston?"}], + max_tokens=10, + ) + print("response", response) + + async def test_re_use_azure_async_client(): try: print("azure gpt-3.5 ASYNC with clie nttest\n\n")