diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 84a9bc67d..348c5d869 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3096,11 +3096,18 @@ "post": { "responses": { "200": { - "description": "OK", + "description": "Response from an OpenAI-compatible chat completion request. **OR** Chunk from a streaming response to an OpenAI-compatible chat completion request.", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/OpenAIChatCompletion" + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIChatCompletion" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionChunk" + } + ] } } } @@ -9506,6 +9513,46 @@ "title": "OpenAIChatCompletion", "description": "Response from an OpenAI-compatible chat completion request." }, + "OpenAIChatCompletionChunk": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The ID of the chat completion" + }, + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIChoice" + }, + "description": "List of choices" + }, + "object": { + "type": "string", + "const": "chat.completion.chunk", + "default": "chat.completion.chunk", + "description": "The object type, which will be \"chat.completion.chunk\"" + }, + "created": { + "type": "integer", + "description": "The Unix timestamp in seconds when the chat completion was created" + }, + "model": { + "type": "string", + "description": "The model that was used to generate the chat completion" + } + }, + "additionalProperties": false, + "required": [ + "id", + "choices", + "object", + "created", + "model" + ], + "title": "OpenAIChatCompletionChunk", + "description": "Chunk from a streaming response to an OpenAI-compatible chat completion request." + }, "OpenAIChoice": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 3fcc83f15..18e39601d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2135,11 +2135,15 @@ paths: post: responses: '200': - description: OK + description: >- + Response from an OpenAI-compatible chat completion request. **OR** Chunk + from a streaming response to an OpenAI-compatible chat completion request. content: application/json: schema: - $ref: '#/components/schemas/OpenAIChatCompletion' + oneOf: + - $ref: '#/components/schemas/OpenAIChatCompletion' + - $ref: '#/components/schemas/OpenAIChatCompletionChunk' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -6507,6 +6511,41 @@ components: title: OpenAIChatCompletion description: >- Response from an OpenAI-compatible chat completion request. + OpenAIChatCompletionChunk: + type: object + properties: + id: + type: string + description: The ID of the chat completion + choices: + type: array + items: + $ref: '#/components/schemas/OpenAIChoice' + description: List of choices + object: + type: string + const: chat.completion.chunk + default: chat.completion.chunk + description: >- + The object type, which will be "chat.completion.chunk" + created: + type: integer + description: >- + The Unix timestamp in seconds when the chat completion was created + model: + type: string + description: >- + The model that was used to generate the chat completion + additionalProperties: false + required: + - id + - choices + - object + - created + - model + title: OpenAIChatCompletionChunk + description: >- + Chunk from a streaming response to an OpenAI-compatible chat completion request. OpenAIChoice: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4251d37ab..f843041b3 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -674,6 +674,24 @@ class OpenAIChatCompletion(BaseModel): model: str +@json_schema_type +class OpenAIChatCompletionChunk(BaseModel): + """Chunk from a streaming response to an OpenAI-compatible chat completion request. + + :param id: The ID of the chat completion + :param choices: List of choices + :param object: The object type, which will be "chat.completion.chunk" + :param created: The Unix timestamp in seconds when the chat completion was created + :param model: The model that was used to generate the chat completion + """ + + id: str + choices: List[OpenAIChoice] + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + + @json_schema_type class OpenAICompletionLogprobs(BaseModel): """The log probabilities for the tokens in the message from an OpenAI-compatible completion response. @@ -954,7 +972,7 @@ class Inference(Protocol): top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: """Generate an OpenAI-compatible chat completion for the given messages using the specified model. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index b9f363be0..91cb262a9 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -39,6 +39,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -546,7 +547,7 @@ class InferenceRouter(Inference): top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 8385209f1..69bf7c863 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from fireworks.client import Fireworks from openai import AsyncOpenAI @@ -34,6 +34,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -352,7 +353,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index b2a244f11..15f0e72a1 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -37,6 +37,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -345,7 +346,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: provider_model_id = self.get_provider_model_id(model) params = await prepare_openai_completion_params( diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a24d35ab2..e08eb1263 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import httpx from ollama import AsyncClient @@ -41,6 +41,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -409,7 +410,7 @@ class OllamaInferenceAdapter( top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self._get_model(model) params = { k: v diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 63054ae0a..af05320b0 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack_client import AsyncLlamaStackClient @@ -28,6 +28,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -282,7 +283,7 @@ class PassthroughInferenceAdapter(Inference): top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: client = self._get_client() model_obj = await self.model_store.get_model(model) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 4ebf9956e..001e6aac4 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from openai import AsyncOpenAI from together import AsyncTogether @@ -33,6 +33,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -331,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, @@ -358,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi top_p=top_p, user=user, ) + if params.get("stream", True): + return self._stream_openai_chat_completion(params) return await self._get_openai_client().chat.completions.create(**params) # type: ignore + + async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: + # together.ai sometimes adds usage data to the stream, even if include_usage is False + # This causes an unexpected final chunk with empty choices array to be sent + # to clients that may not handle it gracefully. + include_usage = False + if params.get("stream_options", None): + include_usage = params["stream_options"].get("include_usage", False) + stream = await self._get_openai_client().chat.completions.create(**params) + + seen_finish_reason = False + async for chunk in stream: + # Final usage chunk with no choices that the user didn't request, so discard + if not include_usage and seen_finish_reason and len(chunk.choices) == 0: + break + yield chunk + for choice in chunk.choices: + if choice.finish_reason: + seen_finish_reason = True + break diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index eca68e399..2b9eae1e9 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json import logging -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import httpx from openai import AsyncOpenAI @@ -503,7 +503,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6d98a0cb4..95e8b767b 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAICompletion, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -324,7 +325,7 @@ class LiteLLMOpenAIMixin( top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1fa202475..07fd75ea8 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -8,7 +8,7 @@ import logging import time import uuid import warnings -from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union +from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterable, List, Optional, Union from openai import AsyncStream from openai.types.chat import ( @@ -1196,5 +1196,5 @@ class OpenAIChatCompletionUnsupportedMixin: top_logprobs: Optional[int] = None, top_p: Optional[float] = None, user: Optional[str] = None, - ) -> OpenAIChatCompletion: + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")