Merge pull request #9748 from BerriAI/litellm_ui_allow_testing_image_endpoints
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 18s
Helm unit test / unit-test (push) Successful in 22s

[Feat] UI - Test Key v2 page - allow testing image endpoints + polish the page
This commit is contained in:
Ishaan Jaff 2025-04-03 22:39:45 -07:00 committed by GitHub
commit e67d16d5bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 566 additions and 270 deletions

View file

@ -20,15 +20,28 @@ import {
SelectItem, SelectItem,
TextInput, TextInput,
Button, Button,
Divider,
} from "@tremor/react"; } from "@tremor/react";
import { message, Select } from "antd"; import { message, Select, Spin, Typography, Tooltip } from "antd";
import { modelAvailableCall } from "./networking"; import { makeOpenAIChatCompletionRequest } from "./chat_ui/llm_calls/chat_completion";
import openai from "openai"; import { makeOpenAIImageGenerationRequest } from "./chat_ui/llm_calls/image_generation";
import { ChatCompletionMessageParam } from "openai/resources/chat/completions"; import { fetchAvailableModels, ModelGroup } from "./chat_ui/llm_calls/fetch_models";
import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./chat_ui/mode_endpoint_mapping";
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
import { Typography } from "antd";
import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism'; import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism';
import EndpointSelector from "./chat_ui/EndpointSelector";
import { determineEndpointType } from "./chat_ui/EndpointUtils";
import {
SendOutlined,
ApiOutlined,
KeyOutlined,
ClearOutlined,
RobotOutlined,
UserOutlined,
DeleteOutlined,
LoadingOutlined
} from "@ant-design/icons";
interface ChatUIProps { interface ChatUIProps {
accessToken: string | null; accessToken: string | null;
@ -38,45 +51,6 @@ interface ChatUIProps {
disabledPersonalKeyCreation: boolean; disabledPersonalKeyCreation: boolean;
} }
async function generateModelResponse(
chatHistory: { role: string; content: string }[],
updateUI: (chunk: string, model: string) => void,
selectedModel: string,
accessToken: string
) {
// base url should be the current base_url
const isLocal = process.env.NODE_ENV === "development";
if (isLocal !== true) {
console.log = function () {};
}
console.log("isLocal:", isLocal);
const proxyBaseUrl = isLocal
? "http://localhost:4000"
: window.location.origin;
const client = new openai.OpenAI({
apiKey: accessToken, // Replace with your OpenAI API key
baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL
dangerouslyAllowBrowser: true, // using a temporary litellm proxy key
});
try {
const response = await client.chat.completions.create({
model: selectedModel,
stream: true,
messages: chatHistory as ChatCompletionMessageParam[],
});
for await (const chunk of response) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
updateUI(chunk.choices[0].delta.content, chunk.model);
}
}
} catch (error) {
message.error(`Error occurred while generating model response. Please try again. Error: ${error}`, 20);
}
}
const ChatUI: React.FC<ChatUIProps> = ({ const ChatUI: React.FC<ChatUIProps> = ({
accessToken, accessToken,
token, token,
@ -89,63 +63,55 @@ const ChatUI: React.FC<ChatUIProps> = ({
); );
const [apiKey, setApiKey] = useState(""); const [apiKey, setApiKey] = useState("");
const [inputMessage, setInputMessage] = useState(""); const [inputMessage, setInputMessage] = useState("");
const [chatHistory, setChatHistory] = useState<{ role: string; content: string; model?: string }[]>([]); const [chatHistory, setChatHistory] = useState<{ role: string; content: string; model?: string; isImage?: boolean }[]>([]);
const [selectedModel, setSelectedModel] = useState<string | undefined>( const [selectedModel, setSelectedModel] = useState<string | undefined>(
undefined undefined
); );
const [showCustomModelInput, setShowCustomModelInput] = useState<boolean>(false); const [showCustomModelInput, setShowCustomModelInput] = useState<boolean>(false);
const [modelInfo, setModelInfo] = useState<any[]>([]); const [modelInfo, setModelInfo] = useState<ModelGroup[]>([]);
const customModelTimeout = useRef<NodeJS.Timeout | null>(null); const customModelTimeout = useRef<NodeJS.Timeout | null>(null);
const [endpointType, setEndpointType] = useState<string>(EndpointType.CHAT);
const [isLoading, setIsLoading] = useState<boolean>(false);
const abortControllerRef = useRef<AbortController | null>(null);
const chatEndRef = useRef<HTMLDivElement>(null); const chatEndRef = useRef<HTMLDivElement>(null);
useEffect(() => { useEffect(() => {
let useApiKey = apiKeySource === 'session' ? accessToken : apiKey; let userApiKey = apiKeySource === 'session' ? accessToken : apiKey;
console.log("useApiKey:", useApiKey); if (!userApiKey || !token || !userRole || !userID) {
if (!useApiKey || !token || !userRole || !userID) { console.log("userApiKey or token or userRole or userID is missing = ", userApiKey, token, userRole, userID);
console.log("useApiKey or token or userRole or userID is missing = ", useApiKey, token, userRole, userID);
return; return;
} }
// Fetch model info and set the default selected model // Fetch model info and set the default selected model
const fetchModelInfo = async () => { const loadModels = async () => {
try { try {
const fetchedAvailableModels = await modelAvailableCall( if (!userApiKey) {
useApiKey ?? '', // Use empty string if useApiKey is null, console.log("userApiKey is missing");
userID, return;
userRole }
const uniqueModels = await fetchAvailableModels(
userApiKey,
); );
console.log("model_info:", fetchedAvailableModels); console.log("Fetched models:", uniqueModels);
if (fetchedAvailableModels?.data.length > 0) {
// Create a Map to store unique models using the model ID as key
const uniqueModelsMap = new Map();
fetchedAvailableModels["data"].forEach((item: { id: string }) => {
uniqueModelsMap.set(item.id, {
value: item.id,
label: item.id
});
});
// Convert Map values back to array
const uniqueModels = Array.from(uniqueModelsMap.values());
// Sort models alphabetically
uniqueModels.sort((a, b) => a.label.localeCompare(b.label));
if (uniqueModels.length > 0) {
setModelInfo(uniqueModels); setModelInfo(uniqueModels);
setSelectedModel(uniqueModels[0].value); setSelectedModel(uniqueModels[0].model_group);
// Auto-set endpoint based on the first model's mode
if (uniqueModels[0].mode) {
const initialEndpointType = determineEndpointType(uniqueModels[0].model_group, uniqueModels);
setEndpointType(initialEndpointType);
}
} }
} catch (error) { } catch (error) {
console.error("Error fetching model info:", error); console.error("Error fetching model info:", error);
} }
}; };
fetchModelInfo(); loadModels();
}, [accessToken, userID, userRole, apiKeySource, apiKey]); }, [accessToken, userID, userRole, apiKeySource, apiKey]);
@ -162,11 +128,11 @@ const ChatUI: React.FC<ChatUIProps> = ({
} }
}, [chatHistory]); }, [chatHistory]);
const updateUI = (role: string, chunk: string, model?: string) => { const updateTextUI = (role: string, chunk: string, model?: string) => {
setChatHistory((prevHistory) => { setChatHistory((prevHistory) => {
const lastMessage = prevHistory[prevHistory.length - 1]; const lastMessage = prevHistory[prevHistory.length - 1];
if (lastMessage && lastMessage.role === role) { if (lastMessage && lastMessage.role === role && !lastMessage.isImage) {
return [ return [
...prevHistory.slice(0, prevHistory.length - 1), ...prevHistory.slice(0, prevHistory.length - 1),
{ role, content: lastMessage.content + chunk, model }, { role, content: lastMessage.content + chunk, model },
@ -177,12 +143,28 @@ const ChatUI: React.FC<ChatUIProps> = ({
}); });
}; };
const updateImageUI = (imageUrl: string, model: string) => {
setChatHistory((prevHistory) => [
...prevHistory,
{ role: "assistant", content: imageUrl, model, isImage: true }
]);
};
const handleKeyDown = (event: React.KeyboardEvent<HTMLInputElement>) => { const handleKeyDown = (event: React.KeyboardEvent<HTMLInputElement>) => {
if (event.key === 'Enter') { if (event.key === 'Enter') {
handleSendMessage(); handleSendMessage();
} }
}; };
const handleCancelRequest = () => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
setIsLoading(false);
message.info("Request cancelled");
}
};
const handleSendMessage = async () => { const handleSendMessage = async () => {
if (inputMessage.trim() === "") return; if (inputMessage.trim() === "") return;
@ -197,27 +179,52 @@ const ChatUI: React.FC<ChatUIProps> = ({
return; return;
} }
// Create new abort controller for this request
abortControllerRef.current = new AbortController();
const signal = abortControllerRef.current.signal;
// Create message object without model field for API call // Create message object without model field for API call
const newUserMessage = { role: "user", content: inputMessage }; const newUserMessage = { role: "user", content: inputMessage };
// Create chat history for API call - strip out model field // Update UI with full message object
const apiChatHistory = [...chatHistory.map(({ role, content }) => ({ role, content })), newUserMessage];
// Update UI with full message object (including model field for display)
setChatHistory([...chatHistory, newUserMessage]); setChatHistory([...chatHistory, newUserMessage]);
setIsLoading(true);
try { try {
if (selectedModel) { if (selectedModel) {
await generateModelResponse( // Use EndpointType enum for comparison
if (endpointType === EndpointType.CHAT) {
// Create chat history for API call - strip out model field and isImage field
const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage];
await makeOpenAIChatCompletionRequest(
apiChatHistory, apiChatHistory,
(chunk, model) => updateUI("assistant", chunk, model), (chunk, model) => updateTextUI("assistant", chunk, model),
selectedModel, selectedModel,
effectiveApiKey effectiveApiKey,
signal
);
} else if (endpointType === EndpointType.IMAGE) {
// For image generation
await makeOpenAIImageGenerationRequest(
inputMessage,
(imageUrl, model) => updateImageUI(imageUrl, model),
selectedModel,
effectiveApiKey,
signal
); );
} }
}
} catch (error) { } catch (error) {
console.error("Error fetching model response", error); if (signal.aborted) {
updateUI("assistant", "Error fetching model response"); console.log("Request was cancelled");
} else {
console.error("Error fetching response", error);
updateTextUI("assistant", "Error fetching response");
}
} finally {
setIsLoading(false);
abortControllerRef.current = null;
} }
setInputMessage(""); setInputMessage("");
@ -238,27 +245,37 @@ const ChatUI: React.FC<ChatUIProps> = ({
); );
} }
const onChange = (value: string) => { const onModelChange = (value: string) => {
console.log(`selected ${value}`); console.log(`selected ${value}`);
setSelectedModel(value); setSelectedModel(value);
// Use the utility function to determine the endpoint type
if (value !== 'custom') {
const newEndpointType = determineEndpointType(value, modelInfo);
setEndpointType(newEndpointType);
}
setShowCustomModelInput(value === 'custom'); setShowCustomModelInput(value === 'custom');
}; };
return ( const handleEndpointChange = (value: string) => {
<div style={{ width: "100%", position: "relative" }}> setEndpointType(value);
<Grid className="gap-2 p-8 h-[80vh] w-full mt-2"> };
<Card>
<TabGroup> const antIcon = <LoadingOutlined style={{ fontSize: 24 }} spin />;
<TabList>
<Tab>Chat</Tab> return (
</TabList> <div className="w-full h-screen p-4 bg-white">
<TabPanels> <Card className="w-full rounded-xl shadow-md overflow-hidden">
<TabPanel> <div className="flex h-[80vh] w-full">
<div className="sm:max-w-2xl"> {/* Left Sidebar with Controls */}
<Grid numItems={2}> <div className="w-1/4 p-4 border-r border-gray-200 bg-gray-50">
<Col> <div className="mb-6">
<Text>API Key Source</Text> <div className="space-y-6">
<div>
<Text className="font-medium block mb-2 text-gray-700 flex items-center">
<KeyOutlined className="mr-2" /> API Key Source
</Text>
<Select <Select
disabled={disabledPersonalKeyCreation} disabled={disabledPersonalKeyCreation}
defaultValue="session" defaultValue="session"
@ -268,6 +285,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
{ value: 'session', label: 'Current UI Session' }, { value: 'session', label: 'Current UI Session' },
{ value: 'custom', label: 'Virtual Key' }, { value: 'custom', label: 'Virtual Key' },
]} ]}
className="rounded-md"
/> />
{apiKeySource === 'custom' && ( {apiKeySource === 'custom' && (
<TextInput <TextInput
@ -276,20 +294,28 @@ const ChatUI: React.FC<ChatUIProps> = ({
type="password" type="password"
onValueChange={setApiKey} onValueChange={setApiKey}
value={apiKey} value={apiKey}
icon={KeyOutlined}
/> />
)} )}
</Col> </div>
<Col className="mx-2">
<Text>Select Model:</Text> <div>
<Text className="font-medium block mb-2 text-gray-700 flex items-center">
<RobotOutlined className="mr-2" /> Select Model
</Text>
<Select <Select
placeholder="Select a Model" placeholder="Select a Model"
onChange={onChange} onChange={onModelChange}
options={[ options={[
...modelInfo, ...modelInfo.map((option) => ({
value: option.model_group,
label: option.model_group
})),
{ value: 'custom', label: 'Enter custom model' } { value: 'custom', label: 'Enter custom model' }
]} ]}
style={{ width: "350px" }} style={{ width: "100%" }}
showSearch={true} showSearch={true}
className="rounded-md"
/> />
{showCustomModelInput && ( {showCustomModelInput && (
<TextInput <TextInput
@ -307,61 +333,75 @@ const ChatUI: React.FC<ChatUIProps> = ({
}} }}
/> />
)} )}
</Col> </div>
</Grid>
<div>
<Text className="font-medium block mb-2 text-gray-700 flex items-center">
<ApiOutlined className="mr-2" /> Endpoint Type
</Text>
<EndpointSelector
endpointType={endpointType}
onEndpointChange={handleEndpointChange}
className="mb-4"
/>
</div>
{/* Clear Chat Button */}
<Button <Button
onClick={clearChatHistory} onClick={clearChatHistory}
className="mt-4" className="w-full bg-gray-100 hover:bg-gray-200 text-gray-700 border-gray-300 mt-4"
icon={ClearOutlined}
> >
Clear Chat Clear Chat
</Button> </Button>
</div> </div>
<Table </div>
className="mt-5" </div>
style={{
display: "block", {/* Main Chat Area */}
maxHeight: "60vh", <div className="w-3/4 flex flex-col bg-white">
overflowY: "auto", <div className="flex-1 overflow-auto p-4 pb-0">
}} {chatHistory.length === 0 && (
> <div className="h-full flex flex-col items-center justify-center text-gray-400">
<TableHead> <RobotOutlined style={{ fontSize: '48px', marginBottom: '16px' }} />
<TableRow> <Text>Start a conversation or generate an image</Text>
<TableCell> </div>
{/* <Title>Chat</Title> */} )}
</TableCell>
</TableRow>
</TableHead>
<TableBody>
{chatHistory.map((message, index) => ( {chatHistory.map((message, index) => (
<TableRow key={index}> <div
<TableCell> key={index}
<div style={{ className={`mb-4 ${message.role === "user" ? "text-right" : "text-left"}`}
display: 'flex', >
alignItems: 'center', <div className="inline-block max-w-[80%] rounded-lg shadow-sm p-3.5 px-4" style={{
gap: '8px', backgroundColor: message.role === "user" ? '#f0f8ff' : '#ffffff',
marginBottom: '4px' border: message.role === "user" ? '1px solid #e6f0fa' : '1px solid #f0f0f0',
textAlign: 'left'
}}> }}>
<strong>{message.role}</strong> <div className="flex items-center gap-2 mb-1.5">
<div className="flex items-center justify-center w-6 h-6 rounded-full mr-1" style={{
backgroundColor: message.role === "user" ? '#e6f0fa' : '#f5f5f5',
}}>
{message.role === "user" ?
<UserOutlined style={{ fontSize: '12px', color: '#2563eb' }} /> :
<RobotOutlined style={{ fontSize: '12px', color: '#4b5563' }} />
}
</div>
<strong className="text-sm capitalize">{message.role}</strong>
{message.role === "assistant" && message.model && ( {message.role === "assistant" && message.model && (
<span style={{ <span className="text-xs px-2 py-0.5 rounded bg-gray-100 text-gray-600 font-normal">
fontSize: '12px',
color: '#666',
backgroundColor: '#f5f5f5',
padding: '2px 6px',
borderRadius: '4px',
fontWeight: 'normal'
}}>
{message.model} {message.model}
</span> </span>
)} )}
</div> </div>
<div style={{ <div className="whitespace-pre-wrap break-words max-w-full message-content">
whiteSpace: "pre-wrap", {message.isImage ? (
wordBreak: "break-word", <img
maxWidth: "100%" src={message.content}
}}> alt="Generated image"
className="max-w-full rounded-md border border-gray-200 shadow-sm"
style={{ maxHeight: '500px' }}
/>
) : (
<ReactMarkdown <ReactMarkdown
components={{ components={{
code({node, inline, className, children, ...props}: React.ComponentPropsWithoutRef<'code'> & { code({node, inline, className, children, ...props}: React.ComponentPropsWithoutRef<'code'> & {
@ -374,12 +414,13 @@ const ChatUI: React.FC<ChatUIProps> = ({
style={coy as any} style={coy as any}
language={match[1]} language={match[1]}
PreTag="div" PreTag="div"
className="rounded-md my-2"
{...props} {...props}
> >
{String(children).replace(/\n$/, '')} {String(children).replace(/\n$/, '')}
</SyntaxHighlighter> </SyntaxHighlighter>
) : ( ) : (
<code className={className} {...props}> <code className={`${className} px-1.5 py-0.5 rounded bg-gray-100 text-sm font-mono`} {...props}>
{children} {children}
</code> </code>
); );
@ -388,43 +429,56 @@ const ChatUI: React.FC<ChatUIProps> = ({
> >
{message.content} {message.content}
</ReactMarkdown> </ReactMarkdown>
)}
</div>
</div>
</div> </div>
</TableCell>
</TableRow>
))} ))}
<TableRow> {isLoading && (
<TableCell> <div className="flex justify-center items-center my-4">
<Spin indicator={antIcon} />
</div>
)}
<div ref={chatEndRef} style={{ height: "1px" }} /> <div ref={chatEndRef} style={{ height: "1px" }} />
</TableCell> </div>
</TableRow>
</TableBody> <div className="p-4 border-t border-gray-200 bg-white">
</Table> <div className="flex items-center">
<div
className="mt-3"
style={{ position: "absolute", bottom: 5, width: "95%" }}
>
<div className="flex" style={{ marginTop: "16px" }}>
<TextInput <TextInput
type="text" type="text"
value={inputMessage} value={inputMessage}
onChange={(e) => setInputMessage(e.target.value)} onChange={(e) => setInputMessage(e.target.value)}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
placeholder="Type your message..." placeholder={
endpointType === EndpointType.CHAT
? "Type your message..."
: "Describe the image you want to generate..."
}
disabled={isLoading}
className="flex-1"
/> />
{isLoading ? (
<Button
onClick={handleCancelRequest}
className="ml-2 bg-red-50 hover:bg-red-100 text-red-600 border-red-200"
icon={DeleteOutlined}
>
Cancel
</Button>
) : (
<Button <Button
onClick={handleSendMessage} onClick={handleSendMessage}
className="ml-2" className="ml-2 text-white"
icon={endpointType === EndpointType.CHAT ? SendOutlined : RobotOutlined}
> >
Send {endpointType === EndpointType.CHAT ? "Send" : "Generate"}
</Button> </Button>
)}
</div>
</div>
</div> </div>
</div> </div>
</TabPanel>
</TabPanels>
</TabGroup>
</Card> </Card>
</Grid>
</div> </div>
); );
}; };

View file

@ -0,0 +1,40 @@
import React from "react";
import { Select } from "antd";
import { Text } from "@tremor/react";
import { EndpointType } from "./mode_endpoint_mapping";
interface EndpointSelectorProps {
endpointType: string; // Accept string to avoid type conflicts
onEndpointChange: (value: string) => void;
className?: string;
}
/**
* A reusable component for selecting API endpoints
*/
const EndpointSelector: React.FC<EndpointSelectorProps> = ({
endpointType,
onEndpointChange,
className,
}) => {
// Map endpoint types to their display labels
const endpointOptions = [
{ value: EndpointType.CHAT, label: '/chat/completions' },
{ value: EndpointType.IMAGE, label: '/images/generations' }
];
return (
<div className={className}>
<Text>Endpoint Type:</Text>
<Select
value={endpointType}
style={{ width: "100%" }}
onChange={onEndpointChange}
options={endpointOptions}
className="rounded-md"
/>
</div>
);
};
export default EndpointSelector;

View file

@ -0,0 +1,27 @@
import { ModelGroup } from "./llm_calls/fetch_models";
import { ModelMode, EndpointType, getEndpointType } from "./mode_endpoint_mapping";
/**
* Determines the appropriate endpoint type based on the selected model
*
* @param selectedModel - The model identifier string
* @param modelInfo - Array of model information
* @returns The appropriate endpoint type
*/
export const determineEndpointType = (
selectedModel: string,
modelInfo: ModelGroup[]
): EndpointType => {
// Find the model information for the selected model
const selectedModelInfo = modelInfo.find(
(option) => option.model_group === selectedModel
);
// If model info is found and it has a mode, determine the endpoint type
if (selectedModelInfo?.mode) {
return getEndpointType(selectedModelInfo.mode);
}
// Default to chat endpoint if no match is found
return EndpointType.CHAT;
};

View file

@ -0,0 +1,49 @@
import openai from "openai";
import { ChatCompletionMessageParam } from "openai/resources/chat/completions";
import { message } from "antd";
export async function makeOpenAIChatCompletionRequest(
chatHistory: { role: string; content: string }[],
updateUI: (chunk: string, model: string) => void,
selectedModel: string,
accessToken: string,
signal?: AbortSignal
) {
// base url should be the current base_url
const isLocal = process.env.NODE_ENV === "development";
if (isLocal !== true) {
console.log = function () {};
}
console.log("isLocal:", isLocal);
const proxyBaseUrl = isLocal
? "http://localhost:4000"
: window.location.origin;
const client = new openai.OpenAI({
apiKey: accessToken, // Replace with your OpenAI API key
baseURL: proxyBaseUrl, // Replace with your OpenAI API base URL
dangerouslyAllowBrowser: true, // using a temporary litellm proxy key
});
try {
const response = await client.chat.completions.create({
model: selectedModel,
stream: true,
messages: chatHistory as ChatCompletionMessageParam[],
}, { signal });
for await (const chunk of response) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
updateUI(chunk.choices[0].delta.content, chunk.model);
}
}
} catch (error) {
if (signal?.aborted) {
console.log("Chat completion request was cancelled");
} else {
message.error(`Error occurred while generating model response. Please try again. Error: ${error}`, 20);
}
throw error; // Re-throw to allow the caller to handle the error
}
}

View file

@ -0,0 +1,35 @@
// fetch_models.ts
import { modelHubCall } from "../../networking";
export interface ModelGroup {
model_group: string;
mode?: string;
}
/**
* Fetches available models using modelHubCall and formats them for the selection dropdown.
*/
export const fetchAvailableModels = async (
accessToken: string
): Promise<ModelGroup[]> => {
try {
const fetchedModels = await modelHubCall(accessToken);
console.log("model_info:", fetchedModels);
if (fetchedModels?.data.length > 0) {
const models: ModelGroup[] = fetchedModels.data.map((item: any) => ({
model_group: item.model_group, // Display the model_group to the user
mode: item?.mode, // Save the mode for auto-selection of endpoint type
}));
// Sort models alphabetically by label
models.sort((a, b) => a.model_group.localeCompare(b.model_group));
return models;
}
return [];
} catch (error) {
console.error("Error fetching model info:", error);
throw error;
}
};

View file

@ -0,0 +1,57 @@
import openai from "openai";
import { message } from "antd";
export async function makeOpenAIImageGenerationRequest(
prompt: string,
updateUI: (imageUrl: string, model: string) => void,
selectedModel: string,
accessToken: string,
signal?: AbortSignal
) {
// base url should be the current base_url
const isLocal = process.env.NODE_ENV === "development";
if (isLocal !== true) {
console.log = function () {};
}
console.log("isLocal:", isLocal);
const proxyBaseUrl = isLocal
? "http://localhost:4000"
: window.location.origin;
const client = new openai.OpenAI({
apiKey: accessToken,
baseURL: proxyBaseUrl,
dangerouslyAllowBrowser: true,
});
try {
const response = await client.images.generate({
model: selectedModel,
prompt: prompt,
}, { signal });
console.log(response.data);
if (response.data && response.data[0]) {
// Handle either URL or base64 data from response
if (response.data[0].url) {
// Use the URL directly
updateUI(response.data[0].url, selectedModel);
} else if (response.data[0].b64_json) {
// Convert base64 to data URL format
const base64Data = response.data[0].b64_json;
updateUI(`data:image/png;base64,${base64Data}`, selectedModel);
} else {
throw new Error("No image data found in response");
}
} else {
throw new Error("Invalid response format");
}
} catch (error) {
if (signal?.aborted) {
console.log("Image generation request was cancelled");
} else {
message.error(`Error occurred while generating image. Please try again. Error: ${error}`, 20);
}
throw error; // Re-throw to allow the caller to handle the error
}
}

View file

@ -0,0 +1,34 @@
// litellmMapping.ts
// Define an enum for the modes as returned in model_info
export enum ModelMode {
IMAGE_GENERATION = "image_generation",
CHAT = "chat",
// add additional modes as needed
}
// Define an enum for the endpoint types your UI calls
export enum EndpointType {
IMAGE = "image",
CHAT = "chat",
// add additional endpoint types if required
}
// Create a mapping between the model mode and the corresponding endpoint type
export const litellmModeMapping: Record<ModelMode, EndpointType> = {
[ModelMode.IMAGE_GENERATION]: EndpointType.IMAGE,
[ModelMode.CHAT]: EndpointType.CHAT,
};
export const getEndpointType = (mode: string): EndpointType => {
// Check if the string mode exists as a key in ModelMode enum
console.log("getEndpointType:", mode);
if (Object.values(ModelMode).includes(mode as ModelMode)) {
const endpointType = litellmModeMapping[mode as ModelMode];
console.log("endpointType:", endpointType);
return endpointType;
}
// else default to chat
return EndpointType.CHAT;
};