From e128dc4e1f3cc31b44b91cdad0eb7e33b90cdb52 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 12 Jun 2024 20:28:16 -0700 Subject: [PATCH] feat - add azure ai studio models on litellm ui --- litellm/__init__.py | 1 + litellm/llms/openai.py | 20 ++++++++++++++++++++ litellm/utils.py | 3 +++ 3 files changed, 24 insertions(+) diff --git a/litellm/__init__.py b/litellm/__init__.py index 3ac5d0581..b7a287025 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -789,6 +789,7 @@ from .llms.openai import ( MistralConfig, MistralEmbeddingConfig, DeepInfraConfig, + AzureAIStudioConfig, ) from .llms.azure import ( AzureOpenAIConfig, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index fa7745af8..1f2b836c3 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -28,6 +28,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI from ..types.llms.openai import * import openai +from litellm.types.utils import ProviderField class OpenAIError(Exception): @@ -207,6 +208,25 @@ class MistralEmbeddingConfig: return optional_params +class AzureAIStudioConfig: + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Azure AI Studio API Key.", + field_value="zEJ...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Azure AI Studio API Base.", + field_value="https://Mistral-serverless.", + ), + ] + + class DeepInfraConfig: """ Reference: https://deepinfra.com/docs/advanced/openai_api diff --git a/litellm/utils.py b/litellm/utils.py index 036e74767..58130b266 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7221,6 +7221,9 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_required_params() + elif custom_llm_provider == "azure_ai": + return litellm.AzureAIStudioConfig().get_required_params() + else: return []