Test fixes in openai_compat

This commit is contained in:
Swapna Lekkala 2025-09-17 16:50:46 -07:00
parent e56a3f266c
commit d60514b57b
7 changed files with 221 additions and 7 deletions

View file

@ -451,6 +451,20 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
event: ChatCompletionResponseEvent
@json_schema_type
class UsageInfo(BaseModel):
"""Usage information for a model.
:param completion_tokens: Number of tokens generated
:param prompt_tokens: Number of tokens in the prompt
:param total_tokens: Total number of tokens processed
"""
completion_tokens: int
prompt_tokens: int
total_tokens: int
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
@ -461,6 +475,7 @@ class ChatCompletionResponse(MetricResponseMixin):
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
usage: UsageInfo | None = None
@json_schema_type
@ -818,7 +833,21 @@ class OpenAIChoice(BaseModel):
@json_schema_type
class OpenAIChatCompletion(BaseModel):
class OpenAIChatCompletionUsage(BaseModel):
"""Usage information for an OpenAI-compatible chat completion response.
:param prompt_tokens: The number of tokens in the prompt
:param completion_tokens: The number of tokens in the completion
:param total_tokens: The total number of tokens used
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
@json_schema_type
class OpenAIChatCompletion(MetricResponseMixin):
"""Response from an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
@ -833,6 +862,7 @@ class OpenAIChatCompletion(BaseModel):
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type

View file

@ -590,6 +590,7 @@ class InferenceRouter(Inference):
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
@ -739,7 +740,6 @@ class InferenceRouter(Inference):
id = None
created = None
choices_data: dict[int, dict[str, Any]] = {}
try:
async for chunk in response:
# Skip None chunks

View file

@ -130,7 +130,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = self.client.completions.create(**params)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -208,9 +208,9 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
params = await self._get_params(request)
if "messages" in params:
stream = self.client.chat.completions.create(**params)
stream = await self.client.chat.completions.create(**params)
else:
stream = self.client.completions.create(**params)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk

View file

@ -31,6 +31,8 @@ from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.inference.inference import UsageInfo
try:
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
@ -103,6 +105,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIEmbeddingData,
@ -277,6 +280,11 @@ def process_chat_completion_response(
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response")
@ -290,6 +298,7 @@ def process_chat_completion_response(
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
usage=usage,
)
else:
# Otherwise, return tool calls as normal
@ -301,6 +310,7 @@ def process_chat_completion_response(
content="",
),
logprobs=None,
usage=usage,
)
# TODO: This does not work well with tool calls for vLLM remote provider
@ -335,6 +345,7 @@ def process_chat_completion_response(
tool_calls=raw_message.tool_calls,
),
logprobs=None,
usage=usage,
)
@ -646,7 +657,7 @@ async def convert_message_to_openai_dict_new(
arguments=json.dumps(tool.arguments),
),
type="function",
)
).model_dump()
for tool in message.tool_calls
]
params = {}
@ -657,6 +668,7 @@ async def convert_message_to_openai_dict_new(
content=await _convert_message_content(message.content),
**params,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
@ -1375,6 +1387,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
@ -1401,7 +1414,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
tools=tools,
)
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
@ -1476,12 +1488,22 @@ class OpenAIChatCompletionToLlamaStackMixin:
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
# Aggregate usage data
if response.usage:
total_prompt_tokens += response.usage.prompt_tokens
total_completion_tokens += response.usage.completion_tokens
total_tokens += response.usage.total_tokens
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
@ -1489,12 +1511,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
)
choices.append(choice)
usage = OpenAIChatCompletionUsage(
prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_tokens
)
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="chat.completion",
usage=usage,
)