This commit is contained in:
Ashwin Bharambe 2025-04-11 16:15:59 -07:00
parent 0cfb2e2473
commit 73d927850e
4 changed files with 43 additions and 316 deletions

View file

@ -6,40 +6,50 @@
from typing import List, Optional, Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolConfig,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
@webmethod(route="/batch-inference/completion-inline", method="POST")
async def batch_completion_inline(
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST")
async def completion(
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
) -> Job: ...
@webmethod(route="/batch-inference/chat-completion-inline", method="POST")
async def batch_chat_completion_inline(
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = list,
tool_config: Optional[ToolConfig] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...
) -> Job: ...

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import os
from typing import AsyncGenerator, List, Optional, Union
@ -44,6 +43,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -72,7 +72,7 @@ from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
@ -159,7 +159,7 @@ class MetaReferenceInferenceImpl(
self.model_id = model_id
self.llama_model = llama_model
print("Warming up...")
log.info("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
@ -170,7 +170,7 @@ class MetaReferenceInferenceImpl(
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
print("Warmed up!")
log.info("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None: