diff --git a/litellm/__init__.py b/litellm/__init__.py index 3ac5d0581..b7a287025 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -789,6 +789,7 @@ from .llms.openai import ( MistralConfig, MistralEmbeddingConfig, DeepInfraConfig, + AzureAIStudioConfig, ) from .llms.azure import ( AzureOpenAIConfig, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index fa7745af8..1f2b836c3 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -28,6 +28,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI from ..types.llms.openai import * import openai +from litellm.types.utils import ProviderField class OpenAIError(Exception): @@ -207,6 +208,25 @@ class MistralEmbeddingConfig: return optional_params +class AzureAIStudioConfig: + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Azure AI Studio API Key.", + field_value="zEJ...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Azure AI Studio API Base.", + field_value="https://Mistral-serverless.", + ), + ] + + class DeepInfraConfig: """ Reference: https://deepinfra.com/docs/advanced/openai_api diff --git a/litellm/utils.py b/litellm/utils.py index 036e74767..58130b266 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7221,6 +7221,9 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_required_params() + elif custom_llm_provider == "azure_ai": + return litellm.AzureAIStudioConfig().get_required_params() + else: return [] diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index d16d8db13..e18f4233e 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -139,6 +139,7 @@ interface ProviderSettings { enum Providers { OpenAI = "OpenAI", Azure = "Azure", + Azure_AI_Studio = "Azure AI Studio", Anthropic = "Anthropic", Google_AI_Studio = "Google AI Studio", Bedrock = "Amazon Bedrock", @@ -151,6 +152,7 @@ enum Providers { const provider_map: Record = { OpenAI: "openai", Azure: "azure", + Azure_AI_Studio: "azure_ai", Anthropic: "anthropic", Google_AI_Studio: "gemini", Bedrock: "bedrock", @@ -158,6 +160,7 @@ const provider_map: Record = { Vertex_AI: "vertex_ai", Databricks: "databricks", Ollama: "ollama", + }; const retry_policy_map: Record = { @@ -1245,6 +1248,10 @@ const ModelDashboard: React.FC = ({ return "claude-3-opus"; } else if (selectedProvider == Providers.Google_AI_Studio) { return "gemini-pro"; + } else if (selectedProvider == Providers.Azure_AI_Studio) { + return "azure_ai/command-r-plus"; + } else if (selectedProvider == Providers.Azure) { + return "azure/my-deployment"; } else { return "gpt-3.5-turbo"; }