diff --git a/litellm/__init__.py b/litellm/__init__.py index 7e0f22e8f..d11242b1c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -730,6 +730,7 @@ from .utils import ( ModelResponse, ImageResponse, ImageObject, + get_provider_fields, ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index b306d425e..7b2013710 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -19,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM import httpx # type: ignore from litellm.types.llms.databricks import GenericStreamingChunk +from litellm.types.utils import ProviderField class DatabricksError(Exception): @@ -76,6 +77,23 @@ class DatabricksConfig: and v is not None } + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Databricks API Key.", + field_value="dapi...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Databricks API Base.", + field_value="https://adb-..", + ), + ] + def get_supported_openai_params(self): return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"] diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 748a3f6ae..aab9c9af1 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -3390,9 +3390,10 @@ "output_cost_per_token": 0.00000015, "litellm_provider": "anyscale", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mistral-7B-Instruct-v0.1" }, - "anyscale/Mixtral-8x7B-Instruct-v0.1": { + "anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1": { "max_tokens": 16384, "max_input_tokens": 16384, "max_output_tokens": 16384, @@ -3400,7 +3401,19 @@ "output_cost_per_token": 0.00000015, "litellm_provider": "anyscale", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x7B-Instruct-v0.1" + }, + "anyscale/mistralai/Mixtral-8x22B-Instruct-v0.1": { + "max_tokens": 65536, + "max_input_tokens": 65536, + "max_output_tokens": 65536, + "input_cost_per_token": 0.00000090, + "output_cost_per_token": 0.00000090, + "litellm_provider": "anyscale", + "mode": "chat", + "supports_function_calling": true, + "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x22B-Instruct-v0.1" }, "anyscale/HuggingFaceH4/zephyr-7b-beta": { "max_tokens": 16384, @@ -3411,6 +3424,16 @@ "litellm_provider": "anyscale", "mode": "chat" }, + "anyscale/google/gemma-7b-it": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 8192, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000015, + "litellm_provider": "anyscale", + "mode": "chat", + "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/google-gemma-7b-it" + }, "anyscale/meta-llama/Llama-2-7b-chat-hf": { "max_tokens": 4096, "max_input_tokens": 4096, @@ -3447,6 +3470,36 @@ "litellm_provider": "anyscale", "mode": "chat" }, + "anyscale/codellama/CodeLlama-70b-Instruct-hf": { + "max_tokens": 4096, + "max_input_tokens": 4096, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "anyscale", + "mode": "chat", + "source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/codellama-CodeLlama-70b-Instruct-hf" + }, + "anyscale/meta-llama/Meta-Llama-3-8B-Instruct": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 8192, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000015, + "litellm_provider": "anyscale", + "mode": "chat", + "source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-8B-Instruct" + }, + "anyscale/meta-llama/Meta-Llama-3-70B-Instruct": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 8192, + "input_cost_per_token": 0.00000100, + "output_cost_per_token": 0.00000100, + "litellm_provider": "anyscale", + "mode": "chat", + "source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-70B-Instruct" + }, "cloudflare/@cf/meta/llama-2-7b-chat-fp16": { "max_tokens": 3072, "max_input_tokens": 3072, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 148f2d11c..6081d8fba 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -5,6 +5,7 @@ from typing import Optional, List, Union, Dict, Literal, Any from datetime import datetime import uuid, json, sys, os from litellm.types.router import UpdateRouterConfig +from litellm.types.utils import ProviderField def hash_token(token: str): @@ -364,6 +365,11 @@ class ModelInfo(LiteLLMBase): return values +class ProviderInfo(LiteLLMBase): + name: str + fields: List[ProviderField] + + class BlockUsers(LiteLLMBase): user_ids: List[str] # required diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0edf0c726..d9c85623a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -9364,6 +9364,36 @@ async def delete_model(model_info: ModelInfoDelete): ) +@router.get( + "/model/settings", + description="Returns provider name, description, and required parameters for each provider", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def model_settings(): + """ + Used by UI to generate 'model add' page + { + field_name=field_name, + field_type=allowed_args[field_name]["type"], # string/int + field_description=field_info.description or "", # human-friendly description + field_value=general_settings.get(field_name, None), # example value + } + """ + + returned_list = [] + for provider in litellm.provider_list: + returned_list.append( + ProviderInfo( + name=provider, + fields=litellm.get_provider_fields(custom_llm_provider=provider), + ) + ) + + return returned_list + + #### EXPERIMENTAL QUEUING #### async def _litellm_chat_completions_worker(data, user_api_key_dict): """ diff --git a/litellm/router.py b/litellm/router.py index bed72bfaa..131ddd692 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3342,10 +3342,14 @@ class Router: non_default_params = litellm.utils.get_non_default_params( passed_params=request_kwargs ) + special_params = ["response_object"] # check if all params are supported for k, v in non_default_params.items(): - if k not in supported_openai_params: + if k not in supported_openai_params and k in special_params: # if not -> invalid model + verbose_router_logger.debug( + f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}" + ) invalid_model_indices.append(idx) if len(invalid_model_indices) == len(_returned_deployments): @@ -3420,6 +3424,7 @@ class Router: ## get healthy deployments ### get all deployments healthy_deployments = [m for m in self.model_list if m["model_name"] == model] + if len(healthy_deployments) == 0: # check if the user sent in a deployment name instead healthy_deployments = [ @@ -3510,7 +3515,7 @@ class Router: if _allowed_model_region is None: _allowed_model_region = "n/a" raise ValueError( - f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}" + f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}" ) if ( diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 10272c629..21823cc1f 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1,6 +1,14 @@ -from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict +from typing import List, Optional, Union, Dict, Tuple, Literal +from typing_extensions import TypedDict class CostPerToken(TypedDict): input_cost_per_token: float output_cost_per_token: float + + +class ProviderField(TypedDict): + field_name: str + field_type: Literal["string"] + field_description: str + field_value: str diff --git a/litellm/utils.py b/litellm/utils.py index 33dfb261e..0f2a46f68 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -34,7 +34,7 @@ from dataclasses import ( import litellm._service_logger # for storing API inputs, outputs, and metadata from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.caching import DualCache -from litellm.types.utils import CostPerToken +from litellm.types.utils import CostPerToken, ProviderField oidc_cache = DualCache() @@ -7327,6 +7327,15 @@ def load_test_model( } +def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: + """Return the fields required for each provider""" + + if custom_llm_provider == "databricks": + return litellm.DatabricksConfig().get_required_params() + else: + return [] + + def validate_environment(model: Optional[str] = None) -> dict: """ Checks if the environment variables are valid for the given model. diff --git a/ui/litellm-dashboard/src/components/model_add/dynamic_form.tsx b/ui/litellm-dashboard/src/components/model_add/dynamic_form.tsx new file mode 100644 index 000000000..aef305908 --- /dev/null +++ b/ui/litellm-dashboard/src/components/model_add/dynamic_form.tsx @@ -0,0 +1,47 @@ +import React from "react"; +import { Form, Input } from "antd"; +import { TextInput } from "@tremor/react"; +interface Field { + field_name: string; + field_type: string; + field_description: string; + field_value: string; +} + +interface DynamicFieldsProps { + fields: Field[]; + selectedProvider: string; +} + +const getPlaceholder = (provider: string) => { + // Implement your placeholder logic based on the provider + return `Enter your ${provider} value here`; +}; + +const DynamicFields: React.FC = ({ + fields, + selectedProvider, +}) => { + if (fields.length === 0) return null; + + return ( + <> + {fields.map((field) => ( + char.toUpperCase())} + name={field.field_name} + tooltip={field.field_description} + className="mb-2" + > + + + ))} + + ); +}; + +export default DynamicFields; diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index 34f0a81c4..89061d916 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -48,6 +48,7 @@ import { modelMetricsSlowResponsesCall, getCallbacksCall, setCallbacksCall, + modelSettingsCall, } from "./networking"; import { BarChart, AreaChart } from "@tremor/react"; import { @@ -84,6 +85,7 @@ import { UploadOutlined } from "@ant-design/icons"; import type { UploadProps } from "antd"; import { Upload } from "antd"; import TimeToFirstToken from "./model_metrics/time_to_first_token"; +import DynamicFields from "./model_add/dynamic_form"; interface ModelDashboardProps { accessToken: string | null; token: string | null; @@ -107,14 +109,27 @@ interface RetryPolicyObject { //["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"] +interface ProviderFields { + field_name: string; + field_type: string; + field_description: string; + field_value: string; +} + +interface ProviderSettings { + name: string; + fields: ProviderFields[]; +} + enum Providers { OpenAI = "OpenAI", Azure = "Azure", Anthropic = "Anthropic", - Google_AI_Studio = "Gemini (Google AI Studio)", + Google_AI_Studio = "Google AI Studio", Bedrock = "Amazon Bedrock", OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)", Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)", + Databricks = "Databricks", } const provider_map: Record = { @@ -125,6 +140,7 @@ const provider_map: Record = { Bedrock: "bedrock", OpenAI_Compatible: "openai", Vertex_AI: "vertex_ai", + Databricks: "databricks", }; const retry_policy_map: Record = { @@ -247,6 +263,9 @@ const ModelDashboard: React.FC = ({ isNaN(Number(key)) ); + const [providerSettings, setProviderSettings] = useState( + [] + ); const [selectedProvider, setSelectedProvider] = useState("OpenAI"); const [healthCheckResponse, setHealthCheckResponse] = useState(""); const [editModalVisible, setEditModalVisible] = useState(false); @@ -514,6 +533,9 @@ const ModelDashboard: React.FC = ({ } const fetchData = async () => { try { + const _providerSettings = await modelSettingsCall(accessToken); + setProviderSettings(_providerSettings); + // Replace with your actual API call for model data const modelDataResponse = await modelInfoCall( accessToken, @@ -945,6 +967,18 @@ const ModelDashboard: React.FC = ({ console.log(`selectedProvider: ${selectedProvider}`); console.log(`providerModels.length: ${providerModels.length}`); + + const providerKey = Object.keys(Providers).find( + (key) => (Providers as { [index: string]: any })[key] === selectedProvider + ); + + let dynamicProviderForm: ProviderSettings | undefined = undefined; + if (providerKey) { + dynamicProviderForm = providerSettings.find( + (provider) => provider.name === provider_map[providerKey] + ); + } + return (
@@ -1278,6 +1312,7 @@ const ModelDashboard: React.FC = ({ ))} + = ({ + {dynamicProviderForm !== undefined && + dynamicProviderForm.fields.length > 0 && ( + + )} {selectedProvider != Providers.Bedrock && - selectedProvider != Providers.Vertex_AI && ( + selectedProvider != Providers.Vertex_AI && + dynamicProviderForm === undefined && ( { + /** + * Get all configurable params for setting a model + */ + try { + let url = proxyBaseUrl + ? `${proxyBaseUrl}/model/settings` + : `/model/settings`; + + //message.info("Requesting model data"); + const response = await fetch(url, { + method: "GET", + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorData = await response.text(); + message.error(errorData, 10); + throw new Error("Network response was not ok"); + } + + const data = await response.json(); + //message.info("Received model data"); + return data; + // Handle success - you might want to update some state or UI based on the created key + } catch (error) { + console.error("Failed to get callbacks:", error); + throw error; + } +}; + export const modelDeleteCall = async ( accessToken: string, model_id: string