diff --git a/litellm/__init__.py b/litellm/__init__.py index 3ac5d0581..b7a287025 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -789,6 +789,7 @@ from .llms.openai import ( MistralConfig, MistralEmbeddingConfig, DeepInfraConfig, + AzureAIStudioConfig, ) from .llms.azure import ( AzureOpenAIConfig, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index fa7745af8..1f2b836c3 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -28,6 +28,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI from ..types.llms.openai import * import openai +from litellm.types.utils import ProviderField class OpenAIError(Exception): @@ -207,6 +208,25 @@ class MistralEmbeddingConfig: return optional_params +class AzureAIStudioConfig: + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Azure AI Studio API Key.", + field_value="zEJ...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Azure AI Studio API Base.", + field_value="https://Mistral-serverless.", + ), + ] + + class DeepInfraConfig: """ Reference: https://deepinfra.com/docs/advanced/openai_api diff --git a/litellm/utils.py b/litellm/utils.py index 036e74767..58130b266 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7221,6 +7221,9 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_required_params() + elif custom_llm_provider == "azure_ai": + return litellm.AzureAIStudioConfig().get_required_params() + else: return []