apis, alt

# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-05-18 21:20:00 -07:00
parent c7015d3d60
commit 3bc175320b
15 changed files with 1356 additions and 869 deletions

View file

@ -820,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@json_schema_type
class ListOpenAIChatCompletionResponse(BaseModel):
data: list[OpenAICompletionWithInputMessages]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
class Order(Enum):
asc = "asc"
desc = "desc"
@runtime_checkable
@trace_protocol
class Inference(Protocol):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
class InferenceProvider(Protocol):
"""
This protocol defines the interface that should be implemented by all inference providers.
"""
API_NAMESPACE: str = "Inference"
model_store: ModelStore | None = None
@ -1062,3 +1079,39 @@ class Inference(Protocol):
:returns: An OpenAIChatCompletion.
"""
...
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET")
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
:returns: A ListOpenAIChatCompletionResponse.
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.
"""
raise NotImplementedError("Get chat completion is not implemented")

View file

@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
}
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
return {
**api_protocol_map(),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
@ -302,9 +309,6 @@ async def instantiate_provider(
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@ -342,6 +346,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -82,7 +82,7 @@ logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(
Inference,
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter

View file

@ -19,7 +19,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
@ -59,7 +59,7 @@ logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
InferenceProvider,
NeedsRequestProviderData,
):
# TODO: avoid exposing the litellm specific model names to the user.