diff --git a/ui/litellm-dashboard/src/components/chat_ui.tsx b/ui/litellm-dashboard/src/components/chat_ui.tsx index 1491a57b49..f9c731461c 100644 --- a/ui/litellm-dashboard/src/components/chat_ui.tsx +++ b/ui/litellm-dashboard/src/components/chat_ui.tsx @@ -25,6 +25,7 @@ import { import { message, Select } from "antd"; import { modelAvailableCall } from "./networking"; import { makeOpenAIChatCompletionRequest } from "./chat_ui/llm_calls/chat_completion"; +import { makeOpenAIImageGenerationRequest } from "./chat_ui/llm_calls/image_generation"; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; import { Typography } from "antd"; import { coy } from 'react-syntax-highlighter/dist/esm/styles/prism'; @@ -49,13 +50,14 @@ const ChatUI: React.FC = ({ ); const [apiKey, setApiKey] = useState(""); const [inputMessage, setInputMessage] = useState(""); - const [chatHistory, setChatHistory] = useState<{ role: string; content: string; model?: string }[]>([]); + const [chatHistory, setChatHistory] = useState<{ role: string; content: string; model?: string; isImage?: boolean }[]>([]); const [selectedModel, setSelectedModel] = useState( undefined ); const [showCustomModelInput, setShowCustomModelInput] = useState(false); const [modelInfo, setModelInfo] = useState([]); const customModelTimeout = useRef(null); + const [endpointType, setEndpointType] = useState<'chat' | 'image'>('chat'); const chatEndRef = useRef(null); @@ -67,8 +69,6 @@ const ChatUI: React.FC = ({ return; } - - // Fetch model info and set the default selected model const fetchModelInfo = async () => { try { @@ -122,11 +122,11 @@ const ChatUI: React.FC = ({ } }, [chatHistory]); - const updateUI = (role: string, chunk: string, model?: string) => { + const updateTextUI = (role: string, chunk: string, model?: string) => { setChatHistory((prevHistory) => { const lastMessage = prevHistory[prevHistory.length - 1]; - if (lastMessage && lastMessage.role === role) { + if (lastMessage && lastMessage.role === role && !lastMessage.isImage) { return [ ...prevHistory.slice(0, prevHistory.length - 1), { role, content: lastMessage.content + chunk, model }, @@ -137,6 +137,13 @@ const ChatUI: React.FC = ({ }); }; + const updateImageUI = (imageUrl: string, model: string) => { + setChatHistory((prevHistory) => [ + ...prevHistory, + { role: "assistant", content: imageUrl, model, isImage: true } + ]); + }; + const handleKeyDown = (event: React.KeyboardEvent) => { if (event.key === 'Enter') { handleSendMessage(); @@ -160,24 +167,34 @@ const ChatUI: React.FC = ({ // Create message object without model field for API call const newUserMessage = { role: "user", content: inputMessage }; - // Create chat history for API call - strip out model field - const apiChatHistory = [...chatHistory.map(({ role, content }) => ({ role, content })), newUserMessage]; - - // Update UI with full message object (including model field for display) + // Update UI with full message object setChatHistory([...chatHistory, newUserMessage]); try { if (selectedModel) { - await makeOpenAIChatCompletionRequest( - apiChatHistory, - (chunk, model) => updateUI("assistant", chunk, model), - selectedModel, - effectiveApiKey - ); + if (endpointType === 'chat') { + // Create chat history for API call - strip out model field and isImage field + const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage]; + + await makeOpenAIChatCompletionRequest( + apiChatHistory, + (chunk, model) => updateTextUI("assistant", chunk, model), + selectedModel, + effectiveApiKey + ); + } else { + // For image generation + await makeOpenAIImageGenerationRequest( + inputMessage, + (imageUrl, model) => updateImageUI(imageUrl, model), + selectedModel, + effectiveApiKey + ); + } } } catch (error) { - console.error("Error fetching model response", error); - updateUI("assistant", "Error fetching model response"); + console.error("Error fetching response", error); + updateTextUI("assistant", "Error fetching response"); } setInputMessage(""); @@ -198,12 +215,16 @@ const ChatUI: React.FC = ({ ); } - const onChange = (value: string) => { + const onModelChange = (value: string) => { console.log(`selected ${value}`); setSelectedModel(value); setShowCustomModelInput(value === 'custom'); }; + const handleEndpointChange = (value: string) => { + setEndpointType(value as 'chat' | 'image'); + }; + return (
@@ -240,10 +261,21 @@ const ChatUI: React.FC = ({ )} + Endpoint Type: + = ({ wordBreak: "break-word", maxWidth: "100%" }}> - & { - inline?: boolean; - node?: any; - }) { - const match = /language-(\w+)/.exec(className || ''); - return !inline && match ? ( - - {String(children).replace(/\n$/, '')} - - ) : ( - - {children} - - ); - } - }} - > - {message.content} - + {message.isImage ? ( + Generated image + ) : ( + & { + inline?: boolean; + node?: any; + }) { + const match = /language-(\w+)/.exec(className || ''); + return !inline && match ? ( + + {String(children).replace(/\n$/, '')} + + ) : ( + + {children} + + ); + } + }} + > + {message.content} + + )}
@@ -369,13 +409,13 @@ const ChatUI: React.FC = ({ value={inputMessage} onChange={(e) => setInputMessage(e.target.value)} onKeyDown={handleKeyDown} - placeholder="Type your message..." + placeholder={endpointType === 'chat' ? "Type your message..." : "Describe the image you want to generate..."} /> diff --git a/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx b/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx new file mode 100644 index 0000000000..1824b83d0b --- /dev/null +++ b/ui/litellm-dashboard/src/components/chat_ui/llm_calls/image_generation.tsx @@ -0,0 +1,51 @@ +import openai from "openai"; +import { message } from "antd"; + +export async function makeOpenAIImageGenerationRequest( + prompt: string, + updateUI: (imageUrl: string, model: string) => void, + selectedModel: string, + accessToken: string +) { + // base url should be the current base_url + const isLocal = process.env.NODE_ENV === "development"; + if (isLocal !== true) { + console.log = function () {}; + } + console.log("isLocal:", isLocal); + const proxyBaseUrl = isLocal + ? "http://localhost:4000" + : window.location.origin; + const client = new openai.OpenAI({ + apiKey: accessToken, + baseURL: proxyBaseUrl, + dangerouslyAllowBrowser: true, + }); + + try { + const response = await client.images.generate({ + model: selectedModel, + prompt: prompt, + }); + + console.log(response.data); + + if (response.data && response.data[0]) { + // Handle either URL or base64 data from response + if (response.data[0].url) { + // Use the URL directly + updateUI(response.data[0].url, selectedModel); + } else if (response.data[0].b64_json) { + // Convert base64 to data URL format + const base64Data = response.data[0].b64_json; + updateUI(`data:image/png;base64,${base64Data}`, selectedModel); + } else { + throw new Error("No image data found in response"); + } + } else { + throw new Error("Invalid response format"); + } + } catch (error) { + message.error(`Error occurred while generating image. Please try again. Error: ${error}`, 20); + } +}