Merge pull request #4167 from BerriAI/litellm_add_azure_ai_litellm_ui

[UI] add Azure AI studio models on UI
This commit is contained in:
Ishaan Jaff 2024-06-12 20:33:14 -07:00 committed by GitHub
commit 9a42ddfd6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 31 additions and 0 deletions

View file

@ -789,6 +789,7 @@ from .llms.openai import (
MistralConfig, MistralConfig,
MistralEmbeddingConfig, MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
AzureAIStudioConfig,
) )
from .llms.azure import ( from .llms.azure import (
AzureOpenAIConfig, AzureOpenAIConfig,

View file

@ -28,6 +28,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import * from ..types.llms.openai import *
import openai import openai
from litellm.types.utils import ProviderField
class OpenAIError(Exception): class OpenAIError(Exception):
@ -207,6 +208,25 @@ class MistralEmbeddingConfig:
return optional_params return optional_params
class AzureAIStudioConfig:
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Azure AI Studio API Key.",
field_value="zEJ...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Azure AI Studio API Base.",
field_value="https://Mistral-serverless.",
),
]
class DeepInfraConfig: class DeepInfraConfig:
""" """
Reference: https://deepinfra.com/docs/advanced/openai_api Reference: https://deepinfra.com/docs/advanced/openai_api

View file

@ -7221,6 +7221,9 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_required_params() return litellm.OllamaConfig().get_required_params()
elif custom_llm_provider == "azure_ai":
return litellm.AzureAIStudioConfig().get_required_params()
else: else:
return [] return []

View file

@ -139,6 +139,7 @@ interface ProviderSettings {
enum Providers { enum Providers {
OpenAI = "OpenAI", OpenAI = "OpenAI",
Azure = "Azure", Azure = "Azure",
Azure_AI_Studio = "Azure AI Studio",
Anthropic = "Anthropic", Anthropic = "Anthropic",
Google_AI_Studio = "Google AI Studio", Google_AI_Studio = "Google AI Studio",
Bedrock = "Amazon Bedrock", Bedrock = "Amazon Bedrock",
@ -151,6 +152,7 @@ enum Providers {
const provider_map: Record<string, string> = { const provider_map: Record<string, string> = {
OpenAI: "openai", OpenAI: "openai",
Azure: "azure", Azure: "azure",
Azure_AI_Studio: "azure_ai",
Anthropic: "anthropic", Anthropic: "anthropic",
Google_AI_Studio: "gemini", Google_AI_Studio: "gemini",
Bedrock: "bedrock", Bedrock: "bedrock",
@ -158,6 +160,7 @@ const provider_map: Record<string, string> = {
Vertex_AI: "vertex_ai", Vertex_AI: "vertex_ai",
Databricks: "databricks", Databricks: "databricks",
Ollama: "ollama", Ollama: "ollama",
}; };
const retry_policy_map: Record<string, string> = { const retry_policy_map: Record<string, string> = {
@ -1245,6 +1248,10 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
return "claude-3-opus"; return "claude-3-opus";
} else if (selectedProvider == Providers.Google_AI_Studio) { } else if (selectedProvider == Providers.Google_AI_Studio) {
return "gemini-pro"; return "gemini-pro";
} else if (selectedProvider == Providers.Azure_AI_Studio) {
return "azure_ai/command-r-plus";
} else if (selectedProvider == Providers.Azure) {
return "azure/my-deployment";
} else { } else {
return "gpt-3.5-turbo"; return "gpt-3.5-turbo";
} }