diff --git a/llama_toolchain/batch_inference/__init__.py b/llama_toolchain/batch_inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_toolchain/batch_inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_toolchain/batch_inference/api/__init__.py b/llama_toolchain/batch_inference/api/__init__.py new file mode 100644 index 000000000..a7e55ba91 --- /dev/null +++ b/llama_toolchain/batch_inference/api/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .api import * # noqa: F401 F403 diff --git a/llama_toolchain/batch_inference/api/api.py b/llama_toolchain/batch_inference/api/api.py new file mode 100644 index 000000000..a02815388 --- /dev/null +++ b/llama_toolchain/batch_inference/api/api.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Optional, Protocol + +from llama_models.schema_utils import json_schema_type, webmethod + +from pydantic import BaseModel, Field + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_toolchain.inference.api import * # noqa: F403 + + +@json_schema_type +class BatchCompletionRequest(BaseModel): + model: str + content_batch: List[InterleavedTextMedia] + sampling_params: Optional[SamplingParams] = SamplingParams() + logprobs: Optional[LogProbConfig] = None + + +@json_schema_type +class BatchCompletionResponse(BaseModel): + completion_message_batch: List[CompletionMessage] + + +@json_schema_type +class BatchChatCompletionRequest(BaseModel): + model: str + messages_batch: List[List[Message]] + sampling_params: Optional[SamplingParams] = SamplingParams() + + # zero-shot tool definitions as input to the model + tools: Optional[List[ToolDefinition]] = Field(default_factory=list) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) + tool_prompt_format: Optional[ToolPromptFormat] = Field( + default=ToolPromptFormat.json + ) + logprobs: Optional[LogProbConfig] = None + + +@json_schema_type +class BatchChatCompletionResponse(BaseModel): + completion_message_batch: List[CompletionMessage] + + +class BatchInference(Protocol): + @webmethod(route="/batch_inference/completion") + async def batch_completion( + self, + request: BatchCompletionRequest, + ) -> BatchCompletionResponse: ... + + @webmethod(route="/batch_inference/chat_completion") + async def batch_chat_completion( + self, + request: BatchChatCompletionRequest, + ) -> BatchChatCompletionResponse: ... diff --git a/llama_toolchain/inference/api/api.py b/llama_toolchain/inference/api/api.py index cf72ef5fd..7298cb27b 100644 --- a/llama_toolchain/inference/api/api.py +++ b/llama_toolchain/inference/api/api.py @@ -185,15 +185,3 @@ class Inference(Protocol): model: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: ... - - @webmethod(route="/inference/batch_completion") - async def batch_completion( - self, - request: BatchCompletionRequest, - ) -> BatchCompletionResponse: ... - - @webmethod(route="/inference/batch_chat_completion") - async def batch_chat_completion( - self, - request: BatchChatCompletionRequest, - ) -> BatchChatCompletionResponse: ... diff --git a/llama_toolchain/stack.py b/llama_toolchain/stack.py index afea66a0c..6ec05896d 100644 --- a/llama_toolchain/stack.py +++ b/llama_toolchain/stack.py @@ -9,6 +9,7 @@ from llama_toolchain.agentic_system.api import * # noqa: F403 from llama_toolchain.dataset.api import * # noqa: F403 from llama_toolchain.evaluations.api import * # noqa: F403 from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.batch_inference.api import * # noqa: F403 from llama_toolchain.memory.api import * # noqa: F403 from llama_toolchain.observability.api import * # noqa: F403 from llama_toolchain.post_training.api import * # noqa: F403 @@ -18,6 +19,7 @@ from llama_toolchain.synthetic_data_generation.api import * # noqa: F403 class LlamaStack( Inference, + BatchInference, AgenticSystem, RewardScoring, SyntheticDataGeneration,