forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR adds two methods to the Inference API: - `batch_completion` - `batch_chat_completion` The motivation is for evaluations targeting a local inference engine (like meta-reference or vllm) where batch APIs provide for a substantial amount of acceleration. Why did I not add this to `Api.batch_inference` though? That just resulted in a _lot_ more book-keeping given the structure of Llama Stack. Had I done that, I would have needed to create a notion of a "batch model" resource, setup routing based on that, etc. This does not sound ideal. So what's the future of the batch inference API? I am not sure. Maybe we can keep it for true _asynchronous_ execution. So you can submit requests, and it can return a Job instance, etc. ## Test Plan Run meta-reference-gpu using: ```bash export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000 export MODEL_PARALLEL_SIZE=4 export MAX_BATCH_SIZE=32 export MAX_SEQ_LEN=6144 LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu ``` Then run the batch inference test case.
481 lines
18 KiB
Python
481 lines
18 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
import httpx
|
|
from ollama import AsyncClient
|
|
from openai import AsyncOpenAI
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
get_sampling_options,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
process_completion_response,
|
|
process_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
completion_request_to_prompt,
|
|
content_has_media,
|
|
convert_image_content_to_url,
|
|
interleaved_content_as_str,
|
|
request_has_media,
|
|
)
|
|
|
|
from .models import model_entries
|
|
|
|
logger = get_logger(name=__name__, category="inference")
|
|
|
|
|
|
class OllamaInferenceAdapter(
|
|
Inference,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
def __init__(self, url: str) -> None:
|
|
self.register_helper = ModelRegistryHelper(model_entries)
|
|
self.url = url
|
|
|
|
@property
|
|
def client(self) -> AsyncClient:
|
|
return AsyncClient(host=self.url)
|
|
|
|
@property
|
|
def openai_client(self) -> AsyncOpenAI:
|
|
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
|
|
|
async def initialize(self) -> None:
|
|
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
|
try:
|
|
await self.client.ps()
|
|
except httpx.ConnectError as e:
|
|
raise RuntimeError(
|
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
|
) from e
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def _get_model(self, model_id: str) -> Model:
|
|
if not self.model_store:
|
|
raise ValueError("Model store not set")
|
|
return await self.model_store.get_model(model_id)
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self._get_model(model_id)
|
|
request = CompletionRequest(
|
|
model=model.provider_resource_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
return await self._nonstream_completion(request)
|
|
|
|
async def _stream_completion(
|
|
self, request: CompletionRequest
|
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_completion_stream_response(stream):
|
|
yield chunk
|
|
|
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
|
params = await self._get_params(request)
|
|
r = await self.client.generate(**params)
|
|
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
return process_completion_response(response)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self._get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
response_format=response_format,
|
|
tool_config=tool_config,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
|
sampling_options = get_sampling_options(request.sampling_params)
|
|
# This is needed since the Ollama API expects num_predict to be set
|
|
# for early truncation instead of max_tokens.
|
|
if sampling_options.get("max_tokens") is not None:
|
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
|
|
|
input_dict: dict[str, Any] = {}
|
|
media_present = request_has_media(request)
|
|
llama_model = self.register_helper.get_llama_model(request.model)
|
|
if isinstance(request, ChatCompletionRequest):
|
|
if media_present or not llama_model:
|
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
|
# flatten the list of lists
|
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
|
else:
|
|
input_dict["raw"] = True
|
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
|
request,
|
|
llama_model,
|
|
)
|
|
else:
|
|
assert not media_present, "Ollama does not support media for Completion requests"
|
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
input_dict["raw"] = True
|
|
|
|
if fmt := request.response_format:
|
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
|
input_dict["format"] = fmt.json_schema
|
|
elif isinstance(fmt, GrammarResponseFormat):
|
|
raise NotImplementedError("Grammar response format is not supported")
|
|
else:
|
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
|
|
|
params = {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"options": sampling_options,
|
|
"stream": request.stream,
|
|
}
|
|
logger.debug(f"params to ollama: {params}")
|
|
|
|
return params
|
|
|
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
params = await self._get_params(request)
|
|
if "messages" in params:
|
|
r = await self.client.chat(**params)
|
|
else:
|
|
r = await self.client.generate(**params)
|
|
|
|
if "message" in r:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
return process_chat_completion_response(response, request)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
if "messages" in params:
|
|
s = await self.client.chat(**params)
|
|
else:
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
if "message" in chunk:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
yield chunk
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
output_dimension: Optional[int] = None,
|
|
task_type: Optional[EmbeddingTaskType] = None,
|
|
) -> EmbeddingsResponse:
|
|
model = await self._get_model(model_id)
|
|
|
|
assert all(not content_has_media(content) for content in contents), (
|
|
"Ollama does not support media for embeddings"
|
|
)
|
|
response = await self.client.embed(
|
|
model=model.provider_resource_id,
|
|
input=[interleaved_content_as_str(content) for content in contents],
|
|
)
|
|
embeddings = response["embeddings"]
|
|
|
|
return EmbeddingsResponse(embeddings=embeddings)
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
model = await self.register_helper.register_model(model)
|
|
if model.model_type == ModelType.embedding:
|
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
|
await self.client.pull(model.provider_resource_id)
|
|
# we use list() here instead of ps() -
|
|
# - ps() only lists running models, not available models
|
|
# - models not currently running are run by the ollama server as needed
|
|
response = await self.client.list()
|
|
available_models = [m["model"] for m in response["models"]]
|
|
if model.provider_resource_id not in available_models:
|
|
raise ValueError(
|
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
|
)
|
|
|
|
return model
|
|
|
|
async def openai_completion(
|
|
self,
|
|
model: str,
|
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
|
best_of: Optional[int] = None,
|
|
echo: Optional[bool] = None,
|
|
frequency_penalty: Optional[float] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
guided_choice: Optional[List[str]] = None,
|
|
prompt_logprobs: Optional[int] = None,
|
|
) -> OpenAICompletion:
|
|
if not isinstance(prompt, str):
|
|
raise ValueError("Ollama does not support non-string prompts for completion")
|
|
|
|
model_obj = await self._get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"prompt": prompt,
|
|
"best_of": best_of,
|
|
"echo": echo,
|
|
"frequency_penalty": frequency_penalty,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"presence_penalty": presence_penalty,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.completions.create(**params) # type: ignore
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[OpenAIMessageParam],
|
|
frequency_penalty: Optional[float] = None,
|
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_completion_tokens: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
response_format: Optional[Dict[str, str]] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
top_logprobs: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
) -> OpenAIChatCompletion:
|
|
model_obj = await self._get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"messages": messages,
|
|
"frequency_penalty": frequency_penalty,
|
|
"function_call": function_call,
|
|
"functions": functions,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_completion_tokens": max_completion_tokens,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"parallel_tool_calls": parallel_tool_calls,
|
|
"presence_penalty": presence_penalty,
|
|
"response_format": response_format,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"tool_choice": tool_choice,
|
|
"tools": tools,
|
|
"top_logprobs": top_logprobs,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
|
|
|
async def batch_completion(
|
|
self,
|
|
model_id: str,
|
|
content_batch: List[InterleavedContent],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
):
|
|
raise NotImplementedError("Batch completion is not supported for Ollama")
|
|
|
|
async def batch_chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages_batch: List[List[Message]],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
):
|
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
|
|
|
|
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
|
async def _convert_content(content) -> dict:
|
|
if isinstance(content, ImageContentItem):
|
|
return {
|
|
"role": message.role,
|
|
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
|
|
}
|
|
else:
|
|
text = content.text if isinstance(content, TextContentItem) else content
|
|
assert isinstance(text, str)
|
|
return {
|
|
"role": message.role,
|
|
"content": text,
|
|
}
|
|
|
|
if isinstance(message.content, list):
|
|
return [await _convert_content(c) for c in message.content]
|
|
else:
|
|
return [await _convert_content(message.content)]
|