mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
clean up endpoint selector
This commit is contained in:
parent
72d7b26811
commit
353c882574
5 changed files with 109 additions and 30 deletions
|
@ -26,10 +26,12 @@ import { message, Select } from "antd";
|
|||
import { makeOpenAIChatCompletionRequest } from "./chat_ui/llm_calls/chat_completion";
|
||||
import { makeOpenAIImageGenerationRequest } from "./chat_ui/llm_calls/image_generation";
|
||||
import { fetchAvailableModels, ModelGroup } from "./chat_ui/llm_calls/fetch_models";
|
||||
import { litellmModeMapping, ModelMode, EndpointType } from "./chat_ui/mode_endpoint_mapping";
|
||||
import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./chat_ui/mode_endpoint_mapping";
|
||||
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
|
||||
import { Typography } from "antd";
|
||||
import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
||||
import EndpointSelector from "./chat_ui/EndpointSelector";
|
||||
import { determineEndpointType } from "./chat_ui/EndpointUtils";
|
||||
|
||||
interface ChatUIProps {
|
||||
accessToken: string | null;
|
||||
|
@ -58,7 +60,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
const [showCustomModelInput, setShowCustomModelInput] = useState<boolean>(false);
|
||||
const [modelInfo, setModelInfo] = useState<ModelGroup[]>([]);
|
||||
const customModelTimeout = useRef<NodeJS.Timeout | null>(null);
|
||||
const [endpointType, setEndpointType] = useState<'chat' | 'image'>('chat');
|
||||
const [endpointType, setEndpointType] = useState<string>(EndpointType.CHAT);
|
||||
|
||||
const chatEndRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
|
@ -82,10 +84,11 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
if (uniqueModels.length > 0) {
|
||||
setModelInfo(uniqueModels);
|
||||
setSelectedModel(uniqueModels[0].model_group);
|
||||
// Auto-set endpoint based on the first model's mode if available
|
||||
const firstMode = uniqueModels[0].mode as ModelMode;
|
||||
if (firstMode && litellmModeMapping[firstMode]) {
|
||||
setEndpointType(litellmModeMapping[firstMode] as EndpointType);
|
||||
|
||||
// Auto-set endpoint based on the first model's mode
|
||||
if (uniqueModels[0].mode) {
|
||||
const initialEndpointType = determineEndpointType(uniqueModels[0].model_group, uniqueModels);
|
||||
setEndpointType(initialEndpointType);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
@ -160,7 +163,8 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
|
||||
try {
|
||||
if (selectedModel) {
|
||||
if (endpointType === 'chat') {
|
||||
// Use EndpointType enum for comparison
|
||||
if (endpointType === EndpointType.CHAT) {
|
||||
// Create chat history for API call - strip out model field and isImage field
|
||||
const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage];
|
||||
|
||||
|
@ -170,7 +174,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
selectedModel,
|
||||
effectiveApiKey
|
||||
);
|
||||
} else {
|
||||
} else if (endpointType === EndpointType.IMAGE) {
|
||||
// For image generation
|
||||
await makeOpenAIImageGenerationRequest(
|
||||
inputMessage,
|
||||
|
@ -206,19 +210,18 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
const onModelChange = (value: string) => {
|
||||
console.log(`selected ${value}`);
|
||||
setSelectedModel(value);
|
||||
// Look up the selected model to auto-select the endpoint type
|
||||
const selectedOption = modelInfo.find((option) => option.model_group === value);
|
||||
if (selectedOption && selectedOption.mode) {
|
||||
const mode = selectedOption.mode as ModelMode;
|
||||
if (litellmModeMapping[mode]) {
|
||||
setEndpointType(litellmModeMapping[mode] as EndpointType);
|
||||
}
|
||||
|
||||
// Use the utility function to determine the endpoint type
|
||||
if (value !== 'custom') {
|
||||
const newEndpointType = determineEndpointType(value, modelInfo);
|
||||
setEndpointType(newEndpointType);
|
||||
}
|
||||
|
||||
setShowCustomModelInput(value === 'custom');
|
||||
};
|
||||
|
||||
const handleEndpointChange = (value: string) => {
|
||||
setEndpointType(value as 'chat' | 'image');
|
||||
setEndpointType(value);
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -287,17 +290,11 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
}}
|
||||
/>
|
||||
)}
|
||||
<Text>Endpoint Type:</Text>
|
||||
<Select
|
||||
defaultValue="chat"
|
||||
style={{ width: "350px", marginBottom: "12px" }}
|
||||
onChange={handleEndpointChange}
|
||||
options={[
|
||||
{ value: 'chat', label: '/chat/completions' },
|
||||
{ value: 'image', label: '/images/generations' }
|
||||
]}
|
||||
<EndpointSelector
|
||||
endpointType={endpointType}
|
||||
onEndpointChange={handleEndpointChange}
|
||||
className="mt-2"
|
||||
/>
|
||||
|
||||
</Col>
|
||||
|
||||
</Grid>
|
||||
|
@ -409,13 +406,17 @@ const ChatUI: React.FC<ChatUIProps> = ({
|
|||
value={inputMessage}
|
||||
onChange={(e) => setInputMessage(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={endpointType === 'chat' ? "Type your message..." : "Describe the image you want to generate..."}
|
||||
placeholder={
|
||||
endpointType === EndpointType.CHAT
|
||||
? "Type your message..."
|
||||
: "Describe the image you want to generate..."
|
||||
}
|
||||
/>
|
||||
<Button
|
||||
onClick={handleSendMessage}
|
||||
className="ml-2"
|
||||
>
|
||||
{endpointType === 'chat' ? "Send" : "Generate"}
|
||||
{endpointType === EndpointType.CHAT ? "Send" : "Generate"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
import React from "react";
|
||||
import { Select } from "antd";
|
||||
import { Text } from "@tremor/react";
|
||||
import { EndpointType } from "./mode_endpoint_mapping";
|
||||
|
||||
interface EndpointSelectorProps {
|
||||
endpointType: string; // Accept string to avoid type conflicts
|
||||
onEndpointChange: (value: string) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A reusable component for selecting API endpoints
|
||||
*/
|
||||
const EndpointSelector: React.FC<EndpointSelectorProps> = ({
|
||||
endpointType,
|
||||
onEndpointChange,
|
||||
className,
|
||||
}) => {
|
||||
// Map endpoint types to their display labels
|
||||
const endpointOptions = [
|
||||
{ value: EndpointType.CHAT, label: '/chat/completions' },
|
||||
{ value: EndpointType.IMAGE, label: '/images/generations' }
|
||||
];
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
<Text>Endpoint Type:</Text>
|
||||
<Select
|
||||
value={endpointType}
|
||||
style={{ width: "100%", marginBottom: "12px" }}
|
||||
onChange={onEndpointChange}
|
||||
options={endpointOptions}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default EndpointSelector;
|
|
@ -0,0 +1,27 @@
|
|||
import { ModelGroup } from "./llm_calls/fetch_models";
|
||||
import { ModelMode, EndpointType, getEndpointType } from "./mode_endpoint_mapping";
|
||||
|
||||
/**
|
||||
* Determines the appropriate endpoint type based on the selected model
|
||||
*
|
||||
* @param selectedModel - The model identifier string
|
||||
* @param modelInfo - Array of model information
|
||||
* @returns The appropriate endpoint type
|
||||
*/
|
||||
export const determineEndpointType = (
|
||||
selectedModel: string,
|
||||
modelInfo: ModelGroup[]
|
||||
): EndpointType => {
|
||||
// Find the model information for the selected model
|
||||
const selectedModelInfo = modelInfo.find(
|
||||
(option) => option.model_group === selectedModel
|
||||
);
|
||||
|
||||
// If model info is found and it has a mode, determine the endpoint type
|
||||
if (selectedModelInfo?.mode) {
|
||||
return getEndpointType(selectedModelInfo.mode);
|
||||
}
|
||||
|
||||
// Default to chat endpoint if no match is found
|
||||
return EndpointType.CHAT;
|
||||
};
|
|
@ -20,7 +20,7 @@ export const fetchAvailableModels = async (
|
|||
if (fetchedModels?.data.length > 0) {
|
||||
const models: ModelGroup[] = fetchedModels.data.map((item: any) => ({
|
||||
model_group: item.model_group, // Display the model_group to the user
|
||||
mode: item?.model_info?.mode, // Save the mode for auto-selection of endpoint type
|
||||
mode: item?.mode, // Save the mode for auto-selection of endpoint type
|
||||
}));
|
||||
|
||||
// Sort models alphabetically by label
|
||||
|
|
|
@ -19,4 +19,16 @@ export enum ModelMode {
|
|||
[ModelMode.IMAGE_GENERATION]: EndpointType.IMAGE,
|
||||
[ModelMode.CHAT]: EndpointType.CHAT,
|
||||
};
|
||||
|
||||
|
||||
export const getEndpointType = (mode: string): EndpointType => {
|
||||
// Check if the string mode exists as a key in ModelMode enum
|
||||
console.log("getEndpointType:", mode);
|
||||
if (Object.values(ModelMode).includes(mode as ModelMode)) {
|
||||
const endpointType = litellmModeMapping[mode as ModelMode];
|
||||
console.log("endpointType:", endpointType);
|
||||
return endpointType;
|
||||
}
|
||||
|
||||
// else default to chat
|
||||
return EndpointType.CHAT;
|
||||
};
|
Loading…
Add table
Add a link
Reference in a new issue