feat: add batch inference API to llama stack inference

This commit is contained in:
Ashwin Bharambe 2025-04-08 13:50:52 -07:00
parent ed58a94b30
commit 0cfb2e2473
24 changed files with 1041 additions and 377 deletions

View file

@ -6,37 +6,24 @@
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionResponse,
BatchChatCompletionResponse,
BatchCompletionResponse,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion(
@webmethod(route="/batch-inference/completion-inline", method="POST")
async def batch_completion_inline(
self,
model: str,
content_batch: List[InterleavedContent],
@ -45,16 +32,14 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
@webmethod(route="/batch-inference/chat-completion-inline", method="POST")
async def batch_chat_completion_inline(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...

View file

@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@ -716,6 +726,17 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
@ -756,6 +777,19 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,