mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: introduce APIs for retrieving chat completion requests (#2145)
# What does this PR do? This PR introduces APIs to retrieve past chat completion requests, which will be used in the LS UI. Our current `Telemetry` is ill-suited for this purpose as it's untyped so we'd need to filter by obscure attribute names, making it brittle. Since these APIs are 'provided by stack' and don't need to be implemented by inference providers, we introduce a new InferenceProvider class, containing the existing inference protocol, which is implemented by inference providers. The APIs are OpenAI-compliant, with an additional `input_messages` field. ## Test Plan This PR just adds the API and marks them provided_by_stack. S tart stack server -> doesn't crash
This commit is contained in:
parent
c7015d3d60
commit
047303e339
15 changed files with 1356 additions and 869 deletions
|
@ -820,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||
data: list[OpenAICompletionWithInputMessages]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
class InferenceProvider(Protocol):
|
||||
"""
|
||||
This protocol defines the interface that should be implemented by all inference providers.
|
||||
"""
|
||||
|
||||
API_NAMESPACE: str = "Inference"
|
||||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
|
@ -1062,3 +1079,39 @@ class Inference(Protocol):
|
|||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 20,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
||||
:returns: A ListOpenAIChatCompletionResponse.
|
||||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
"""
|
||||
raise NotImplementedError("Get chat completion is not implemented")
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
|
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
}
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||
return {
|
||||
**api_protocol_map(),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
|
@ -302,9 +309,6 @@ async def instantiate_provider(
|
|||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
@ -342,6 +346,8 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check()
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
|
|
|
@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
|
|||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -82,7 +82,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, url: str) -> None:
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import SambaNovaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .sambanova import SambaNovaCompatInferenceAdapter
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import TogetherCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .together import TogetherCompatInferenceAdapter
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -59,7 +59,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
# TODO: avoid exposing the litellm specific model names to the user.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue