split batch_inference from inference

This commit is contained in:
Ashwin Bharambe 2024-08-26 13:17:59 -07:00
parent 986a865e62
commit dc433f6c90
5 changed files with 75 additions and 12 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F401 F403

View file

@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: str
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...

View file

@ -185,15 +185,3 @@ class Inference(Protocol):
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...

View file

@ -9,6 +9,7 @@ from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.dataset.api import * # noqa: F403
from llama_toolchain.evaluations.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.batch_inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.observability.api import * # noqa: F403
from llama_toolchain.post_training.api import * # noqa: F403
@ -18,6 +19,7 @@ from llama_toolchain.synthetic_data_generation.api import * # noqa: F403
class LlamaStack(
Inference,
BatchInference,
AgenticSystem,
RewardScoring,
SyntheticDataGeneration,