diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 84a9bc67d..348c5d869 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -3096,11 +3096,18 @@
"post": {
"responses": {
"200": {
- "description": "OK",
+ "description": "Response from an OpenAI-compatible chat completion request. **OR** Chunk from a streaming response to an OpenAI-compatible chat completion request.",
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/OpenAIChatCompletion"
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/OpenAIChatCompletion"
+ },
+ {
+ "$ref": "#/components/schemas/OpenAIChatCompletionChunk"
+ }
+ ]
}
}
}
@@ -9506,6 +9513,46 @@
"title": "OpenAIChatCompletion",
"description": "Response from an OpenAI-compatible chat completion request."
},
+ "OpenAIChatCompletionChunk": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string",
+ "description": "The ID of the chat completion"
+ },
+ "choices": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/OpenAIChoice"
+ },
+ "description": "List of choices"
+ },
+ "object": {
+ "type": "string",
+ "const": "chat.completion.chunk",
+ "default": "chat.completion.chunk",
+ "description": "The object type, which will be \"chat.completion.chunk\""
+ },
+ "created": {
+ "type": "integer",
+ "description": "The Unix timestamp in seconds when the chat completion was created"
+ },
+ "model": {
+ "type": "string",
+ "description": "The model that was used to generate the chat completion"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "id",
+ "choices",
+ "object",
+ "created",
+ "model"
+ ],
+ "title": "OpenAIChatCompletionChunk",
+ "description": "Chunk from a streaming response to an OpenAI-compatible chat completion request."
+ },
"OpenAIChoice": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 3fcc83f15..18e39601d 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2135,11 +2135,15 @@ paths:
post:
responses:
'200':
- description: OK
+ description: >-
+ Response from an OpenAI-compatible chat completion request. **OR** Chunk
+ from a streaming response to an OpenAI-compatible chat completion request.
content:
application/json:
schema:
- $ref: '#/components/schemas/OpenAIChatCompletion'
+ oneOf:
+ - $ref: '#/components/schemas/OpenAIChatCompletion'
+ - $ref: '#/components/schemas/OpenAIChatCompletionChunk'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -6507,6 +6511,41 @@ components:
title: OpenAIChatCompletion
description: >-
Response from an OpenAI-compatible chat completion request.
+ OpenAIChatCompletionChunk:
+ type: object
+ properties:
+ id:
+ type: string
+ description: The ID of the chat completion
+ choices:
+ type: array
+ items:
+ $ref: '#/components/schemas/OpenAIChoice'
+ description: List of choices
+ object:
+ type: string
+ const: chat.completion.chunk
+ default: chat.completion.chunk
+ description: >-
+ The object type, which will be "chat.completion.chunk"
+ created:
+ type: integer
+ description: >-
+ The Unix timestamp in seconds when the chat completion was created
+ model:
+ type: string
+ description: >-
+ The model that was used to generate the chat completion
+ additionalProperties: false
+ required:
+ - id
+ - choices
+ - object
+ - created
+ - model
+ title: OpenAIChatCompletionChunk
+ description: >-
+ Chunk from a streaming response to an OpenAI-compatible chat completion request.
OpenAIChoice:
type: object
properties:
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 4251d37ab..f843041b3 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -674,6 +674,24 @@ class OpenAIChatCompletion(BaseModel):
model: str
+@json_schema_type
+class OpenAIChatCompletionChunk(BaseModel):
+ """Chunk from a streaming response to an OpenAI-compatible chat completion request.
+
+ :param id: The ID of the chat completion
+ :param choices: List of choices
+ :param object: The object type, which will be "chat.completion.chunk"
+ :param created: The Unix timestamp in seconds when the chat completion was created
+ :param model: The model that was used to generate the chat completion
+ """
+
+ id: str
+ choices: List[OpenAIChoice]
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
+ created: int
+ model: str
+
+
@json_schema_type
class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
@@ -954,7 +972,7 @@ class Inference(Protocol):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index b9f363be0..91cb262a9 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -39,6 +39,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -546,7 +547,7 @@ class InferenceRouter(Inference):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 8385209f1..69bf7c863 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks
from openai import AsyncOpenAI
@@ -34,6 +34,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -352,7 +353,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index b2a244f11..15f0e72a1 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -37,6 +37,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -345,7 +346,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model)
params = await prepare_openai_completion_params(
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index a24d35ab2..e08eb1263 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
-from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
@@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -409,7 +410,7 @@ class OllamaInferenceAdapter(
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = {
k: v
diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py
index 63054ae0a..af05320b0 100644
--- a/llama_stack/providers/remote/inference/passthrough/passthrough.py
+++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack_client import AsyncLlamaStackClient
@@ -28,6 +28,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -282,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 4ebf9956e..001e6aac4 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from together import AsyncTogether
@@ -33,6 +33,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -331,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
@@ -358,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p,
user=user,
)
+ if params.get("stream", True):
+ return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
+
+ async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
+ # together.ai sometimes adds usage data to the stream, even if include_usage is False
+ # This causes an unexpected final chunk with empty choices array to be sent
+ # to clients that may not handle it gracefully.
+ include_usage = False
+ if params.get("stream_options", None):
+ include_usage = params["stream_options"].get("include_usage", False)
+ stream = await self._get_openai_client().chat.completions.create(**params)
+
+ seen_finish_reason = False
+ async for chunk in stream:
+ # Final usage chunk with no choices that the user didn't request, so discard
+ if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
+ break
+ yield chunk
+ for choice in chunk.choices:
+ if choice.finish_reason:
+ seen_finish_reason = True
+ break
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index eca68e399..2b9eae1e9 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
-from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from openai import AsyncOpenAI
@@ -503,7 +503,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 6d98a0cb4..95e8b767b 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
+ OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
@@ -324,7 +325,7 @@ class LiteLLMOpenAIMixin(
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 1fa202475..07fd75ea8 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -8,7 +8,7 @@ import logging
import time
import uuid
import warnings
-from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
@@ -1196,5 +1196,5 @@ class OpenAIChatCompletionUnsupportedMixin:
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
- ) -> OpenAIChatCompletion:
+ ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")