revert openai_compat changes and use OpenAIMixin for openai_chat_completion

This commit is contained in:
Swapna Lekkala 2025-09-18 16:06:53 -07:00
parent 0f5bef893a
commit a6baa7b3d4
9 changed files with 23 additions and 303 deletions

View file

@ -6372,9 +6372,6 @@
"$ref": "#/components/schemas/TokenLogProbs" "$ref": "#/components/schemas/TokenLogProbs"
}, },
"description": "Optional log probabilities for generated tokens" "description": "Optional log probabilities for generated tokens"
},
"usage": {
"$ref": "#/components/schemas/UsageInfo"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6433,31 +6430,6 @@
"title": "TokenLogProbs", "title": "TokenLogProbs",
"description": "Log probabilities for generated tokens." "description": "Log probabilities for generated tokens."
}, },
"UsageInfo": {
"type": "object",
"properties": {
"completion_tokens": {
"type": "integer",
"description": "Number of tokens generated"
},
"prompt_tokens": {
"type": "integer",
"description": "Number of tokens in the prompt"
},
"total_tokens": {
"type": "integer",
"description": "Total number of tokens processed"
}
},
"additionalProperties": false,
"required": [
"completion_tokens",
"prompt_tokens",
"total_tokens"
],
"title": "UsageInfo",
"description": "Usage information for a model."
},
"BatchCompletionRequest": { "BatchCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -10967,31 +10939,6 @@
"title": "OpenAIChatCompletionToolCallFunction", "title": "OpenAIChatCompletionToolCallFunction",
"description": "Function call details for OpenAI-compatible tool calls." "description": "Function call details for OpenAI-compatible tool calls."
}, },
"OpenAIChatCompletionUsage": {
"type": "object",
"properties": {
"prompt_tokens": {
"type": "integer",
"description": "The number of tokens in the prompt"
},
"completion_tokens": {
"type": "integer",
"description": "The number of tokens in the completion"
},
"total_tokens": {
"type": "integer",
"description": "The total number of tokens used"
}
},
"additionalProperties": false,
"required": [
"prompt_tokens",
"completion_tokens",
"total_tokens"
],
"title": "OpenAIChatCompletionUsage",
"description": "Usage information for an OpenAI-compatible chat completion response."
},
"OpenAIChoice": { "OpenAIChoice": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -11329,13 +11276,6 @@
"OpenAICompletionWithInputMessages": { "OpenAICompletionWithInputMessages": {
"type": "object", "type": "object",
"properties": { "properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
},
"description": "(Optional) List of metrics associated with the API response"
},
"id": { "id": {
"type": "string", "type": "string",
"description": "The ID of the chat completion" "description": "The ID of the chat completion"
@ -11361,9 +11301,6 @@
"type": "string", "type": "string",
"description": "The model that was used to generate the chat completion" "description": "The model that was used to generate the chat completion"
}, },
"usage": {
"$ref": "#/components/schemas/OpenAIChatCompletionUsage"
},
"input_messages": { "input_messages": {
"type": "array", "type": "array",
"items": { "items": {
@ -13125,13 +13062,6 @@
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
},
"description": "(Optional) List of metrics associated with the API response"
},
"id": { "id": {
"type": "string", "type": "string",
"description": "The ID of the chat completion" "description": "The ID of the chat completion"
@ -13157,9 +13087,6 @@
"type": "string", "type": "string",
"description": "The model that was used to generate the chat completion" "description": "The model that was used to generate the chat completion"
}, },
"usage": {
"$ref": "#/components/schemas/OpenAIChatCompletionUsage"
},
"input_messages": { "input_messages": {
"type": "array", "type": "array",
"items": { "items": {
@ -14551,13 +14478,6 @@
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
"properties": { "properties": {
"metrics": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MetricInResponse"
},
"description": "(Optional) List of metrics associated with the API response"
},
"id": { "id": {
"type": "string", "type": "string",
"description": "The ID of the chat completion" "description": "The ID of the chat completion"
@ -14582,9 +14502,6 @@
"model": { "model": {
"type": "string", "type": "string",
"description": "The model that was used to generate the chat completion" "description": "The model that was used to generate the chat completion"
},
"usage": {
"$ref": "#/components/schemas/OpenAIChatCompletionUsage"
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -4548,8 +4548,6 @@ components:
$ref: '#/components/schemas/TokenLogProbs' $ref: '#/components/schemas/TokenLogProbs'
description: >- description: >-
Optional log probabilities for generated tokens Optional log probabilities for generated tokens
usage:
$ref: '#/components/schemas/UsageInfo'
additionalProperties: false additionalProperties: false
required: required:
- completion_message - completion_message
@ -4591,25 +4589,6 @@ components:
- logprobs_by_token - logprobs_by_token
title: TokenLogProbs title: TokenLogProbs
description: Log probabilities for generated tokens. description: Log probabilities for generated tokens.
UsageInfo:
type: object
properties:
completion_tokens:
type: integer
description: Number of tokens generated
prompt_tokens:
type: integer
description: Number of tokens in the prompt
total_tokens:
type: integer
description: Total number of tokens processed
additionalProperties: false
required:
- completion_tokens
- prompt_tokens
- total_tokens
title: UsageInfo
description: Usage information for a model.
BatchCompletionRequest: BatchCompletionRequest:
type: object type: object
properties: properties:
@ -8124,26 +8103,6 @@ components:
title: OpenAIChatCompletionToolCallFunction title: OpenAIChatCompletionToolCallFunction
description: >- description: >-
Function call details for OpenAI-compatible tool calls. Function call details for OpenAI-compatible tool calls.
OpenAIChatCompletionUsage:
type: object
properties:
prompt_tokens:
type: integer
description: The number of tokens in the prompt
completion_tokens:
type: integer
description: The number of tokens in the completion
total_tokens:
type: integer
description: The total number of tokens used
additionalProperties: false
required:
- prompt_tokens
- completion_tokens
- total_tokens
title: OpenAIChatCompletionUsage
description: >-
Usage information for an OpenAI-compatible chat completion response.
OpenAIChoice: OpenAIChoice:
type: object type: object
properties: properties:
@ -8406,12 +8365,6 @@ components:
OpenAICompletionWithInputMessages: OpenAICompletionWithInputMessages:
type: object type: object
properties: properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
description: >-
(Optional) List of metrics associated with the API response
id: id:
type: string type: string
description: The ID of the chat completion description: The ID of the chat completion
@ -8434,8 +8387,6 @@ components:
type: string type: string
description: >- description: >-
The model that was used to generate the chat completion The model that was used to generate the chat completion
usage:
$ref: '#/components/schemas/OpenAIChatCompletionUsage'
input_messages: input_messages:
type: array type: array
items: items:
@ -9731,12 +9682,6 @@ components:
items: items:
type: object type: object
properties: properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
description: >-
(Optional) List of metrics associated with the API response
id: id:
type: string type: string
description: The ID of the chat completion description: The ID of the chat completion
@ -9759,8 +9704,6 @@ components:
type: string type: string
description: >- description: >-
The model that was used to generate the chat completion The model that was used to generate the chat completion
usage:
$ref: '#/components/schemas/OpenAIChatCompletionUsage'
input_messages: input_messages:
type: array type: array
items: items:
@ -10776,12 +10719,6 @@ components:
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
metrics:
type: array
items:
$ref: '#/components/schemas/MetricInResponse'
description: >-
(Optional) List of metrics associated with the API response
id: id:
type: string type: string
description: The ID of the chat completion description: The ID of the chat completion
@ -10804,8 +10741,6 @@ components:
type: string type: string
description: >- description: >-
The model that was used to generate the chat completion The model that was used to generate the chat completion
usage:
$ref: '#/components/schemas/OpenAIChatCompletionUsage'
additionalProperties: false additionalProperties: false
required: required:
- id - id

View file

@ -451,20 +451,6 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
event: ChatCompletionResponseEvent event: ChatCompletionResponseEvent
@json_schema_type
class UsageInfo(BaseModel):
"""Usage information for a model.
:param completion_tokens: Number of tokens generated
:param prompt_tokens: Number of tokens in the prompt
:param total_tokens: Total number of tokens processed
"""
completion_tokens: int
prompt_tokens: int
total_tokens: int
@json_schema_type @json_schema_type
class ChatCompletionResponse(MetricResponseMixin): class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request. """Response from a chat completion request.
@ -475,7 +461,6 @@ class ChatCompletionResponse(MetricResponseMixin):
completion_message: CompletionMessage completion_message: CompletionMessage
logprobs: list[TokenLogProbs] | None = None logprobs: list[TokenLogProbs] | None = None
usage: UsageInfo | None = None
@json_schema_type @json_schema_type
@ -833,21 +818,7 @@ class OpenAIChoice(BaseModel):
@json_schema_type @json_schema_type
class OpenAIChatCompletionUsage(BaseModel): class OpenAIChatCompletion(BaseModel):
"""Usage information for an OpenAI-compatible chat completion response.
:param prompt_tokens: The number of tokens in the prompt
:param completion_tokens: The number of tokens in the completion
:param total_tokens: The total number of tokens used
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
@json_schema_type
class OpenAIChatCompletion(MetricResponseMixin):
"""Response from an OpenAI-compatible chat completion request. """Response from an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion :param id: The ID of the chat completion
@ -862,7 +833,6 @@ class OpenAIChatCompletion(MetricResponseMixin):
object: Literal["chat.completion"] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: int created: int
model: str model: str
usage: OpenAIChatCompletionUsage | None = None
@json_schema_type @json_schema_type

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from fireworks.client import Fireworks from fireworks.client import Fireworks
@ -23,11 +23,7 @@ from llama_stack.apis.inference import (
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
@ -43,7 +39,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
@ -335,90 +330,3 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
suffix=suffix, suffix=suffix,
) )
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -61,6 +61,7 @@ MODEL_ENTRIES = [
), ),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5", provider_model_id="nomic-ai/nomic-embed-text-v1.5",
aliases=["nomic-ai/nomic-embed-text-v1.5"],
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={
"embedding_dimension": 768, "embedding_dimension": 768,

View file

@ -31,8 +31,6 @@ from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
) )
from llama_stack.apis.inference.inference import UsageInfo
try: try:
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall, ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
@ -105,7 +103,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionUsage,
OpenAICompletion, OpenAICompletion,
OpenAICompletionChoice, OpenAICompletionChoice,
OpenAIEmbeddingData, OpenAIEmbeddingData,
@ -280,11 +277,6 @@ def process_chat_completion_response(
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
choice = response.choices[0] choice = response.choices[0]
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
if choice.finish_reason == "tool_calls": if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls: if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response") raise ValueError("Tool calls are not present in the response")
@ -298,7 +290,6 @@ def process_chat_completion_response(
content=json.dumps(tool_calls, default=lambda x: x.model_dump()), content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
), ),
logprobs=None, logprobs=None,
usage=usage,
) )
else: else:
# Otherwise, return tool calls as normal # Otherwise, return tool calls as normal
@ -310,7 +301,6 @@ def process_chat_completion_response(
content="", content="",
), ),
logprobs=None, logprobs=None,
usage=usage,
) )
# TODO: This does not work well with tool calls for vLLM remote provider # TODO: This does not work well with tool calls for vLLM remote provider
@ -345,7 +335,6 @@ def process_chat_completion_response(
tool_calls=raw_message.tool_calls, tool_calls=raw_message.tool_calls,
), ),
logprobs=None, logprobs=None,
usage=usage,
) )
@ -657,7 +646,7 @@ async def convert_message_to_openai_dict_new(
arguments=json.dumps(tool.arguments), arguments=json.dumps(tool.arguments),
), ),
type="function", type="function",
).model_dump() )
for tool in message.tool_calls for tool in message.tool_calls
] ]
params = {} params = {}
@ -668,7 +657,6 @@ async def convert_message_to_openai_dict_new(
content=await _convert_message_content(message.content), content=await _convert_message_content(message.content),
**params, **params,
) )
elif isinstance(message, ToolResponseMessage): elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage( out = OpenAIChatCompletionToolMessage(
role="tool", role="tool",
@ -1387,7 +1375,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
user: str | None = None, user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages) messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format) response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params( sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens, max_tokens=max_tokens,
@ -1414,6 +1401,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
tools=tools, tools=tools,
) )
outstanding_responses.append(response) outstanding_responses.append(response)
if stream: if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
@ -1488,22 +1476,12 @@ class OpenAIChatCompletionToLlamaStackMixin:
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion: ) -> OpenAIChatCompletion:
choices = [] choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
for outstanding_response in outstanding_responses: for outstanding_response in outstanding_responses:
response = await outstanding_response response = await outstanding_response
completion_message = response.completion_message completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message) message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
# Aggregate usage data
if response.usage:
total_prompt_tokens += response.usage.prompt_tokens
total_completion_tokens += response.usage.completion_tokens
total_tokens += response.usage.total_tokens
choice = OpenAIChatCompletionChoice( choice = OpenAIChatCompletionChoice(
index=len(choices), index=len(choices),
message=message, message=message,
@ -1511,17 +1489,12 @@ class OpenAIChatCompletionToLlamaStackMixin:
) )
choices.append(choice) choices.append(choice)
usage = OpenAIChatCompletionUsage(
prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_tokens
)
return OpenAIChatCompletion( return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}", id=f"chatcmpl-{uuid.uuid4()}",
choices=choices, choices=choices,
created=int(time.time()), created=int(time.time()),
model=model, model=model,
object="chat.completion", object="chat.completion",
usage=usage,
) )

