revert openai_compat changes and use OpenAIMixin for openai_chat_completion

This commit is contained in:
Swapna Lekkala 2025-09-18 16:06:53 -07:00
parent 0f5bef893a
commit a6baa7b3d4
9 changed files with 23 additions and 303 deletions

View file

@ -6372,9 +6372,6 @@
"$ref": "#/components/schemas/TokenLogProbs"
},
"description": "Optional log probabilities for generated tokens"
},
"usage": {
"$ref": "#/components/schemas/UsageInfo"
}
},
"additionalProperties": false,
@ -6433,31 +6430,6 @@
"title": "TokenLogProbs",
"description": "Log probabilities for generated tokens."
},
"UsageInfo": {
"type": "object",
"properties": {
"completion_tokens": {
"type": "integer",
"description": "Number of tokens generated"
},
"prompt_tokens": {
"type": "integer",
"description": "Number of tokens in the prompt"
},
"total_tokens": {
"type": "integer",
"description": "Total number of tokens processed"
}
},
"additionalProperties": false,
"required": [
"completion_tokens",
"prompt_tokens",
"total_tokens"
],
"title": "UsageInfo",
"description": "Usage information for a model."
},
"BatchCompletionRequest": {
"type": "object",
"properties": {
@ -10967,31 +10939,6 @@
"title": "OpenAIChatCompletionToolCallFunction",
"description": "Function call details for OpenAI-compatible tool calls."
},
"OpenAIChatCompletionUsage": {
"type": "object",
"properties": {
"prompt_tokens": {
"type": "integer",
"description": "The number of tokens in the prompt"
},
"completion_tokens": {
"type": "integer",
"description": "The number of tokens in the completion"
},
"total_tokens": {
"type": "integer",
"description": "The total number of tokens used"
}
},
"additionalProperties": false,
"required": [
"prompt_tokens",
"completion_tokens",
"total_tokens"
],
"title": "OpenAIChatCompletionUsage",
"description": "Usage information for an OpenAI-compatible chat completion response."
},
"OpenAIChoice": {
"type": "object",
"properties": {
@ -11329,13 +11276,6 @@
"OpenAICompletionWithInputMessages": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
},
"description": "(Optional) List of metrics associated with the API response"
},
"id": {
"type": "string",
"description": "The ID of the chat completion"
@ -11361,9 +11301,6 @@
"type": "string",
"description": "The model that was used to generate the chat completion"
},
"usage": {
"$ref": "#/components/schemas/OpenAIChatCompletionUsage"
},
"input_messages": {
"type": "array",
"items": {
@ -13125,13 +13062,6 @@
"items": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
},
"description": "(Optional) List of metrics associated with the API response"
},
"id": {
"type": "string",
"description": "The ID of the chat completion"
@ -13157,9 +13087,6 @@
"type": "string",
"description": "The model that was used to generate the chat completion"
},
"usage": {
"$ref": "#/components/schemas/OpenAIChatCompletionUsage"
},
"input_messages": {
"type": "array",
"items": {
@ -14551,13 +14478,6 @@
"OpenAIChatCompletion": {
"type": "object",
"properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
},
"description": "(Optional) List of metrics associated with the API response"
},
"id": {
"type": "string",
"description": "The ID of the chat completion"
@ -14582,9 +14502,6 @@
"model": {
"type": "string",
"description": "The model that was used to generate the chat completion"
},
"usage": {
"$ref": "#/components/schemas/OpenAIChatCompletionUsage"
}
},
"additionalProperties": false,

View file

@ -4548,8 +4548,6 @@ components:
$ref: '#/components/schemas/TokenLogProbs'
description: >-
Optional log probabilities for generated tokens
usage:
$ref: '#/components/schemas/UsageInfo'
additionalProperties: false
required:
- completion_message
@ -4591,25 +4589,6 @@ components:
- logprobs_by_token
title: TokenLogProbs
description: Log probabilities for generated tokens.
UsageInfo:
type: object
properties:
completion_tokens:
type: integer
description: Number of tokens generated
prompt_tokens:
type: integer
description: Number of tokens in the prompt
total_tokens:
type: integer
description: Total number of tokens processed
additionalProperties: false
required:
- completion_tokens
- prompt_tokens
- total_tokens
title: UsageInfo
description: Usage information for a model.
BatchCompletionRequest:
type: object
properties:
@ -8124,26 +8103,6 @@ components:
title: OpenAIChatCompletionToolCallFunction
description: >-
Function call details for OpenAI-compatible tool calls.
OpenAIChatCompletionUsage:
type: object
properties:
prompt_tokens:
type: integer
description: The number of tokens in the prompt
completion_tokens:
type: integer
description: The number of tokens in the completion
total_tokens:
type: integer
description: The total number of tokens used
additionalProperties: false
required:
- prompt_tokens
- completion_tokens
- total_tokens
title: OpenAIChatCompletionUsage
description: >-
Usage information for an OpenAI-compatible chat completion response.
OpenAIChoice:
type: object
properties:
@ -8406,12 +8365,6 @@ components:
OpenAICompletionWithInputMessages:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
description: >-
(Optional) List of metrics associated with the API response
id:
type: string
description: The ID of the chat completion
@ -8434,8 +8387,6 @@ components:
type: string
description: >-
The model that was used to generate the chat completion
usage:
$ref: '#/components/schemas/OpenAIChatCompletionUsage'
input_messages:
type: array
items:
@ -9731,12 +9682,6 @@ components:
items:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
description: >-
(Optional) List of metrics associated with the API response
id:
type: string
description: The ID of the chat completion
@ -9759,8 +9704,6 @@ components:
type: string
description: >-
The model that was used to generate the chat completion
usage:
$ref: '#/components/schemas/OpenAIChatCompletionUsage'
input_messages:
type: array
items:
@ -10776,12 +10719,6 @@ components:
OpenAIChatCompletion:
type: object
properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
description: >-
(Optional) List of metrics associated with the API response
id:
type: string
description: The ID of the chat completion
@ -10804,8 +10741,6 @@ components:
type: string
description: >-
The model that was used to generate the chat completion
usage:
$ref: '#/components/schemas/OpenAIChatCompletionUsage'
additionalProperties: false
required:
- id

View file

@ -451,20 +451,6 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
event: ChatCompletionResponseEvent
@json_schema_type
class UsageInfo(BaseModel):
"""Usage information for a model.
:param completion_tokens: Number of tokens generated
:param prompt_tokens: Number of tokens in the prompt
:param total_tokens: Total number of tokens processed
"""
completion_tokens: int
prompt_tokens: int
total_tokens: int
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
@ -475,7 +461,6 @@ class ChatCompletionResponse(MetricResponseMixin):
completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None
usage: UsageInfo | None = None
@json_schema_type
@ -833,21 +818,7 @@ class OpenAIChoice(BaseModel):
@json_schema_type
class OpenAIChatCompletionUsage(BaseModel):
"""Usage information for an OpenAI-compatible chat completion response.
:param prompt_tokens: The number of tokens in the prompt
:param completion_tokens: The number of tokens in the completion
:param total_tokens: The total number of tokens used
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
@json_schema_type
class OpenAIChatCompletion(MetricResponseMixin):
class OpenAIChatCompletion(BaseModel):
"""Response from an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
@ -862,7 +833,6 @@ class OpenAIChatCompletion(MetricResponseMixin):
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator
from typing import Any
from fireworks.client import Fireworks
@ -23,11 +23,7 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -43,7 +39,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
@ -335,90 +330,3 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
prompt_logprobs=prompt_logprobs,
suffix=suffix,
)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -61,6 +61,7 @@ MODEL_ENTRIES = [
),
ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
aliases=["nomic-ai/nomic-embed-text-v1.5"],
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,

View file

@ -31,8 +31,6 @@ from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.inference.inference import UsageInfo
try:
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
@ -105,7 +103,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionUsage,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIEmbeddingData,
@ -280,11 +277,6 @@ def process_chat_completion_response(
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
choice = response.choices[0]
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response")
@ -298,7 +290,6 @@ def process_chat_completion_response(
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
usage=usage,
)
else:
# Otherwise, return tool calls as normal
@ -310,7 +301,6 @@ def process_chat_completion_response(
content="",
),
logprobs=None,
usage=usage,
)
# TODO: This does not work well with tool calls for vLLM remote provider
@ -345,7 +335,6 @@ def process_chat_completion_response(
tool_calls=raw_message.tool_calls,
),
logprobs=None,
usage=usage,
)
@ -657,7 +646,7 @@ async def convert_message_to_openai_dict_new(
arguments=json.dumps(tool.arguments),
),
type="function",
).model_dump()
)
for tool in message.tool_calls
]
params = {}
@ -668,7 +657,6 @@ async def convert_message_to_openai_dict_new(
content=await _convert_message_content(message.content),
**params,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
@ -1387,7 +1375,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
@ -1414,6 +1401,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
tools=tools,
)
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
@ -1488,22 +1476,12 @@ class OpenAIChatCompletionToLlamaStackMixin:
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
# Aggregate usage data
if response.usage:
total_prompt_tokens += response.usage.prompt_tokens
total_completion_tokens += response.usage.completion_tokens
total_tokens += response.usage.total_tokens
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
@ -1511,17 +1489,12 @@ class OpenAIChatCompletionToLlamaStackMixin:
)
choices.append(choice)
usage = OpenAIChatCompletionUsage(
prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_tokens
)
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="chat.completion",
usage=usage,
)

View file

@ -13,6 +13,13 @@ import pytest
from ..test_cases.test_case import TestCase
@pytest.fixture(autouse=True)
def rate_limit_delay():
"""Add delay between tests to avoid rate limiting from providers like Fireworks"""
yield
time.sleep(30) # 30 second delay after each test
def _normalize_text(text: str) -> str:
"""
Normalize Unicode text by removing diacritical marks for comparison.

View file

@ -6,6 +6,7 @@
import base64
import struct
import time
import pytest
from openai import OpenAI
@ -13,6 +14,13 @@ from openai import OpenAI
from llama_stack.core.library_client import LlamaStackAsLibraryClient
@pytest.fixture(autouse=True)
def rate_limit_delay():
"""Add delay between tests to avoid rate limiting from providers like Fireworks"""
yield
time.sleep(30) # 30 second delay after each test
def decode_base64_to_floats(base64_string: str) -> list[float]:
"""Helper function to decode base64 string to list of float32 values."""
embedding_bytes = base64.b64decode(base64_string)

View file

@ -112,9 +112,10 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="fireworks",
description="Fireworks provider with a text model",
defaults={
"text_model": "fireworks/accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "fireworks/accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
# "embedding_model": "accounts/fireworks/models/qwen3-embedding-8b",
},
),
}