llama-stack-mirror/llama_stack/providers/remote/inference/ollama/ollama.py
Matthew Farrellee d266c59c2a
chore: remove deprecated inference.chat_completion implementations (#3654)
# What does this PR do?

remove unused chat_completion implementations

vllm features ported -
 - requires max_tokens be set, use config value
 - set tool_choice to none if no tools provided


## Test Plan

ci
2025-10-03 07:55:34 -04:00

220 lines
8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
Message,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
convert_image_content_to_url,
request_has_media,
)
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
InferenceProvider,
ModelsProtocolPrivate,
):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
},
"nomic-embed-text:latest": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:v1.5": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:137m-v1.5-fp16": {
"embedding_dimension": 768,
"context_length": 8192,
},
}
def __init__(self, config: OllamaImplConfig) -> None:
# TODO: remove ModelRegistryHelper.__init__ when completion and
# chat_completion are. this exists to satisfy the input /
# output processing for llama models. specifically,
# tool_calling is handled by raw template processing,
# instead of using the /api/chat endpoint w/ tools=...
ModelRegistryHelper.__init__(
self,
model_entries=[
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop]
def get_api_key(self):
return "NO_KEY"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
r = await self.health()
if r["status"] == HealthStatus.ERROR:
logger.warning(
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
)
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
This method is used by initialize() and the Provider API to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None:
self._clients.clear()
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def _get_params(self, request: ChatCompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
model.provider_resource_id = f"{model.provider_model_id}:latest"
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
)
return model
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {
"role": message.role,
"content": text,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]