forked from phoenix/litellm-mirror
feat(ui/model_dashboard.tsx): add databricks models via admin ui
This commit is contained in:
parent
c14584722e
commit
f04e4b921b
11 changed files with 263 additions and 9 deletions
|
@ -730,6 +730,7 @@ from .utils import (
|
|||
ModelResponse,
|
||||
ImageResponse,
|
||||
ImageObject,
|
||||
get_provider_fields,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
|
@ -19,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||
from litellm.types.utils import ProviderField
|
||||
|
||||
|
||||
class DatabricksError(Exception):
|
||||
|
@ -76,6 +77,23 @@ class DatabricksConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Key.",
|
||||
field_value="dapi...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Base.",
|
||||
field_value="https://adb-..",
|
||||
),
|
||||
]
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["stream", "stop", "temperature", "top_p", "max_tokens", "n"]
|
||||
|
||||
|
|
|
@ -3390,9 +3390,10 @@
|
|||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mistral-7B-Instruct-v0.1"
|
||||
},
|
||||
"anyscale/Mixtral-8x7B-Instruct-v0.1": {
|
||||
"anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 16384,
|
||||
"max_output_tokens": 16384,
|
||||
|
@ -3400,7 +3401,19 @@
|
|||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x7B-Instruct-v0.1"
|
||||
},
|
||||
"anyscale/mistralai/Mixtral-8x22B-Instruct-v0.1": {
|
||||
"max_tokens": 65536,
|
||||
"max_input_tokens": 65536,
|
||||
"max_output_tokens": 65536,
|
||||
"input_cost_per_token": 0.00000090,
|
||||
"output_cost_per_token": 0.00000090,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/mistralai-Mixtral-8x22B-Instruct-v0.1"
|
||||
},
|
||||
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
|
||||
"max_tokens": 16384,
|
||||
|
@ -3411,6 +3424,16 @@
|
|||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/google/gemma-7b-it": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/google-gemma-7b-it"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
@ -3447,6 +3470,36 @@
|
|||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/codellama/CodeLlama-70b-Instruct-hf": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/codellama-CodeLlama-70b-Instruct-hf"
|
||||
},
|
||||
"anyscale/meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"source": "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-8B-Instruct"
|
||||
},
|
||||
"anyscale/meta-llama/Meta-Llama-3-70B-Instruct": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.00000100,
|
||||
"output_cost_per_token": 0.00000100,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat",
|
||||
"source" : "https://docs.anyscale.com/preview/endpoints/text-generation/supported-models/meta-llama-Meta-Llama-3-70B-Instruct"
|
||||
},
|
||||
"cloudflare/@cf/meta/llama-2-7b-chat-fp16": {
|
||||
"max_tokens": 3072,
|
||||
"max_input_tokens": 3072,
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional, List, Union, Dict, Literal, Any
|
|||
from datetime import datetime
|
||||
import uuid, json, sys, os
|
||||
from litellm.types.router import UpdateRouterConfig
|
||||
from litellm.types.utils import ProviderField
|
||||
|
||||
|
||||
def hash_token(token: str):
|
||||
|
@ -364,6 +365,11 @@ class ModelInfo(LiteLLMBase):
|
|||
return values
|
||||
|
||||
|
||||
class ProviderInfo(LiteLLMBase):
|
||||
name: str
|
||||
fields: List[ProviderField]
|
||||
|
||||
|
||||
class BlockUsers(LiteLLMBase):
|
||||
user_ids: List[str] # required
|
||||
|
||||
|
|
|
@ -9364,6 +9364,36 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model/settings",
|
||||
description="Returns provider name, description, and required parameters for each provider",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def model_settings():
|
||||
"""
|
||||
Used by UI to generate 'model add' page
|
||||
{
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"], # string/int
|
||||
field_description=field_info.description or "", # human-friendly description
|
||||
field_value=general_settings.get(field_name, None), # example value
|
||||
}
|
||||
"""
|
||||
|
||||
returned_list = []
|
||||
for provider in litellm.provider_list:
|
||||
returned_list.append(
|
||||
ProviderInfo(
|
||||
name=provider,
|
||||
fields=litellm.get_provider_fields(custom_llm_provider=provider),
|
||||
)
|
||||
)
|
||||
|
||||
return returned_list
|
||||
|
||||
|
||||
#### EXPERIMENTAL QUEUING ####
|
||||
async def _litellm_chat_completions_worker(data, user_api_key_dict):
|
||||
"""
|
||||
|
|
|
@ -3342,10 +3342,14 @@ class Router:
|
|||
non_default_params = litellm.utils.get_non_default_params(
|
||||
passed_params=request_kwargs
|
||||
)
|
||||
special_params = ["response_object"]
|
||||
# check if all params are supported
|
||||
for k, v in non_default_params.items():
|
||||
if k not in supported_openai_params:
|
||||
if k not in supported_openai_params and k in special_params:
|
||||
# if not -> invalid model
|
||||
verbose_router_logger.debug(
|
||||
f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}"
|
||||
)
|
||||
invalid_model_indices.append(idx)
|
||||
|
||||
if len(invalid_model_indices) == len(_returned_deployments):
|
||||
|
@ -3420,6 +3424,7 @@ class Router:
|
|||
## get healthy deployments
|
||||
### get all deployments
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = [
|
||||
|
@ -3510,7 +3515,7 @@ class Router:
|
|||
if _allowed_model_region is None:
|
||||
_allowed_model_region = "n/a"
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||
from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class CostPerToken(TypedDict):
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
||||
|
||||
|
||||
class ProviderField(TypedDict):
|
||||
field_name: str
|
||||
field_type: Literal["string"]
|
||||
field_description: str
|
||||
field_value: str
|
||||
|
|
|
@ -34,7 +34,7 @@ from dataclasses import (
|
|||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.caching import DualCache
|
||||
from litellm.types.utils import CostPerToken
|
||||
from litellm.types.utils import CostPerToken, ProviderField
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
@ -7327,6 +7327,15 @@ def load_test_model(
|
|||
}
|
||||
|
||||
|
||||
def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
|
||||
"""Return the fields required for each provider"""
|
||||
|
||||
if custom_llm_provider == "databricks":
|
||||
return litellm.DatabricksConfig().get_required_params()
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def validate_environment(model: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Checks if the environment variables are valid for the given model.
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
import React from "react";
|
||||
import { Form, Input } from "antd";
|
||||
import { TextInput } from "@tremor/react";
|
||||
interface Field {
|
||||
field_name: string;
|
||||
field_type: string;
|
||||
field_description: string;
|
||||
field_value: string;
|
||||
}
|
||||
|
||||
interface DynamicFieldsProps {
|
||||
fields: Field[];
|
||||
selectedProvider: string;
|
||||
}
|
||||
|
||||
const getPlaceholder = (provider: string) => {
|
||||
// Implement your placeholder logic based on the provider
|
||||
return `Enter your ${provider} value here`;
|
||||
};
|
||||
|
||||
const DynamicFields: React.FC<DynamicFieldsProps> = ({
|
||||
fields,
|
||||
selectedProvider,
|
||||
}) => {
|
||||
if (fields.length === 0) return null;
|
||||
|
||||
return (
|
||||
<>
|
||||
{fields.map((field) => (
|
||||
<Form.Item
|
||||
key={field.field_name}
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
label={field.field_name
|
||||
.replace(/_/g, " ")
|
||||
.replace(/\b\w/g, (char) => char.toUpperCase())}
|
||||
name={field.field_name}
|
||||
tooltip={field.field_description}
|
||||
className="mb-2"
|
||||
>
|
||||
<TextInput placeholder={field.field_value} type="password" />
|
||||
</Form.Item>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default DynamicFields;
|
|
@ -48,6 +48,7 @@ import {
|
|||
modelMetricsSlowResponsesCall,
|
||||
getCallbacksCall,
|
||||
setCallbacksCall,
|
||||
modelSettingsCall,
|
||||
} from "./networking";
|
||||
import { BarChart, AreaChart } from "@tremor/react";
|
||||
import {
|
||||
|
@ -84,6 +85,7 @@ import { UploadOutlined } from "@ant-design/icons";
|
|||
import type { UploadProps } from "antd";
|
||||
import { Upload } from "antd";
|
||||
import TimeToFirstToken from "./model_metrics/time_to_first_token";
|
||||
import DynamicFields from "./model_add/dynamic_form";
|
||||
interface ModelDashboardProps {
|
||||
accessToken: string | null;
|
||||
token: string | null;
|
||||
|
@ -107,14 +109,27 @@ interface RetryPolicyObject {
|
|||
|
||||
//["OpenAI", "Azure OpenAI", "Anthropic", "Gemini (Google AI Studio)", "Amazon Bedrock", "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"]
|
||||
|
||||
interface ProviderFields {
|
||||
field_name: string;
|
||||
field_type: string;
|
||||
field_description: string;
|
||||
field_value: string;
|
||||
}
|
||||
|
||||
interface ProviderSettings {
|
||||
name: string;
|
||||
fields: ProviderFields[];
|
||||
}
|
||||
|
||||
enum Providers {
|
||||
OpenAI = "OpenAI",
|
||||
Azure = "Azure",
|
||||
Anthropic = "Anthropic",
|
||||
Google_AI_Studio = "Gemini (Google AI Studio)",
|
||||
Google_AI_Studio = "Google AI Studio",
|
||||
Bedrock = "Amazon Bedrock",
|
||||
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)",
|
||||
Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)",
|
||||
Databricks = "Databricks",
|
||||
}
|
||||
|
||||
const provider_map: Record<string, string> = {
|
||||
|
@ -125,6 +140,7 @@ const provider_map: Record<string, string> = {
|
|||
Bedrock: "bedrock",
|
||||
OpenAI_Compatible: "openai",
|
||||
Vertex_AI: "vertex_ai",
|
||||
Databricks: "databricks",
|
||||
};
|
||||
|
||||
const retry_policy_map: Record<string, string> = {
|
||||
|
@ -247,6 +263,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
isNaN(Number(key))
|
||||
);
|
||||
|
||||
const [providerSettings, setProviderSettings] = useState<ProviderSettings[]>(
|
||||
[]
|
||||
);
|
||||
const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI");
|
||||
const [healthCheckResponse, setHealthCheckResponse] = useState<string>("");
|
||||
const [editModalVisible, setEditModalVisible] = useState<boolean>(false);
|
||||
|
@ -514,6 +533,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
}
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
const _providerSettings = await modelSettingsCall(accessToken);
|
||||
setProviderSettings(_providerSettings);
|
||||
|
||||
// Replace with your actual API call for model data
|
||||
const modelDataResponse = await modelInfoCall(
|
||||
accessToken,
|
||||
|
@ -945,6 +967,18 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
|
||||
console.log(`selectedProvider: ${selectedProvider}`);
|
||||
console.log(`providerModels.length: ${providerModels.length}`);
|
||||
|
||||
const providerKey = Object.keys(Providers).find(
|
||||
(key) => (Providers as { [index: string]: any })[key] === selectedProvider
|
||||
);
|
||||
|
||||
let dynamicProviderForm: ProviderSettings | undefined = undefined;
|
||||
if (providerKey) {
|
||||
dynamicProviderForm = providerSettings.find(
|
||||
(provider) => provider.name === provider_map[providerKey]
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ width: "100%", height: "100%" }}>
|
||||
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||
|
@ -1278,6 +1312,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
label="Public Model Name"
|
||||
|
@ -1340,8 +1375,16 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
|
|||
</Text>
|
||||
</Col>
|
||||
</Row>
|
||||
{dynamicProviderForm !== undefined &&
|
||||
dynamicProviderForm.fields.length > 0 && (
|
||||
<DynamicFields
|
||||
fields={dynamicProviderForm.fields}
|
||||
selectedProvider={dynamicProviderForm.name}
|
||||
/>
|
||||
)}
|
||||
{selectedProvider != Providers.Bedrock &&
|
||||
selectedProvider != Providers.Vertex_AI && (
|
||||
selectedProvider != Providers.Vertex_AI &&
|
||||
dynamicProviderForm === undefined && (
|
||||
<Form.Item
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
label="API Key"
|
||||
|
|
|
@ -62,6 +62,40 @@ export const modelCreateCall = async (
|
|||
}
|
||||
};
|
||||
|
||||
export const modelSettingsCall = async (accessToken: String) => {
|
||||
/**
|
||||
* Get all configurable params for setting a model
|
||||
*/
|
||||
try {
|
||||
let url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/model/settings`
|
||||
: `/model/settings`;
|
||||
|
||||
//message.info("Requesting model data");
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
message.error(errorData, 10);
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
//message.info("Received model data");
|
||||
return data;
|
||||
// Handle success - you might want to update some state or UI based on the created key
|
||||
} catch (error) {
|
||||
console.error("Failed to get callbacks:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const modelDeleteCall = async (
|
||||
accessToken: string,
|
||||
model_id: string
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue