build(ui): add vertex ai models via ui

This commit is contained in:
Krrish Dholakia 2024-04-15 15:59:15 -07:00
parent 0d2a75d301
commit e4bcc51e44
33 changed files with 131 additions and 92 deletions

View file

@ -36,6 +36,9 @@ import { Typography } from "antd";
import TextArea from "antd/es/input/TextArea";
import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon } from "@heroicons/react/outline";
const { Title: Title2, Link } = Typography;
import { UploadOutlined } from '@ant-design/icons';
import type { UploadProps } from 'antd';
import { Upload } from 'antd';
interface ModelDashboardProps {
accessToken: string | null;
@ -52,7 +55,8 @@ enum Providers {
Anthropic = "Anthropic",
Google_AI_Studio = "Gemini (Google AI Studio)",
Bedrock = "Amazon Bedrock",
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)"
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)",
Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)"
}
const provider_map: Record <string, string> = {
@ -61,9 +65,11 @@ const provider_map: Record <string, string> = {
"Anthropic": "anthropic",
"Google_AI_Studio": "gemini",
"Bedrock": "bedrock",
"OpenAI_Compatible": "openai"
"OpenAI_Compatible": "openai",
"Vertex_AI": "vertex_ai"
};
const ModelDashboard: React.FC<ModelDashboardProps> = ({
accessToken,
token,
@ -78,12 +84,41 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
const [providerModels, setProviderModels] = useState<Array<string>>([]); // Explicitly typing providerModels as a string array
const providers: Providers[] = [Providers.OpenAI, Providers.Azure, Providers.Anthropic, Providers.Google_AI_Studio, Providers.Bedrock, Providers.OpenAI_Compatible]
const providers = Object.values(Providers).filter(key => isNaN(Number(key)));
const [selectedProvider, setSelectedProvider] = useState<String>("OpenAI");
const [healthCheckResponse, setHealthCheckResponse] = useState<string>('');
const props: UploadProps = {
name: 'file',
accept: '.json',
beforeUpload: file => {
if (file.type === 'application/json') {
const reader = new FileReader();
reader.onload = (e) => {
if (e.target) {
const jsonStr = e.target.result as string;
form.setFieldsValue({ vertex_credentials: jsonStr });
}
};
reader.readAsText(file);
}
// Prevent upload
return false;
},
onChange(info) {
if (info.file.status !== 'uploading') {
console.log(info.file, info.fileList);
}
if (info.file.status === 'done') {
message.success(`${info.file.name} file uploaded successfully`);
} else if (info.file.status === 'error') {
message.error(`${info.file.name} file upload failed.`);
}
},
};
useEffect(() => {
if (!accessToken || !token || !userRole || !userID) {
@ -233,20 +268,29 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
const setProviderModelsFn = (provider: string) => {
console.log(`received provider string: ${provider}`)
const providerEnumValue = Providers[provider as keyof typeof Providers];
console.log(`received providerEnumValue: ${providerEnumValue}`)
const mappingResult = provider_map[providerEnumValue]; // Get the corresponding value from the mapping
console.log(`mappingResult: ${mappingResult}`)
let _providerModels: Array<string> = []
if (typeof modelMap === 'object') {
Object.entries(modelMap).forEach(([key, value]) => {
if (value !== null && typeof value === 'object' && "litellm_provider" in value && value["litellm_provider"] === mappingResult) {
_providerModels.push(key);
}
});
const providerKey = Object.keys(Providers).find(key => (Providers as {[index: string]: any})[key] === provider);
if (providerKey) {
const mappingResult = provider_map[providerKey]; // Get the corresponding value from the mapping
console.log(`mappingResult: ${mappingResult}`)
let _providerModels: Array<string> = []
if (typeof modelMap === 'object') {
Object.entries(modelMap).forEach(([key, value]) => {
if (
value !== null
&& typeof value === 'object'
&& "litellm_provider" in (value as object)
&& (
(value as any)["litellm_provider"] === mappingResult
|| (value as any)["litellm_provider"].includes(mappingResult)
)) {
_providerModels.push(key);
}
});
}
setProviderModels(_providerModels)
console.log(`providerModels: ${providerModels}`);
}
setProviderModels(_providerModels)
console.log(`providerModels: ${providerModels}`);
}
const runHealthCheck = async () => {
@ -349,6 +393,20 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
}
}
const getPlaceholder = (selectedProvider: string): string => {
if (selectedProvider === Providers.Vertex_AI) {
return 'gemini-pro';
} else if (selectedProvider == Providers.Anthropic) {
return 'claude-3-opus'
} else if (selectedProvider == Providers.Bedrock) {
return 'claude-3-opus'
} else if (selectedProvider == Providers.Google_AI_Studio) {
return 'gemini-pro'
} else {
return 'gpt-3.5-turbo';
}
};
const handleOk = () => {
form
.validateFields()
@ -455,9 +513,6 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Grid>
</TabPanel>
<TabPanel className="h-full">
{/* <Card className="mx-auto max-w-lg flex flex-col h-[60vh] space-between">
</Card> */}
<Title2 level={2}>Add new model</Title2>
<Card>
<Form
@ -485,7 +540,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Select>
</Form.Item>
<Form.Item rules={[{ required: true, message: 'Required' }]} label="Public Model Name" name="model_name" tooltip="Model name your users will pass in. Also used for load-balancing, LiteLLM will load balance between all models with this public name." className="mb-0">
<TextInput placeholder="gpt-3.5-turbo"/>
<TextInput placeholder={getPlaceholder(selectedProvider.toString())}/>
</Form.Item>
<Row>
<Col span={10}></Col>
@ -508,9 +563,9 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Form.Item>
<Row>
<Col span={10}></Col>
<Col span={10}><Text className="mb-3 mt-1">Actual model name used for making<Link href="https://docs.litellm.ai/docs/providers" target="_blank">litellm.completion() call</Link>.We&apos;ll<Link href="https://docs.litellm.ai/docs/proxy/reliability#step-1---set-deployments-on-config" target="_blank">loadbalance</Link> models with the same &apos;public name&apos;</Text></Col></Row>
<Col span={10}><Text className="mb-3 mt-1">Actual model name used for making <Link href="https://docs.litellm.ai/docs/providers" target="_blank">litellm.completion() call</Link>. We&apos;ll <Link href="https://docs.litellm.ai/docs/proxy/reliability#step-1---set-deployments-on-config" target="_blank">loadbalance</Link> models with the same &apos;public name&apos;</Text></Col></Row>
{
selectedProvider != Providers.Bedrock && <Form.Item
selectedProvider != Providers.Bedrock && selectedProvider != Providers.Vertex_AI && <Form.Item
rules={[{ required: true, message: 'Required' }]}
label="API Key"
name="api_key"
@ -526,6 +581,32 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
<TextInput placeholder="[OPTIONAL] my-unique-org"/>
</Form.Item>
}
{
selectedProvider == Providers.Vertex_AI && <Form.Item rules={[{ required: true, message: 'Required' }]}
label="Vertex Project"
name="vertex_project"><TextInput placeholder="adroit-cadet-1234.."/></Form.Item>
}
{
selectedProvider == Providers.Vertex_AI && <Form.Item rules={[{ required: true, message: 'Required' }]}
label="Vertex Location"
name="vertex_location"><TextInput placeholder="us-east-1"/></Form.Item>
}
{
selectedProvider == Providers.Vertex_AI && <Form.Item rules={[{ required: true, message: 'Required' }]}
label="Vertex Credentials"
name="vertex_credentials"
className="mb-0">
<Upload {...props}>
<Button2 icon={<UploadOutlined />}>Click to Upload</Button2>
</Upload>
</Form.Item>
}
{
selectedProvider == Providers.Vertex_AI && <Row>
<Col span={10}></Col>
<Col span={10}><Text className="mb-3 mt-1">Give litellm a gcp service account(.json file), so it can make the relevant calls</Text></Col></Row>
}
{
(selectedProvider == Providers.Azure || selectedProvider == Providers.OpenAI_Compatible) && <Form.Item
rules={[{ required: true, message: 'Required' }]}