feat(ui/model_dashboard.tsx): add databricks models via admin ui

This commit is contained in:
Krrish Dholakia 2024-05-23 20:28:54 -07:00
parent c14584722e
commit f04e4b921b
11 changed files with 263 additions and 9 deletions

View file

@ -730,6 +730,7 @@ from .utils import (
ModelResponse,
ImageResponse,
ImageObject,
get_provider_fields,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig

View file

@ -19,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import ProviderField
class DatabricksError(Exception):
@ -76,6 +77,23 @@ class DatabricksConfig:
and v is not None
}
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="api_key",
field_type="string",
field_description="Your Databricks API Key.",
field_value="dapi...",
),
ProviderField(
field_name="api_base",
field_type="string",
field_description="Your Databricks API Base.",
field_value="https://adb-..",
),
]
def get_supported_openai_params(self):
return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"]

View file

@ -3390,9 +3390,10 @@
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
"supports_function_calling": true
"supports_function_calling": true,
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mistral-7B-Instruct-v0.1"
},
"anyscale/Mixtral-8x7B-Instruct-v0.1": {
"anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 16384,
@ -3400,7 +3401,19 @@
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
"supports_function_calling": true
"supports_function_calling": true,
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x7B-Instruct-v0.1"
},
"anyscale/mistralai/Mixtral-8x22B-Instruct-v0.1": {
"max_tokens": 65536,
"max_input_tokens": 65536,
"max_output_tokens": 65536,
"input_cost_per_token": 0.00000090,
"output_cost_per_token": 0.00000090,
"litellm_provider": "anyscale",
"mode": "chat",
"supports_function_calling": true,
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x22B-Instruct-v0.1"
},
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
"max_tokens": 16384,
@ -3411,6 +3424,16 @@
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/google/gemma-7b-it": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/google-gemma-7b-it"
},
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
"max_tokens": 4096,
"max_input_tokens": 4096,
@ -3447,6 +3470,36 @@
"litellm_provider": "anyscale",
"mode": "chat"
},
"anyscale/codellama/CodeLlama-70b-Instruct-hf": {
"max_tokens": 4096,
"max_input_tokens": 4096,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000001,
"output_cost_per_token": 0.000001,
"litellm_provider": "anyscale",
"mode": "chat",
"source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/codellama-CodeLlama-70b-Instruct-hf"
},
"anyscale/meta-llama/Meta-Llama-3-8B-Instruct": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000015,
"output_cost_per_token": 0.00000015,
"litellm_provider": "anyscale",
"mode": "chat",
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-8B-Instruct"
},
"anyscale/meta-llama/Meta-Llama-3-70B-Instruct": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.00000100,
"output_cost_per_token": 0.00000100,
"litellm_provider": "anyscale",
"mode": "chat",
"source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-70B-Instruct"
},
"cloudflare/@cf/meta/llama-2-7b-chat-fp16": {
"max_tokens": 3072,
"max_input_tokens": 3072,

View file

@ -5,6 +5,7 @@ from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
def hash_token(token: str):
@ -364,6 +365,11 @@ class ModelInfo(LiteLLMBase):
return values
class ProviderInfo(LiteLLMBase):
name: str
fields: List[ProviderField]
class BlockUsers(LiteLLMBase):
user_ids: List[str] # required

View file

@ -9364,6 +9364,36 @@ async def delete_model(model_info: ModelInfoDelete):
)
@router.get(
"/model/settings",
description="Returns provider name, description, and required parameters for each provider",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def model_settings():
"""
Used by UI to generate 'model add' page
{
field_name=field_name,
field_type=allowed_args[field_name]["type"], # string/int
field_description=field_info.description or "", # human-friendly description
field_value=general_settings.get(field_name, None), # example value
}
"""
returned_list = []
for provider in litellm.provider_list:
returned_list.append(
ProviderInfo(
name=provider,
fields=litellm.get_provider_fields(custom_llm_provider=provider),
)
)
return returned_list
#### EXPERIMENTAL QUEUING ####
async def _litellm_chat_completions_worker(data, user_api_key_dict):
"""

View file

@ -3342,10 +3342,14 @@ class Router:
non_default_params = litellm.utils.get_non_default_params(
passed_params=request_kwargs
)
special_params = ["response_object"]
# check if all params are supported
for k, v in non_default_params.items():
if k not in supported_openai_params:
if k not in supported_openai_params and k in special_params:
# if not -> invalid model
verbose_router_logger.debug(
f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}"
)
invalid_model_indices.append(idx)
if len(invalid_model_indices) == len(_returned_deployments):
@ -3420,6 +3424,7 @@ class Router:
## get healthy deployments
### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [
@ -3510,7 +3515,7 @@ class Router:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
)
if (

View file

@ -1,6 +1,14 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
from typing import List, Optional, Union, Dict, Tuple, Literal
from typing_extensions import TypedDict
class CostPerToken(TypedDict):
input_cost_per_token: float
output_cost_per_token: float
class ProviderField(TypedDict):
field_name: str
field_type: Literal["string"]
field_description: str
field_value: str

View file

@ -34,7 +34,7 @@ from dataclasses import (
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache
from litellm.types.utils import CostPerToken
from litellm.types.utils import CostPerToken, ProviderField
oidc_cache = DualCache()
@ -7327,6 +7327,15 @@ def load_test_model(
}
def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
"""Return the fields required for each provider"""
if custom_llm_provider == "databricks":
return litellm.DatabricksConfig().get_required_params()
else:
return []
def validate_environment(model: Optional[str] = None) -> dict:
"""
Checks if the environment variables are valid for the given model.

View file

@ -0,0 +1,47 @@
import React from "react";
import { Form, Input } from "antd";
import { TextInput } from "@tremor/react";
interface Field {
field_name: string;
field_type: string;
field_description: string;
field_value: string;
}
interface DynamicFieldsProps {
fields: Field[];
selectedProvider: string;
}
const getPlaceholder = (provider: string) => {
// Implement your placeholder logic based on the provider
return `Enter your ${provider} value here`;
};
const DynamicFields: React.FC<DynamicFieldsProps> = ({
fields,
selectedProvider,
}) => {
if (fields.length === 0) return null;
return (
<>
{fields.map((field) => (
<Form.Item
key={field.field_name}
rules={[{ required: true, message: "Required" }]}
label={field.field_name
.replace(/_/g, " ")
.replace(/\b\w/g, (char) => char.toUpperCase())}
name={field.field_name}
tooltip={field.field_description}
className="mb-2"
>
<TextInput placeholder={field.field_value} type="password" />
</Form.Item>
))}
</>
);
};
export default DynamicFields;

View file

@ -48,6 +48,7 @@ import {
modelMetricsSlowResponsesCall,
getCallbacksCall,
setCallbacksCall,
modelSettingsCall,
} from "./networking";
import { BarChart, AreaChart } from "@tremor/react";
import {
@ -84,6 +85,7 @@ import { UploadOutlined } from "@ant-design/icons";
import type { UploadProps } from "antd";
import { Upload } from "antd";
import TimeToFirstToken from "./model_metrics/time_to_first_token";
import DynamicFields from "./model_add/dynamic_form";
interface ModelDashboardProps {
accessToken: string | null;
token: string | null;
@ -107,14 +109,27 @@ interface RetryPolicyObject {
//["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"]
interface ProviderFields {
field_name: string;
field_type: string;
field_description: string;
field_value: string;
}
interface ProviderSettings {
name: string;
fields: ProviderFields[];
}
enum Providers {
OpenAI = "OpenAI",
Azure = "Azure",
Anthropic = "Anthropic",
Google_AI_Studio = "Gemini (Google AI Studio)",
Google_AI_Studio = "Google AI Studio",
Bedrock = "Amazon Bedrock",
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)",
Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)",
Databricks = "Databricks",
}
const provider_map: Record<string, string> = {
@ -125,6 +140,7 @@ const provider_map: Record<string, string> = {
Bedrock: "bedrock",
OpenAI_Compatible: "openai",
Vertex_AI: "vertex_ai",
Databricks: "databricks",
};
const retry_policy_map: Record<string, string> = {
@ -247,6 +263,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
isNaN(Number(key))
);
const [providerSettings, setProviderSettings] = useState<ProviderSettings[]>(
[]
);
const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI");
const [healthCheckResponse, setHealthCheckResponse] = useState<string>("");
const [editModalVisible, setEditModalVisible] = useState<boolean>(false);
@ -514,6 +533,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
}
const fetchData = async () => {
try {
const _providerSettings = await modelSettingsCall(accessToken);
setProviderSettings(_providerSettings);
// Replace with your actual API call for model data
const modelDataResponse = await modelInfoCall(
accessToken,
@ -945,6 +967,18 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
console.log(`selectedProvider: ${selectedProvider}`);
console.log(`providerModels.length: ${providerModels.length}`);
const providerKey = Object.keys(Providers).find(
(key) => (Providers as { [index: string]: any })[key] === selectedProvider
);
let dynamicProviderForm: ProviderSettings | undefined = undefined;
if (providerKey) {
dynamicProviderForm = providerSettings.find(
(provider) => provider.name === provider_map[providerKey]
);
}
return (
<div style={{ width: "100%", height: "100%" }}>
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
@ -1278,6 +1312,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
))}
</Select>
</Form.Item>
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="Public Model Name"
@ -1340,8 +1375,16 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Text>
</Col>
</Row>
{dynamicProviderForm !== undefined &&
dynamicProviderForm.fields.length > 0 && (
<DynamicFields
fields={dynamicProviderForm.fields}
selectedProvider={dynamicProviderForm.name}
/>
)}
{selectedProvider != Providers.Bedrock &&
selectedProvider != Providers.Vertex_AI && (
selectedProvider != Providers.Vertex_AI &&
dynamicProviderForm === undefined && (
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="API Key"

View file

@ -62,6 +62,40 @@ export const modelCreateCall = async (
}
};
export const modelSettingsCall = async (accessToken: String) => {
/**
* Get all configurable params for setting a model
*/
try {
let url = proxyBaseUrl
? `${proxyBaseUrl}/model/settings`
: `/model/settings`;
//message.info("Requesting model data");
const response = await fetch(url, {
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
message.error(errorData, 10);
throw new Error("Network response was not ok");
}
const data = await response.json();
//message.info("Received model data");
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to get callbacks:", error);
throw error;
}
};
export const modelDeleteCall = async (
accessToken: string,
model_id: string