feat: introduce APIs for retrieving chat completion requests (#2145)

# What does this PR do?
This PR introduces APIs to retrieve past chat completion requests, which
will be used in the LS UI.

Our current `Telemetry` is ill-suited for this purpose as it's untyped
so we'd need to filter by obscure attribute names, making it brittle.

Since these APIs are 'provided by stack' and don't need to be
implemented by inference providers, we introduce a new InferenceProvider
class, containing the existing inference protocol, which is implemented
by inference providers.

The APIs are OpenAI-compliant, with an additional `input_messages`
field.


## Test Plan
This PR just adds the API and marks them provided_by_stack. S
tart stack server -> doesn't crash
This commit is contained in:
ehhuang 2025-05-18 21:43:19 -07:00 committed by GitHub
parent c7015d3d60
commit 047303e339
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1356 additions and 869 deletions

View file

@ -820,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@json_schema_type
class ListOpenAIChatCompletionResponse(BaseModel):
data: list[OpenAICompletionWithInputMessages]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
class Order(Enum):
asc = "asc"
desc = "desc"
@runtime_checkable
@trace_protocol
class Inference(Protocol):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
class InferenceProvider(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.
"""
API_NAMESPACE: str = "Inference"
model_store: ModelStore | None = None
@ -1062,3 +1079,39 @@ class Inference(Protocol):
:returns: An OpenAIChatCompletion.
"""
...
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET")
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
:returns: A ListOpenAIChatCompletionResponse.
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.
"""
raise NotImplementedError("Get chat completion is not implemented")

View file

@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
}
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
return {
**api_protocol_map(),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
@ -302,9 +309,6 @@ async def instantiate_provider(
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@ -342,6 +346,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -82,7 +82,7 @@ logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter

View file

@ -19,7 +19,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -59,7 +59,7 @@ logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
InferenceProvider,
NeedsRequestProviderData,
):
# TODO: avoid exposing the litellm specific model names to the user.