Merge branch 'main' into litellm_bug_fix_metadata_cost

This commit is contained in:
Ishaan Jaff 2025-04-19 14:01:02 -07:00
commit e9d42e3755
12 changed files with 422 additions and 72 deletions

View file

@ -265,9 +265,10 @@ def generic_cost_per_token(
) )
## CALCULATE OUTPUT COST ## CALCULATE OUTPUT COST
text_tokens = usage.completion_tokens text_tokens = 0
audio_tokens = 0 audio_tokens = 0
reasoning_tokens = 0 reasoning_tokens = 0
is_text_tokens_total = False
if usage.completion_tokens_details is not None: if usage.completion_tokens_details is not None:
audio_tokens = ( audio_tokens = (
cast( cast(
@ -281,7 +282,7 @@ def generic_cost_per_token(
Optional[int], Optional[int],
getattr(usage.completion_tokens_details, "text_tokens", None), getattr(usage.completion_tokens_details, "text_tokens", None),
) )
or usage.completion_tokens # default to completion tokens, if this field is not set or 0 # default to completion tokens, if this field is not set
) )
reasoning_tokens = ( reasoning_tokens = (
cast( cast(
@ -290,6 +291,11 @@ def generic_cost_per_token(
) )
or 0 or 0
) )
if text_tokens == 0:
text_tokens = usage.completion_tokens
if text_tokens == usage.completion_tokens:
is_text_tokens_total = True
## TEXT COST ## TEXT COST
completion_cost = float(text_tokens) * completion_base_cost completion_cost = float(text_tokens) * completion_base_cost
@ -302,19 +308,21 @@ def generic_cost_per_token(
) )
## AUDIO COST ## AUDIO COST
if ( if not is_text_tokens_total and audio_tokens is not None and audio_tokens > 0:
_output_cost_per_audio_token is not None _output_cost_per_audio_token = (
and audio_tokens is not None _output_cost_per_audio_token
and audio_tokens > 0 if _output_cost_per_audio_token is not None
): else completion_base_cost
)
completion_cost += float(audio_tokens) * _output_cost_per_audio_token completion_cost += float(audio_tokens) * _output_cost_per_audio_token
## REASONING COST ## REASONING COST
if ( if not is_text_tokens_total and reasoning_tokens and reasoning_tokens > 0:
_output_cost_per_reasoning_token is not None _output_cost_per_reasoning_token = (
and reasoning_tokens _output_cost_per_reasoning_token
and reasoning_tokens > 0 if _output_cost_per_reasoning_token is not None
): else completion_base_cost
)
completion_cost += float(reasoning_tokens) * _output_cost_per_reasoning_token completion_cost += float(reasoning_tokens) * _output_cost_per_reasoning_token
return prompt_cost, completion_cost return prompt_cost, completion_cost

View file

@ -587,14 +587,15 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
_content_str += "data:{};base64,{}".format( _content_str += "data:{};base64,{}".format(
part["inlineData"]["mimeType"], part["inlineData"]["data"] part["inlineData"]["mimeType"], part["inlineData"]["data"]
) )
if part.get("thought") is True: if len(_content_str) > 0:
if reasoning_content_str is None: if part.get("thought") is True:
reasoning_content_str = "" if reasoning_content_str is None:
reasoning_content_str += _content_str reasoning_content_str = ""
else: reasoning_content_str += _content_str
if content_str is None: else:
content_str = "" if content_str is None:
content_str += _content_str content_str = ""
content_str += _content_str
return content_str, reasoning_content_str return content_str, reasoning_content_str

View file

@ -7,15 +7,18 @@ from litellm.responses.litellm_completion_transformation.transformation import (
) )
from litellm.responses.streaming_iterator import ResponsesAPIStreamingIterator from litellm.responses.streaming_iterator import ResponsesAPIStreamingIterator
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
OutputTextDeltaEvent,
ResponseCompletedEvent, ResponseCompletedEvent,
ResponseInputParam, ResponseInputParam,
ResponsesAPIOptionalRequestParams, ResponsesAPIOptionalRequestParams,
ResponsesAPIStreamEvents, ResponsesAPIStreamEvents,
ResponsesAPIStreamingResponse, ResponsesAPIStreamingResponse,
) )
from litellm.types.utils import Delta as ChatCompletionDelta
from litellm.types.utils import ( from litellm.types.utils import (
ModelResponse, ModelResponse,
ModelResponseStream, ModelResponseStream,
StreamingChoices,
TextCompletionResponse, TextCompletionResponse,
) )
@ -38,7 +41,7 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator):
self.responses_api_request: ResponsesAPIOptionalRequestParams = ( self.responses_api_request: ResponsesAPIOptionalRequestParams = (
responses_api_request responses_api_request
) )
self.collected_chunks: List[ModelResponseStream] = [] self.collected_chat_completion_chunks: List[ModelResponseStream] = []
self.finished: bool = False self.finished: bool = False
async def __anext__( async def __anext__(
@ -51,7 +54,14 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator):
# Get the next chunk from the stream # Get the next chunk from the stream
try: try:
chunk = await self.litellm_custom_stream_wrapper.__anext__() chunk = await self.litellm_custom_stream_wrapper.__anext__()
self.collected_chunks.append(chunk) self.collected_chat_completion_chunks.append(chunk)
response_api_chunk = (
self._transform_chat_completion_chunk_to_response_api_chunk(
chunk
)
)
if response_api_chunk:
return response_api_chunk
except StopAsyncIteration: except StopAsyncIteration:
self.finished = True self.finished = True
response_completed_event = self._emit_response_completed_event() response_completed_event = self._emit_response_completed_event()
@ -74,28 +84,65 @@ class LiteLLMCompletionStreamingIterator(ResponsesAPIStreamingIterator):
try: try:
while True: while True:
if self.finished is True: if self.finished is True:
raise StopAsyncIteration raise StopIteration
# Get the next chunk from the stream # Get the next chunk from the stream
try: try:
chunk = self.litellm_custom_stream_wrapper.__next__() chunk = self.litellm_custom_stream_wrapper.__next__()
self.collected_chunks.append(chunk) self.collected_chat_completion_chunks.append(chunk)
except StopAsyncIteration: response_api_chunk = (
self._transform_chat_completion_chunk_to_response_api_chunk(
chunk
)
)
if response_api_chunk:
return response_api_chunk
except StopIteration:
self.finished = True self.finished = True
response_completed_event = self._emit_response_completed_event() response_completed_event = self._emit_response_completed_event()
if response_completed_event: if response_completed_event:
return response_completed_event return response_completed_event
else: else:
raise StopAsyncIteration raise StopIteration
except Exception as e: except Exception as e:
# Handle HTTP errors # Handle HTTP errors
self.finished = True self.finished = True
raise e raise e
def _transform_chat_completion_chunk_to_response_api_chunk(
self, chunk: ModelResponseStream
) -> Optional[ResponsesAPIStreamingResponse]:
"""
Transform a chat completion chunk to a response API chunk.
This currently only handles emitting the OutputTextDeltaEvent, which is used by other tools using the responses API.
"""
return OutputTextDeltaEvent(
type=ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA,
item_id=chunk.id,
output_index=0,
content_index=0,
delta=self._get_delta_string_from_streaming_choices(chunk.choices),
)
def _get_delta_string_from_streaming_choices(
self, choices: List[StreamingChoices]
) -> str:
"""
Get the delta string from the streaming choices
For now this collected the first choice's delta string.
It's unclear how users expect litellm to translate multiple-choices-per-chunk to the responses API output.
"""
choice = choices[0]
chat_completion_delta: ChatCompletionDelta = choice.delta
return chat_completion_delta.content or ""
def _emit_response_completed_event(self) -> Optional[ResponseCompletedEvent]: def _emit_response_completed_event(self) -> Optional[ResponseCompletedEvent]:
litellm_model_response: Optional[ litellm_model_response: Optional[
Union[ModelResponse, TextCompletionResponse] Union[ModelResponse, TextCompletionResponse]
] = stream_chunk_builder(chunks=self.collected_chunks) ] = stream_chunk_builder(chunks=self.collected_chat_completion_chunks)
if litellm_model_response and isinstance(litellm_model_response, ModelResponse): if litellm_model_response and isinstance(litellm_model_response, ModelResponse):
return ResponseCompletedEvent( return ResponseCompletedEvent(

View file

@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union
from openai.types.responses.tool_param import FunctionToolParam from openai.types.responses.tool_param import FunctionToolParam
from litellm.caching import InMemoryCache from litellm.caching import InMemoryCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.responses.litellm_completion_transformation.session_handler import ( from litellm.responses.litellm_completion_transformation.session_handler import (
ResponsesAPISessionElement, ResponsesAPISessionElement,
SessionHandler, SessionHandler,
@ -88,6 +89,18 @@ class LiteLLMCompletionResponsesConfig:
"custom_llm_provider": custom_llm_provider, "custom_llm_provider": custom_llm_provider,
} }
# Responses API `Completed` events require usage, we pass `stream_options` to litellm.completion to include usage
if stream is True:
stream_options = {
"include_usage": True,
}
litellm_completion_request["stream_options"] = stream_options
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get(
"litellm_logging_obj"
)
if litellm_logging_obj:
litellm_logging_obj.stream_options = stream_options
# only pass non-None values # only pass non-None values
litellm_completion_request = { litellm_completion_request = {
k: v for k, v in litellm_completion_request.items() if v is not None k: v for k, v in litellm_completion_request.items() if v is not None

View file

@ -11,7 +11,9 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.thread_pool_executor import executor from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
OutputTextDeltaEvent,
ResponseCompletedEvent, ResponseCompletedEvent,
ResponsesAPIResponse,
ResponsesAPIStreamEvents, ResponsesAPIStreamEvents,
ResponsesAPIStreamingResponse, ResponsesAPIStreamingResponse,
) )
@ -212,9 +214,14 @@ class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator): class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
""" """
mock iterator - some models like o1-pro do not support streaming, we need to fake a stream Mock iteratorfake a stream by slicing the full response text into
5 char deltas, then emit a completed event.
Models like o1-pro don't support streaming, so we fake it.
""" """
CHUNK_SIZE = 5
def __init__( def __init__(
self, self,
response: httpx.Response, response: httpx.Response,
@ -222,49 +229,68 @@ class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
responses_api_provider_config: BaseResponsesAPIConfig, responses_api_provider_config: BaseResponsesAPIConfig,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
): ):
self.raw_http_response = response
super().__init__( super().__init__(
response=response, response=response,
model=model, model=model,
responses_api_provider_config=responses_api_provider_config, responses_api_provider_config=responses_api_provider_config,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
self.is_done = False
# one-time transform
transformed = (
self.responses_api_provider_config.transform_response_api_response(
model=self.model,
raw_response=response,
logging_obj=logging_obj,
)
)
full_text = self._collect_text(transformed)
# build a list of 5char delta events
deltas = [
OutputTextDeltaEvent(
type=ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA,
delta=full_text[i : i + self.CHUNK_SIZE],
item_id=transformed.id,
output_index=0,
content_index=0,
)
for i in range(0, len(full_text), self.CHUNK_SIZE)
]
# append the completed event
self._events = deltas + [
ResponseCompletedEvent(
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed,
)
]
self._idx = 0
def __aiter__(self): def __aiter__(self):
return self return self
async def __anext__(self) -> ResponsesAPIStreamingResponse: async def __anext__(self) -> ResponsesAPIStreamingResponse:
if self.is_done: if self._idx >= len(self._events):
raise StopAsyncIteration raise StopAsyncIteration
self.is_done = True evt = self._events[self._idx]
transformed_response = ( self._idx += 1
self.responses_api_provider_config.transform_response_api_response( return evt
model=self.model,
raw_response=self.raw_http_response,
logging_obj=self.logging_obj,
)
)
return ResponseCompletedEvent(
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed_response,
)
def __iter__(self): def __iter__(self):
return self return self
def __next__(self) -> ResponsesAPIStreamingResponse: def __next__(self) -> ResponsesAPIStreamingResponse:
if self.is_done: if self._idx >= len(self._events):
raise StopIteration raise StopIteration
self.is_done = True evt = self._events[self._idx]
transformed_response = ( self._idx += 1
self.responses_api_provider_config.transform_response_api_response( return evt
model=self.model,
raw_response=self.raw_http_response, def _collect_text(self, resp: ResponsesAPIResponse) -> str:
logging_obj=self.logging_obj, out = ""
) for out_item in resp.output:
) if out_item.type == "message":
return ResponseCompletedEvent( for c in getattr(out_item, "content", []):
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED, out += c.text
response=transformed_response, return out
)

View file

@ -26,6 +26,47 @@ from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_toke
from litellm.types.utils import Usage from litellm.types.utils import Usage
def test_reasoning_tokens_no_price_set():
model = "o1-mini"
custom_llm_provider = "openai"
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
model_cost_map = litellm.model_cost[model]
usage = Usage(
completion_tokens=1578,
prompt_tokens=17,
total_tokens=1595,
completion_tokens_details=CompletionTokensDetailsWrapper(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=952,
rejected_prediction_tokens=None,
text_tokens=626,
),
prompt_tokens_details=PromptTokensDetailsWrapper(
audio_tokens=None, cached_tokens=None, text_tokens=17, image_tokens=None
),
)
prompt_cost, completion_cost = generic_cost_per_token(
model=model,
usage=usage,
custom_llm_provider="openai",
)
assert round(prompt_cost, 10) == round(
model_cost_map["input_cost_per_token"] * usage.prompt_tokens,
10,
)
print(f"completion_cost: {completion_cost}")
expected_completion_cost = (
model_cost_map["output_cost_per_token"] * usage.completion_tokens
)
print(f"expected_completion_cost: {expected_completion_cost}")
assert round(completion_cost, 10) == round(
expected_completion_cost,
10,
)
def test_reasoning_tokens_gemini(): def test_reasoning_tokens_gemini():
model = "gemini-2.5-flash-preview-04-17" model = "gemini-2.5-flash-preview-04-17"
custom_llm_provider = "gemini" custom_llm_provider = "gemini"

View file

@ -239,3 +239,23 @@ def test_vertex_ai_thinking_output_part():
content, reasoning_content = v.get_assistant_content_message(parts=parts) content, reasoning_content = v.get_assistant_content_message(parts=parts)
assert content == "Hello world" assert content == "Hello world"
assert reasoning_content == "I'm thinking..." assert reasoning_content == "I'm thinking..."
def test_vertex_ai_empty_content():
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.vertex_ai import HttpxPartType
v = VertexGeminiConfig()
parts = [
HttpxPartType(
functionCall={
"name": "get_current_weather",
"arguments": "{}",
},
),
]
content, reasoning_content = v.get_assistant_content_message(parts=parts)
assert content is None
assert reasoning_content is None

View file

@ -133,11 +133,13 @@ class BaseResponsesAPITest(ABC):
validate_responses_api_response(response, final_chunk=True) validate_responses_api_response(response, final_chunk=True)
@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_openai_responses_api_streaming(self, sync_mode): async def test_basic_openai_responses_api_streaming(self, sync_mode):
litellm._turn_on_debug() litellm._turn_on_debug()
base_completion_call_args = self.get_base_completion_call_args() base_completion_call_args = self.get_base_completion_call_args()
collected_content_string = ""
response_completed_event = None
if sync_mode: if sync_mode:
response = litellm.responses( response = litellm.responses(
input="Basic ping", input="Basic ping",
@ -146,6 +148,10 @@ class BaseResponsesAPITest(ABC):
) )
for event in response: for event in response:
print("litellm response=", json.dumps(event, indent=4, default=str)) print("litellm response=", json.dumps(event, indent=4, default=str))
if event.type == "response.output_text.delta":
collected_content_string += event.delta
elif event.type == "response.completed":
response_completed_event = event
else: else:
response = await litellm.aresponses( response = await litellm.aresponses(
input="Basic ping", input="Basic ping",
@ -154,5 +160,35 @@ class BaseResponsesAPITest(ABC):
) )
async for event in response: async for event in response:
print("litellm response=", json.dumps(event, indent=4, default=str)) print("litellm response=", json.dumps(event, indent=4, default=str))
if event.type == "response.output_text.delta":
collected_content_string += event.delta
elif event.type == "response.completed":
response_completed_event = event
# assert the delta chunks content had len(collected_content_string) > 0
# this content is typically rendered on chat ui's
assert len(collected_content_string) > 0
# assert the response completed event is not None
assert response_completed_event is not None
# assert the response completed event has a response
assert response_completed_event.response is not None
# assert the response completed event includes the usage
assert response_completed_event.response.usage is not None
# basic test assert the usage seems reasonable
print("response_completed_event.response.usage=", response_completed_event.response.usage)
assert response_completed_event.response.usage.input_tokens > 0 and response_completed_event.response.usage.input_tokens < 100
assert response_completed_event.response.usage.output_tokens > 0 and response_completed_event.response.usage.output_tokens < 1000
assert response_completed_event.response.usage.total_tokens > 0 and response_completed_event.response.usage.total_tokens < 1000
# total tokens should be the sum of input and output tokens
assert response_completed_event.response.usage.total_tokens == response_completed_event.response.usage.input_tokens + response_completed_event.response.usage.output_tokens

View file

@ -26,6 +26,7 @@ import {
import { message, Select, Spin, Typography, Tooltip, Input } from "antd"; import { message, Select, Spin, Typography, Tooltip, Input } from "antd";
import { makeOpenAIChatCompletionRequest } from "./chat_ui/llm_calls/chat_completion"; import { makeOpenAIChatCompletionRequest } from "./chat_ui/llm_calls/chat_completion";
import { makeOpenAIImageGenerationRequest } from "./chat_ui/llm_calls/image_generation"; import { makeOpenAIImageGenerationRequest } from "./chat_ui/llm_calls/image_generation";
import { makeOpenAIResponsesRequest } from "./chat_ui/llm_calls/responses_api";
import { fetchAvailableModels, ModelGroup } from "./chat_ui/llm_calls/fetch_models"; import { fetchAvailableModels, ModelGroup } from "./chat_ui/llm_calls/fetch_models";
import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./chat_ui/mode_endpoint_mapping"; import { litellmModeMapping, ModelMode, EndpointType, getEndpointType } from "./chat_ui/mode_endpoint_mapping";
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter"; import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
@ -137,20 +138,28 @@ const ChatUI: React.FC<ChatUIProps> = ({
}, [chatHistory]); }, [chatHistory]);
const updateTextUI = (role: string, chunk: string, model?: string) => { const updateTextUI = (role: string, chunk: string, model?: string) => {
setChatHistory((prevHistory) => { console.log("updateTextUI called with:", role, chunk, model);
const lastMessage = prevHistory[prevHistory.length - 1]; setChatHistory((prev) => {
const last = prev[prev.length - 1];
if (lastMessage && lastMessage.role === role && !lastMessage.isImage) { // if the last message is already from this same role, append
if (last && last.role === role && !last.isImage) {
// build a new object, but only set `model` if it wasn't there already
const updated: MessageType = {
...last,
content: last.content + chunk,
model: last.model ?? model, // ← only use the passedin model on the first chunk
};
return [...prev.slice(0, -1), updated];
} else {
// otherwise start a brand new assistant bubble
return [ return [
...prevHistory.slice(0, prevHistory.length - 1), ...prev,
{ {
...lastMessage, role,
content: lastMessage.content + chunk, content: chunk,
model model, // model set exactly once here
}, },
]; ];
} else {
return [...prevHistory, { role, content: chunk, model }];
} }
}); });
}; };
@ -297,7 +306,6 @@ const ChatUI: React.FC<ChatUIProps> = ({
try { try {
if (selectedModel) { if (selectedModel) {
// Use EndpointType enum for comparison
if (endpointType === EndpointType.CHAT) { if (endpointType === EndpointType.CHAT) {
// Create chat history for API call - strip out model field and isImage field // Create chat history for API call - strip out model field and isImage field
const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage]; const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage];
@ -323,6 +331,21 @@ const ChatUI: React.FC<ChatUIProps> = ({
selectedTags, selectedTags,
signal signal
); );
} else if (endpointType === EndpointType.RESPONSES) {
// Create chat history for API call - strip out model field and isImage field
const apiChatHistory = [...chatHistory.filter(msg => !msg.isImage).map(({ role, content }) => ({ role, content })), newUserMessage];
await makeOpenAIResponsesRequest(
apiChatHistory,
(role, delta, model) => updateTextUI(role, delta, model),
selectedModel,
effectiveApiKey,
selectedTags,
signal,
updateReasoningContent,
updateTimingData,
updateUsageData
);
} }
} }
} catch (error) { } catch (error) {
@ -592,7 +615,7 @@ const ChatUI: React.FC<ChatUIProps> = ({
onChange={(e) => setInputMessage(e.target.value)} onChange={(e) => setInputMessage(e.target.value)}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
placeholder={ placeholder={
endpointType === EndpointType.CHAT endpointType === EndpointType.CHAT || endpointType === EndpointType.RESPONSES
? "Type your message... (Shift+Enter for new line)" ? "Type your message... (Shift+Enter for new line)"
: "Describe the image you want to generate..." : "Describe the image you want to generate..."
} }

View file

@ -19,8 +19,9 @@ const EndpointSelector: React.FC<EndpointSelectorProps> = ({
}) => { }) => {
// Map endpoint types to their display labels // Map endpoint types to their display labels
const endpointOptions = [ const endpointOptions = [
{ value: EndpointType.CHAT, label: '/chat/completions' }, { value: EndpointType.CHAT, label: '/v1/chat/completions' },
{ value: EndpointType.IMAGE, label: '/images/generations' } { value: EndpointType.RESPONSES, label: '/v1/responses' },
{ value: EndpointType.IMAGE, label: '/v1/images/generations' },
]; ];
return ( return (

View file

@ -0,0 +1,131 @@
import openai from "openai";
import { message } from "antd";
import { MessageType } from "../types";
import { TokenUsage } from "../ResponseMetrics";
export async function makeOpenAIResponsesRequest(
messages: MessageType[],
updateTextUI: (role: string, delta: string, model?: string) => void,
selectedModel: string,
accessToken: string | null,
tags: string[] = [],
signal?: AbortSignal,
onReasoningContent?: (content: string) => void,
onTimingData?: (timeToFirstToken: number) => void,
onUsageData?: (usage: TokenUsage) => void
) {
if (!accessToken) {
throw new Error("API key is required");
}
// Base URL should be the current base_url
const isLocal = process.env.NODE_ENV === "development";
if (isLocal !== true) {
console.log = function () {};
}
const proxyBaseUrl = isLocal
? "http://localhost:4000"
: window.location.origin;
const client = new openai.OpenAI({
apiKey: accessToken,
baseURL: proxyBaseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: tags && tags.length > 0 ? { 'x-litellm-tags': tags.join(',') } : undefined,
});
try {
const startTime = Date.now();
let firstTokenReceived = false;
// Format messages for the API
const formattedInput = messages.map(message => ({
role: message.role,
content: message.content,
type: "message"
}));
// Create request to OpenAI responses API
// Use 'any' type to avoid TypeScript issues with the experimental API
const response = await (client as any).responses.create({
model: selectedModel,
input: formattedInput,
stream: true,
}, { signal });
for await (const event of response) {
console.log("Response event:", event);
// Use a type-safe approach to handle events
if (typeof event === 'object' && event !== null) {
// Handle output text delta
// 1) drop any “role” streams
if (event.type === "response.role.delta") {
continue;
}
// 2) only handle actual text deltas
if (event.type === "response.output_text.delta" && typeof event.delta === "string") {
const delta = event.delta;
console.log("Text delta", delta);
// skip pure whitespace/newlines
if (delta.trim().length > 0) {
updateTextUI("assistant", delta, selectedModel);
// Calculate time to first token
if (!firstTokenReceived) {
firstTokenReceived = true;
const timeToFirstToken = Date.now() - startTime;
console.log("First token received! Time:", timeToFirstToken, "ms");
if (onTimingData) {
onTimingData(timeToFirstToken);
}
}
}
}
// Handle reasoning content
if (event.type === "response.reasoning.delta" && 'delta' in event) {
const delta = event.delta;
if (typeof delta === 'string' && onReasoningContent) {
onReasoningContent(delta);
}
}
// Handle usage data at the response.completed event
if (event.type === "response.completed" && 'response' in event) {
const response_obj = event.response;
const usage = response_obj.usage;
console.log("Usage data:", usage);
if (usage && onUsageData) {
console.log("Usage data:", usage);
// Extract usage data safely
const usageData: TokenUsage = {
completionTokens: usage.output_tokens,
promptTokens: usage.input_tokens,
totalTokens: usage.total_tokens
};
// Add reasoning tokens if available
if (usage.completion_tokens_details?.reasoning_tokens) {
usageData.reasoningTokens = usage.completion_tokens_details.reasoning_tokens;
}
onUsageData(usageData);
}
}
}
}
} catch (error) {
if (signal?.aborted) {
console.log("Responses API request was cancelled");
} else {
message.error(`Error occurred while generating model response. Please try again. Error: ${error}`, 20);
}
throw error; // Re-throw to allow the caller to handle the error
}
}

View file

@ -4,6 +4,7 @@
export enum ModelMode { export enum ModelMode {
IMAGE_GENERATION = "image_generation", IMAGE_GENERATION = "image_generation",
CHAT = "chat", CHAT = "chat",
RESPONSES = "responses",
// add additional modes as needed // add additional modes as needed
} }
@ -11,6 +12,7 @@ export enum ModelMode {
export enum EndpointType { export enum EndpointType {
IMAGE = "image", IMAGE = "image",
CHAT = "chat", CHAT = "chat",
RESPONSES = "responses",
// add additional endpoint types if required // add additional endpoint types if required
} }
@ -18,6 +20,7 @@ export enum ModelMode {
export const litellmModeMapping: Record<ModelMode, EndpointType> = { export const litellmModeMapping: Record<ModelMode, EndpointType> = {
[ModelMode.IMAGE_GENERATION]: EndpointType.IMAGE, [ModelMode.IMAGE_GENERATION]: EndpointType.IMAGE,
[ModelMode.CHAT]: EndpointType.CHAT, [ModelMode.CHAT]: EndpointType.CHAT,
[ModelMode.RESPONSES]: EndpointType.RESPONSES,
}; };
export const getEndpointType = (mode: string): EndpointType => { export const getEndpointType = (mode: string): EndpointType => {