clean up endpoint selector

This commit is contained in:
Ishaan Jaff 2025-04-03 21:33:39 -07:00
parent 72d7b26811
commit 353c882574
5 changed files with 109 additions and 30 deletions

View file

@ -26,10 +26,12 @@ import { message, Select } from "antd";
import { makeOpenAIChatCompletionRequest } from "./chat_ui/llm_calls/chat_completion";
import { makeOpenAIImageGenerationRequest } from "./chat_ui/llm_calls/image_generation";
import { fetchAvailableModels, ModelGroup } from "./chat_ui/llm_calls/fetch_models";
import { litellmModeMapping, ModelMode, EndpointType } from "./chat_ui/mode_endpoint_mapping";
import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./chat_ui/mode_endpoint_mapping";
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
import { Typography } from "antd";
import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism';
import EndpointSelector from "./chat_ui/EndpointSelector";
import { determineEndpointType } from "./chat_ui/EndpointUtils";
interface ChatUIProps {
accessToken: string | null;
@ -58,7 +60,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
const [showCustomModelInput, setShowCustomModelInput] = useState<boolean>(false);
const [modelInfo, setModelInfo] = useState<ModelGroup[]>([]);
const customModelTimeout = useRef<NodeJS.Timeout | null>(null);
const [endpointType, setEndpointType] = useState<'chat' | 'image'>('chat');
const [endpointType, setEndpointType] = useState<string>(EndpointType.CHAT);
const chatEndRef = useRef<HTMLDivElement>(null);
@ -82,10 +84,11 @@ const ChatUI: React.FC<ChatUIProps> = ({
if (uniqueModels.length > 0) {
setModelInfo(uniqueModels);
setSelectedModel(uniqueModels[0].model_group);
// Auto-set endpoint based on the first model's mode if available
const firstMode = uniqueModels[0].mode as ModelMode;
if (firstMode && litellmModeMapping[firstMode]) {
setEndpointType(litellmModeMapping[firstMode] as EndpointType);
// Auto-set endpoint based on the first model's mode
if (uniqueModels[0].mode) {
const initialEndpointType = determineEndpointType(uniqueModels[0].model_group, uniqueModels);
setEndpointType(initialEndpointType);
}
}
} catch (error) {
@ -160,7 +163,8 @@ const ChatUI: React.FC<ChatUIProps> = ({
try {
if (selectedModel) {
if (endpointType === 'chat') {
// Use EndpointType enum for comparison
if (endpointType === EndpointType.CHAT) {
// Create chat history for API call - strip out model field and isImage field
const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage];
@ -170,7 +174,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
selectedModel,
effectiveApiKey
);
} else {
} else if (endpointType === EndpointType.IMAGE) {
// For image generation
await makeOpenAIImageGenerationRequest(
inputMessage,
@ -206,19 +210,18 @@ const ChatUI: React.FC<ChatUIProps> = ({
const onModelChange = (value: string) => {
console.log(`selected ${value}`);
setSelectedModel(value);
// Look up the selected model to auto-select the endpoint type
const selectedOption = modelInfo.find((option) => option.model_group === value);
if (selectedOption && selectedOption.mode) {
const mode = selectedOption.mode as ModelMode;
if (litellmModeMapping[mode]) {
setEndpointType(litellmModeMapping[mode] as EndpointType);
}
// Use the utility function to determine the endpoint type
if (value !== 'custom') {
const newEndpointType = determineEndpointType(value, modelInfo);
setEndpointType(newEndpointType);
}
setShowCustomModelInput(value === 'custom');
};
const handleEndpointChange = (value: string) => {
setEndpointType(value as 'chat' | 'image');
setEndpointType(value);
};
return (
@ -287,17 +290,11 @@ const ChatUI: React.FC<ChatUIProps> = ({
}}
/>
)}
<Text>Endpoint Type:</Text>
<Select
defaultValue="chat"
style={{ width: "350px", marginBottom: "12px" }}
onChange={handleEndpointChange}
options={[
{ value: 'chat', label: '/chat/completions' },
{ value: 'image', label: '/images/generations' }
]}
<EndpointSelector
endpointType={endpointType}
onEndpointChange={handleEndpointChange}
className="mt-2"
/>
</Col>
</Grid>
@ -409,13 +406,17 @@ const ChatUI: React.FC<ChatUIProps> = ({
value={inputMessage}
onChange={(e) => setInputMessage(e.target.value)}
onKeyDown={handleKeyDown}
placeholder={endpointType === 'chat' ? "Type your message..." : "Describe the image you want to generate..."}
placeholder={
endpointType === EndpointType.CHAT
? "Type your message..."
: "Describe the image you want to generate..."
}
/>
<Button
onClick={handleSendMessage}
className="ml-2"
>
{endpointType === 'chat' ? "Send" : "Generate"}
{endpointType === EndpointType.CHAT ? "Send" : "Generate"}
</Button>
</div>
</div>

View file

@ -0,0 +1,39 @@
import React from "react";
import { Select } from "antd";
import { Text } from "@tremor/react";
import { EndpointType } from "./mode_endpoint_mapping";
interface EndpointSelectorProps {
endpointType: string; // Accept string to avoid type conflicts
onEndpointChange: (value: string) => void;
className?: string;
}
/**
* A reusable component for selecting API endpoints
*/
const EndpointSelector: React.FC<EndpointSelectorProps> = ({
endpointType,
onEndpointChange,
className,
}) => {
// Map endpoint types to their display labels
const endpointOptions = [
{ value: EndpointType.CHAT, label: '/chat/completions' },
{ value: EndpointType.IMAGE, label: '/images/generations' }
];
return (
<div className={className}>
<Text>Endpoint Type:</Text>
<Select
value={endpointType}
style={{ width: "100%", marginBottom: "12px" }}
onChange={onEndpointChange}
options={endpointOptions}
/>
</div>
);
};
export default EndpointSelector;

View file

@ -0,0 +1,27 @@
import { ModelGroup } from "./llm_calls/fetch_models";
import { ModelMode, EndpointType, getEndpointType } from "./mode_endpoint_mapping";
/**
* Determines the appropriate endpoint type based on the selected model
*
* @param selectedModel - The model identifier string
* @param modelInfo - Array of model information
* @returns The appropriate endpoint type
*/
export const determineEndpointType = (
selectedModel: string,
modelInfo: ModelGroup[]
): EndpointType => {
// Find the model information for the selected model
const selectedModelInfo = modelInfo.find(
(option) => option.model_group === selectedModel
);
// If model info is found and it has a mode, determine the endpoint type
if (selectedModelInfo?.mode) {
return getEndpointType(selectedModelInfo.mode);
}
// Default to chat endpoint if no match is found
return EndpointType.CHAT;
};

View file

@ -20,7 +20,7 @@ export const fetchAvailableModels = async (
if (fetchedModels?.data.length > 0) {
const models: ModelGroup[] = fetchedModels.data.map((item: any) => ({
model_group: item.model_group, // Display the model_group to the user
mode: item?.model_info?.mode, // Save the mode for auto-selection of endpoint type
mode: item?.mode, // Save the mode for auto-selection of endpoint type
}));
// Sort models alphabetically by label

View file

@ -19,4 +19,16 @@ export enum ModelMode {
[ModelMode.IMAGE_GENERATION]: EndpointType.IMAGE,
[ModelMode.CHAT]: EndpointType.CHAT,
};
export const getEndpointType = (mode: string): EndpointType => {
// Check if the string mode exists as a key in ModelMode enum
console.log("getEndpointType:", mode);
if (Object.values(ModelMode).includes(mode as ModelMode)) {
const endpointType = litellmModeMapping[mode as ModelMode];
console.log("endpointType:", endpointType);
return endpointType;
}
// else default to chat
return EndpointType.CHAT;
};