View file

@ -13,6 +13,13 @@ import pytest
from ..test_cases.test_case import TestCase from ..test_cases.test_case import TestCase
@pytest.fixture(autouse=True)
def rate_limit_delay():
"""Add delay between tests to avoid rate limiting from providers like Fireworks"""
yield
time.sleep(30) # 30 second delay after each test
def _normalize_text(text: str) -> str: def _normalize_text(text: str) -> str:
""" """
Normalize Unicode text by removing diacritical marks for comparison. Normalize Unicode text by removing diacritical marks for comparison.

View file

@ -6,6 +6,7 @@
import base64 import base64
import struct import struct
import time
import pytest import pytest
from openai import OpenAI from openai import OpenAI
@ -13,6 +14,13 @@ from openai import OpenAI
from llama_stack.core.library_client import LlamaStackAsLibraryClient from llama_stack.core.library_client import LlamaStackAsLibraryClient
@pytest.fixture(autouse=True)
def rate_limit_delay():
"""Add delay between tests to avoid rate limiting from providers like Fireworks"""
yield
time.sleep(30) # 30 second delay after each test
def decode_base64_to_floats(base64_string: str) -> list[float]: def decode_base64_to_floats(base64_string: str) -> list[float]:
"""Helper function to decode base64 string to list of float32 values.""" """Helper function to decode base64 string to list of float32 values."""
embedding_bytes = base64.b64decode(base64_string) embedding_bytes = base64.b64decode(base64_string)

View file

@ -112,9 +112,10 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="fireworks", name="fireworks",
description="Fireworks provider with a text model", description="Fireworks provider with a text model",
defaults={ defaults={
"text_model": "fireworks/accounts/fireworks/models/llama-v3p1-8b-instruct", "text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "fireworks/accounts/fireworks/models/llama-v3p2-90b-vision-instruct", "vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"embedding_model": "nomic-ai/nomic-embed-text-v1.5", "embedding_model": "nomic-ai/nomic-embed-text-v1.5",
# "embedding_model": "accounts/fireworks/models/qwen3-embedding-8b",
}, },
), ),
} }