feat(ui/model_dashboard.tsx): add databricks models via admin ui

This commit is contained in:
Krrish Dholakia 2024-05-23 20:28:54 -07:00
parent c14584722e
commit f04e4b921b
11 changed files with 263 additions and 9 deletions

View file

@ -730,6 +730,7 @@ from .utils import (
ModelResponse, ModelResponse,
ImageResponse, ImageResponse,
ImageObject, ImageObject,
get_provider_fields,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig

View file

@ -19,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import ProviderField
class DatabricksError(Exception): class DatabricksError(Exception):
@ -76,6 +77,23 @@ class DatabricksConfig:
and v is not None and v is not None
} }
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Databricks API Key.",
field_value="dapi...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Databricks API Base.",
field_value="https://adb-..",
),
]
def get_supported_openai_params(self): def get_supported_openai_params(self):
return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"] return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"]

View file

@ -3390,9 +3390,10 @@
"output_cost_per_token": 0.00000015, "output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale", "litellm_provider": "anyscale",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mistral-7B-Instruct-v0.1"
}, },
"anyscale/Mixtral-8x7B-Instruct-v0.1": { "anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 16384, "max_tokens": 16384,
"max_input_tokens": 16384, "max_input_tokens": 16384,
"max_output_tokens": 16384, "max_output_tokens": 16384,
@ -3400,7 +3401,19 @@
"output_cost_per_token": 0.00000015, "output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale", "litellm_provider": "anyscale",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x7B-Instruct-v0.1"
},
"anyscale/mistralai/Mixtral-8x22B-Instruct-v0.1": {
"max_tokens": 65536,
"max_input_tokens": 65536,
"max_output_tokens": 65536,
"input_cost_per_token": 0.00000090,
"output_cost_per_token": 0.00000090,
"litellm_provider": "anyscale",
"mode": "chat",
"supports_function_calling": true,
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x22B-Instruct-v0.1"
}, },
"anyscale/HuggingFaceH4/zephyr-7b-beta": { "anyscale/HuggingFaceH4/zephyr-7b-beta": {
"max_tokens": 16384, "max_tokens": 16384,
@ -3411,6 +3424,16 @@
"litellm_provider": "anyscale", "litellm_provider": "anyscale",
"mode": "chat" "mode": "chat"
}, },
"anyscale/google/gemma-7b-it": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/google-gemma-7b-it"
},
"anyscale/meta-llama/Llama-2-7b-chat-hf": { "anyscale/meta-llama/Llama-2-7b-chat-hf": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,
@ -3447,6 +3470,36 @@
"litellm_provider": "anyscale", "litellm_provider": "anyscale",
"mode": "chat" "mode": "chat"
}, },
"anyscale/codellama/CodeLlama-70b-Instruct-hf": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "anyscale",
"mode": "chat",
"source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/codellama-CodeLlama-70b-Instruct-hf"
},
"anyscale/meta-llama/Meta-Llama-3-8B-Instruct": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-8B-Instruct"
},
"anyscale/meta-llama/Meta-Llama-3-70B-Instruct": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000100,
"output_cost_per_token": 0.00000100,
"litellm_provider": "anyscale",
"mode": "chat",
"source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-70B-Instruct"
},
"cloudflare/@cf/meta/llama-2-7b-chat-fp16": { "cloudflare/@cf/meta/llama-2-7b-chat-fp16": {
"max_tokens": 3072, "max_tokens": 3072,
"max_input_tokens": 3072, "max_input_tokens": 3072,

View file

@ -5,6 +5,7 @@ from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
import uuid, json, sys, os import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
def hash_token(token: str): def hash_token(token: str):
@ -364,6 +365,11 @@ class ModelInfo(LiteLLMBase):
return values return values
class ProviderInfo(LiteLLMBase):
name: str
fields: List[ProviderField]
class BlockUsers(LiteLLMBase): class BlockUsers(LiteLLMBase):
user_ids: List[str] # required user_ids: List[str] # required

View file

@ -9364,6 +9364,36 @@ async def delete_model(model_info: ModelInfoDelete):
) )
@router.get(
"/model/settings",
description="Returns provider name, description, and required parameters for each provider",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def model_settings():
"""
Used by UI to generate 'model add' page
{
field_name=field_name,
field_type=allowed_args[field_name]["type"], # string/int
field_description=field_info.description or "", # human-friendly description
field_value=general_settings.get(field_name, None), # example value
}
"""
returned_list = []
for provider in litellm.provider_list:
returned_list.append(
ProviderInfo(
name=provider,
fields=litellm.get_provider_fields(custom_llm_provider=provider),
)
)
return returned_list
#### EXPERIMENTAL QUEUING #### #### EXPERIMENTAL QUEUING ####
async def _litellm_chat_completions_worker(data, user_api_key_dict): async def _litellm_chat_completions_worker(data, user_api_key_dict):
""" """

View file

@ -3342,10 +3342,14 @@ class Router:
non_default_params = litellm.utils.get_non_default_params( non_default_params = litellm.utils.get_non_default_params(
passed_params=request_kwargs passed_params=request_kwargs
) )
special_params = ["response_object"]
# check if all params are supported # check if all params are supported
for k, v in non_default_params.items(): for k, v in non_default_params.items():
if k not in supported_openai_params: if k not in supported_openai_params and k in special_params:
# if not -> invalid model # if not -> invalid model
verbose_router_logger.debug(
f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}"
)
invalid_model_indices.append(idx) invalid_model_indices.append(idx)
if len(invalid_model_indices) == len(_returned_deployments): if len(invalid_model_indices) == len(_returned_deployments):
@ -3420,6 +3424,7 @@ class Router:
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model] healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead # check if the user sent in a deployment name instead
healthy_deployments = [ healthy_deployments = [
@ -3510,7 +3515,7 @@ class Router:
if _allowed_model_region is None: if _allowed_model_region is None:
_allowed_model_region = "n/a" _allowed_model_region = "n/a"
raise ValueError( raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}" f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
) )
if ( if (

View file

@ -1,6 +1,14 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict from typing import List, Optional, Union, Dict, Tuple, Literal
from typing_extensions import TypedDict
class CostPerToken(TypedDict): class CostPerToken(TypedDict):
input_cost_per_token: float input_cost_per_token: float
output_cost_per_token: float output_cost_per_token: float
class ProviderField(TypedDict):
field_name: str
field_type: Literal["string"]
field_description: str
field_value: str

View file

@ -34,7 +34,7 @@ from dataclasses import (
import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.types.utils import CostPerToken from litellm.types.utils import CostPerToken, ProviderField
oidc_cache = DualCache() oidc_cache = DualCache()
@ -7327,6 +7327,15 @@ def load_test_model(
} }
def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
"""Return the fields required for each provider"""
if custom_llm_provider == "databricks":
return litellm.DatabricksConfig().get_required_params()
else:
return []
def validate_environment(model: Optional[str] = None) -> dict: def validate_environment(model: Optional[str] = None) -> dict:
""" """
Checks if the environment variables are valid for the given model. Checks if the environment variables are valid for the given model.

View file

@ -0,0 +1,47 @@
import React from "react";
import { Form, Input } from "antd";
import { TextInput } from "@tremor/react";
interface Field {
field_name: string;
field_type: string;
field_description: string;
field_value: string;
}
interface DynamicFieldsProps {
fields: Field[];
selectedProvider: string;
}
const getPlaceholder = (provider: string) => {
// Implement your placeholder logic based on the provider
return `Enter your ${provider} value here`;
};
const DynamicFields: React.FC<DynamicFieldsProps> = ({
fields,
selectedProvider,
}) => {
if (fields.length === 0) return null;
return (
<>
{fields.map((field) => (
<Form.Item
key={field.field_name}
rules={[{ required: true, message: "Required" }]}
label={field.field_name
.replace(/_/g, " ")
.replace(/\b\w/g, (char) => char.toUpperCase())}
name={field.field_name}
tooltip={field.field_description}
className="mb-2"
>
<TextInput placeholder={field.field_value} type="password" />
</Form.Item>
))}
</>
);
};
export default DynamicFields;

View file

@ -48,6 +48,7 @@ import {
modelMetricsSlowResponsesCall, modelMetricsSlowResponsesCall,
getCallbacksCall, getCallbacksCall,
setCallbacksCall, setCallbacksCall,
modelSettingsCall,
} from "./networking"; } from "./networking";
import { BarChart, AreaChart } from "@tremor/react"; import { BarChart, AreaChart } from "@tremor/react";
import { import {
@ -84,6 +85,7 @@ import { UploadOutlined } from "@ant-design/icons";
import type { UploadProps } from "antd"; import type { UploadProps } from "antd";
import { Upload } from "antd"; import { Upload } from "antd";
import TimeToFirstToken from "./model_metrics/time_to_first_token"; import TimeToFirstToken from "./model_metrics/time_to_first_token";
import DynamicFields from "./model_add/dynamic_form";
interface ModelDashboardProps { interface ModelDashboardProps {
accessToken: string | null; accessToken: string | null;
token: string | null; token: string | null;
@ -107,14 +109,27 @@ interface RetryPolicyObject {
//["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"] //["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"]
interface ProviderFields {
field_name: string;
field_type: string;
field_description: string;
field_value: string;
}
interface ProviderSettings {
name: string;
fields: ProviderFields[];
}
enum Providers { enum Providers {
OpenAI = "OpenAI", OpenAI = "OpenAI",
Azure = "Azure", Azure = "Azure",
Anthropic = "Anthropic", Anthropic = "Anthropic",
Google_AI_Studio = "Gemini (Google AI Studio)", Google_AI_Studio = "Google AI Studio",
Bedrock = "Amazon Bedrock", Bedrock = "Amazon Bedrock",
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)", OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)",
Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)", Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)",
Databricks = "Databricks",
} }
const provider_map: Record<string, string> = { const provider_map: Record<string, string> = {
@ -125,6 +140,7 @@ const provider_map: Record<string, string> = {
Bedrock: "bedrock", Bedrock: "bedrock",
OpenAI_Compatible: "openai", OpenAI_Compatible: "openai",
Vertex_AI: "vertex_ai", Vertex_AI: "vertex_ai",
Databricks: "databricks",
}; };
const retry_policy_map: Record<string, string> = { const retry_policy_map: Record<string, string> = {
@ -247,6 +263,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
isNaN(Number(key)) isNaN(Number(key))
); );
const [providerSettings, setProviderSettings] = useState<ProviderSettings[]>(
[]
);
const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI"); const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI");
const [healthCheckResponse, setHealthCheckResponse] = useState<string>(""); const [healthCheckResponse, setHealthCheckResponse] = useState<string>("");
const [editModalVisible, setEditModalVisible] = useState<boolean>(false); const [editModalVisible, setEditModalVisible] = useState<boolean>(false);
@ -514,6 +533,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
} }
const fetchData = async () => { const fetchData = async () => {
try { try {
const _providerSettings = await modelSettingsCall(accessToken);
setProviderSettings(_providerSettings);
// Replace with your actual API call for model data // Replace with your actual API call for model data
const modelDataResponse = await modelInfoCall( const modelDataResponse = await modelInfoCall(
accessToken, accessToken,
@ -945,6 +967,18 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
console.log(`selectedProvider: ${selectedProvider}`); console.log(`selectedProvider: ${selectedProvider}`);
console.log(`providerModels.length: ${providerModels.length}`); console.log(`providerModels.length: ${providerModels.length}`);
const providerKey = Object.keys(Providers).find(
(key) => (Providers as { [index: string]: any })[key] === selectedProvider
);
let dynamicProviderForm: ProviderSettings | undefined = undefined;
if (providerKey) {
dynamicProviderForm = providerSettings.find(
(provider) => provider.name === provider_map[providerKey]
);
}
return ( return (
<div style={{ width: "100%", height: "100%" }}> <div style={{ width: "100%", height: "100%" }}>
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2"> <TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
@ -1278,6 +1312,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
))} ))}
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item <Form.Item
rules={[{ required: true, message: "Required" }]} rules={[{ required: true, message: "Required" }]}
label="Public Model Name" label="Public Model Name"
@ -1340,8 +1375,16 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Text> </Text>
</Col> </Col>
</Row> </Row>
{dynamicProviderForm !== undefined &&
dynamicProviderForm.fields.length > 0 && (
<DynamicFields
fields={dynamicProviderForm.fields}
selectedProvider={dynamicProviderForm.name}
/>
)}
{selectedProvider != Providers.Bedrock && {selectedProvider != Providers.Bedrock &&
selectedProvider != Providers.Vertex_AI && ( selectedProvider != Providers.Vertex_AI &&
dynamicProviderForm === undefined && (
<Form.Item <Form.Item
rules={[{ required: true, message: "Required" }]} rules={[{ required: true, message: "Required" }]}
label="API Key" label="API Key"

View file

@ -62,6 +62,40 @@ export const modelCreateCall = async (
} }
}; };
export const modelSettingsCall = async (accessToken: String) => {
/**
* Get all configurable params for setting a model
*/
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/model/settings`
: `/model/settings`;
//message.info("Requesting model data");
const response = await fetch(url, {
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
message.error(errorData, 10);
throw new Error("Network response was not ok");
}
const data = await response.json();
//message.info("Received model data");
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to get callbacks:", error);
throw error;
}
};
export const modelDeleteCall = async ( export const modelDeleteCall = async (
accessToken: string, accessToken: string,
model_id: string model_id: string