fix: OpenAI API - together.ai extra usage chunks

This fixes an issue where, with some models (ie the Llama 4 models),
together.ai is sending a final usage chunk for streaming responses
even if the user didn't ask to include usage.

With this change, the OpenAI API verification tests now pass 100% when
using Llama Stack as your API server and together.ai as the backend
provider.

As part of this, I also cleaned up the streaming/non-streaming return
types of the `openai_chat_completion` method to keep type checking happy.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-12 17:27:43 -04:00
parent a4b573d750
commit c014571258
12 changed files with 153 additions and 20 deletions

View file

@ -3096,11 +3096,18 @@
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
"description": "OK", "description": "Response from an OpenAI-compatible chat completion request. **OR** Chunk from a streaming response to an OpenAI-compatible chat completion request.",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenAIChatCompletion" "oneOf": [
{
"$ref": "#/components/schemas/OpenAIChatCompletion"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionChunk"
}
]
} }
} }
} }
@ -9506,6 +9513,46 @@
"title": "OpenAIChatCompletion", "title": "OpenAIChatCompletion",
"description": "Response from an OpenAI-compatible chat completion request." "description": "Response from an OpenAI-compatible chat completion request."
}, },
"OpenAIChatCompletionChunk": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "The ID of the chat completion"
},
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIChoice"
},
"description": "List of choices"
},
"object": {
"type": "string",
"const": "chat.completion.chunk",
"default": "chat.completion.chunk",
"description": "The object type, which will be \"chat.completion.chunk\""
},
"created": {
"type": "integer",
"description": "The Unix timestamp in seconds when the chat completion was created"
},
"model": {
"type": "string",
"description": "The model that was used to generate the chat completion"
}
},
"additionalProperties": false,
"required": [
"id",
"choices",
"object",
"created",
"model"
],
"title": "OpenAIChatCompletionChunk",
"description": "Chunk from a streaming response to an OpenAI-compatible chat completion request."
},
"OpenAIChoice": { "OpenAIChoice": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -2135,11 +2135,15 @@ paths:
post: post:
responses: responses:
'200': '200':
description: OK description: >-
Response from an OpenAI-compatible chat completion request. **OR** Chunk
from a streaming response to an OpenAI-compatible chat completion request.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenAIChatCompletion' oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletion'
- $ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -6507,6 +6511,41 @@ components:
title: OpenAIChatCompletion title: OpenAIChatCompletion
description: >- description: >-
Response from an OpenAI-compatible chat completion request. Response from an OpenAI-compatible chat completion request.
OpenAIChatCompletionChunk:
type: object
properties:
id:
type: string
description: The ID of the chat completion
choices:
type: array
items:
$ref: '#/components/schemas/OpenAIChoice'
description: List of choices
object:
type: string
const: chat.completion.chunk
default: chat.completion.chunk
description: >-
The object type, which will be "chat.completion.chunk"
created:
type: integer
description: >-
The Unix timestamp in seconds when the chat completion was created
model:
type: string
description: >-
The model that was used to generate the chat completion
additionalProperties: false
required:
- id
- choices
- object
- created
- model
title: OpenAIChatCompletionChunk
description: >-
Chunk from a streaming response to an OpenAI-compatible chat completion request.
OpenAIChoice: OpenAIChoice:
type: object type: object
properties: properties:

View file

@ -674,6 +674,24 @@ class OpenAIChatCompletion(BaseModel):
model: str model: str
@json_schema_type
class OpenAIChatCompletionChunk(BaseModel):
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
:param choices: List of choices
:param object: The object type, which will be "chat.completion.chunk"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
"""
id: str
choices: List[OpenAIChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
@json_schema_type @json_schema_type
class OpenAICompletionLogprobs(BaseModel): class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response. """The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
@ -954,7 +972,7 @@ class Inference(Protocol):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -39,6 +39,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -546,7 +547,7 @@ class InferenceRouter(Inference):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
logger.debug( logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
) )

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -34,6 +34,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -352,7 +353,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,

View file

@ -37,6 +37,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -345,7 +346,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model) provider_model_id = self.get_provider_model_id(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx import httpx
from ollama import AsyncClient from ollama import AsyncClient
@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -409,7 +410,7 @@ class OllamaInferenceAdapter(
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
params = { params = {
k: v k: v

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack_client import AsyncLlamaStackClient from llama_stack_client import AsyncLlamaStackClient
@ -28,6 +28,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -282,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI from openai import AsyncOpenAI
from together import AsyncTogether from together import AsyncTogether
@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -331,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,
@ -358,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
if params.get("stream", True):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging import logging
from typing import Any, AsyncGenerator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx import httpx
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -503,7 +503,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -324,7 +325,7 @@ class LiteLLMOpenAIMixin(
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id, model=model_obj.provider_resource_id,

View file

@ -8,7 +8,7 @@ import logging
import time import time
import uuid import uuid
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
@ -1196,5 +1196,5 @@ class OpenAIChatCompletionUnsupportedMixin:
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAIChatCompletion: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion") raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")