mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
# What does this PR do? remove unused chat_completion implementations vllm features ported - - requires max_tokens be set, use config value - set tool_choice to none if no tools provided ## Test Plan ci
220 lines
8 KiB
Python
220 lines
8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from ollama import AsyncClient as AsyncOllamaClient
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.common.errors import UnsupportedModelError
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
GrammarResponseFormat,
|
|
InferenceProvider,
|
|
JsonSchemaResponseFormat,
|
|
Message,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.sku_types import CoreModelId
|
|
from llama_stack.providers.datatypes import (
|
|
HealthResponse,
|
|
HealthStatus,
|
|
ModelsProtocolPrivate,
|
|
)
|
|
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
build_hf_repo_model_entry,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
get_sampling_options,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
convert_image_content_to_url,
|
|
request_has_media,
|
|
)
|
|
|
|
logger = get_logger(name=__name__, category="inference::ollama")
|
|
|
|
|
|
class OllamaInferenceAdapter(
|
|
OpenAIMixin,
|
|
ModelRegistryHelper,
|
|
InferenceProvider,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
# automatically set by the resolver when instantiating the provider
|
|
__provider_id__: str
|
|
|
|
embedding_model_metadata = {
|
|
"all-minilm:l6-v2": {
|
|
"embedding_dimension": 384,
|
|
"context_length": 512,
|
|
},
|
|
"nomic-embed-text:latest": {
|
|
"embedding_dimension": 768,
|
|
"context_length": 8192,
|
|
},
|
|
"nomic-embed-text:v1.5": {
|
|
"embedding_dimension": 768,
|
|
"context_length": 8192,
|
|
},
|
|
"nomic-embed-text:137m-v1.5-fp16": {
|
|
"embedding_dimension": 768,
|
|
"context_length": 8192,
|
|
},
|
|
}
|
|
|
|
def __init__(self, config: OllamaImplConfig) -> None:
|
|
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
|
# chat_completion are. this exists to satisfy the input /
|
|
# output processing for llama models. specifically,
|
|
# tool_calling is handled by raw template processing,
|
|
# instead of using the /api/chat endpoint w/ tools=...
|
|
ModelRegistryHelper.__init__(
|
|
self,
|
|
model_entries=[
|
|
build_hf_repo_model_entry(
|
|
"llama3.2:3b-instruct-fp16",
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
),
|
|
build_hf_repo_model_entry(
|
|
"llama-guard3:1b",
|
|
CoreModelId.llama_guard_3_1b.value,
|
|
),
|
|
],
|
|
)
|
|
self.config = config
|
|
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
|
self.download_images = True
|
|
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
|
|
|
@property
|
|
def ollama_client(self) -> AsyncOllamaClient:
|
|
# ollama client attaches itself to the current event loop (sadly?)
|
|
loop = asyncio.get_running_loop()
|
|
if loop not in self._clients:
|
|
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
|
return self._clients[loop]
|
|
|
|
def get_api_key(self):
|
|
return "NO_KEY"
|
|
|
|
def get_base_url(self):
|
|
return self.config.url.rstrip("/") + "/v1"
|
|
|
|
async def initialize(self) -> None:
|
|
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
|
r = await self.health()
|
|
if r["status"] == HealthStatus.ERROR:
|
|
logger.warning(
|
|
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
|
)
|
|
|
|
async def should_refresh_models(self) -> bool:
|
|
return self.config.refresh_models
|
|
|
|
async def health(self) -> HealthResponse:
|
|
"""
|
|
Performs a health check by verifying connectivity to the Ollama server.
|
|
This method is used by initialize() and the Provider API to verify that the service is running
|
|
correctly.
|
|
Returns:
|
|
HealthResponse: A dictionary containing the health status.
|
|
"""
|
|
try:
|
|
await self.ollama_client.ps()
|
|
return HealthResponse(status=HealthStatus.OK)
|
|
except Exception as e:
|
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
|
|
|
async def shutdown(self) -> None:
|
|
self._clients.clear()
|
|
|
|
async def _get_model(self, model_id: str) -> Model:
|
|
if not self.model_store:
|
|
raise ValueError("Model store not set")
|
|
return await self.model_store.get_model(model_id)
|
|
|
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
|
sampling_options = get_sampling_options(request.sampling_params)
|
|
# This is needed since the Ollama API expects num_predict to be set
|
|
# for early truncation instead of max_tokens.
|
|
if sampling_options.get("max_tokens") is not None:
|
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
|
|
|
input_dict: dict[str, Any] = {}
|
|
media_present = request_has_media(request)
|
|
llama_model = self.get_llama_model(request.model)
|
|
if media_present or not llama_model:
|
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
|
# flatten the list of lists
|
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
|
else:
|
|
input_dict["raw"] = True
|
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
|
request,
|
|
llama_model,
|
|
)
|
|
|
|
if fmt := request.response_format:
|
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
|
input_dict["format"] = fmt.json_schema
|
|
elif isinstance(fmt, GrammarResponseFormat):
|
|
raise NotImplementedError("Grammar response format is not supported")
|
|
else:
|
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
|
|
|
params = {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"options": sampling_options,
|
|
"stream": request.stream,
|
|
}
|
|
logger.debug(f"params to ollama: {params}")
|
|
|
|
return params
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
if await self.check_model_availability(model.provider_model_id):
|
|
return model
|
|
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
|
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
|
logger.warning(
|
|
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
|
)
|
|
return model
|
|
|
|
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
|
|
|
|
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
|
async def _convert_content(content) -> dict:
|
|
if isinstance(content, ImageContentItem):
|
|
return {
|
|
"role": message.role,
|
|
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
|
|
}
|
|
else:
|
|
text = content.text if isinstance(content, TextContentItem) else content
|
|
assert isinstance(text, str)
|
|
return {
|
|
"role": message.role,
|
|
"content": text,
|
|
}
|
|
|
|
if isinstance(message.content, list):
|
|
return [await _convert_content(c) for c in message.content]
|
|
else:
|
|
return [await _convert_content(message.content)]
|