forked from phoenix-oss/llama-stack-mirror
The `/v1/providers` now reports the health status of each provider when implemented. ``` curl -L http://127.0.0.1:8321/v1/providers|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 4072 100 4072 0 0 246k 0 --:--:-- --:--:-- --:--:-- 248k { "data": [ { "api": "inference", "provider_id": "ollama", "provider_type": "remote::ollama", "config": { "url": "http://localhost:11434" }, "health": { "status": "OK" } }, { "api": "vector_io", "provider_id": "faiss", "provider_type": "inline::faiss", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/faiss_store.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "safety", "provider_id": "llama-guard", "provider_type": "inline::llama-guard", "config": { "excluded_categories": [] }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "agents", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "persistence_store": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/agents_store.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "telemetry", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "service_name": "llama-stack", "sinks": "console,sqlite", "sqlite_db_path": "/Users/leseb/.llama/distributions/ollama/trace_store.db" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "eval", "provider_id": "meta-reference", "provider_type": "inline::meta-reference", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/meta_reference_eval.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "datasetio", "provider_id": "huggingface", "provider_type": "remote::huggingface", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/huggingface_datasetio.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "datasetio", "provider_id": "localfs", "provider_type": "inline::localfs", "config": { "kvstore": { "type": "sqlite", "namespace": null, "db_path": "/Users/leseb/.llama/distributions/ollama/localfs_datasetio.db" } }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "basic", "provider_type": "inline::basic", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "llm-as-judge", "provider_type": "inline::llm-as-judge", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "scoring", "provider_id": "braintrust", "provider_type": "inline::braintrust", "config": { "openai_api_key": "********" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "brave-search", "provider_type": "remote::brave-search", "config": { "api_key": "********", "max_results": 3 }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "tavily-search", "provider_type": "remote::tavily-search", "config": { "api_key": "********", "max_results": 3 }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "code-interpreter", "provider_type": "inline::code-interpreter", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "rag-runtime", "provider_type": "inline::rag-runtime", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "model-context-protocol", "provider_type": "remote::model-context-protocol", "config": {}, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } }, { "api": "tool_runtime", "provider_id": "wolfram-alpha", "provider_type": "remote::wolfram-alpha", "config": { "api_key": "********" }, "health": { "status": "Not Implemented", "message": "Provider does not implement health check" } } ] } ``` Per providers too: ``` curl -L http://127.0.0.1:8321/v1/providers/ollama {"api":"inference","provider_id":"ollama","provider_type":"remote::ollama","config":{"url":"http://localhost:11434"},"health":{"status":"OK"}} ``` Signed-off-by: Sébastien Han <seb@redhat.com>
496 lines
19 KiB
Python
496 lines
19 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
import httpx
|
|
from ollama import AsyncClient
|
|
from openai import AsyncOpenAI
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import (
|
|
HealthResponse,
|
|
HealthStatus,
|
|
ModelsProtocolPrivate,
|
|
)
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
get_sampling_options,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
process_completion_response,
|
|
process_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
completion_request_to_prompt,
|
|
content_has_media,
|
|
convert_image_content_to_url,
|
|
interleaved_content_as_str,
|
|
request_has_media,
|
|
)
|
|
|
|
from .models import model_entries
|
|
|
|
logger = get_logger(name=__name__, category="inference")
|
|
|
|
|
|
class OllamaInferenceAdapter(
|
|
Inference,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
def __init__(self, url: str) -> None:
|
|
self.register_helper = ModelRegistryHelper(model_entries)
|
|
self.url = url
|
|
|
|
@property
|
|
def client(self) -> AsyncClient:
|
|
return AsyncClient(host=self.url)
|
|
|
|
@property
|
|
def openai_client(self) -> AsyncOpenAI:
|
|
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
|
|
|
async def initialize(self) -> None:
|
|
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
|
await self.health()
|
|
|
|
async def health(self) -> HealthResponse:
|
|
"""
|
|
Performs a health check by verifying connectivity to the Ollama server.
|
|
This method is used by initialize() and the Provider API to verify that the service is running
|
|
correctly.
|
|
Returns:
|
|
HealthResponse: A dictionary containing the health status.
|
|
"""
|
|
try:
|
|
await self.client.ps()
|
|
return HealthResponse(status=HealthStatus.OK)
|
|
except httpx.ConnectError as e:
|
|
raise RuntimeError(
|
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
|
) from e
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def _get_model(self, model_id: str) -> Model:
|
|
if not self.model_store:
|
|
raise ValueError("Model store not set")
|
|
return await self.model_store.get_model(model_id)
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self._get_model(model_id)
|
|
request = CompletionRequest(
|
|
model=model.provider_resource_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
return await self._nonstream_completion(request)
|
|
|
|
async def _stream_completion(
|
|
self, request: CompletionRequest
|
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_completion_stream_response(stream):
|
|
yield chunk
|
|
|
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
|
params = await self._get_params(request)
|
|
r = await self.client.generate(**params)
|
|
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
return process_completion_response(response)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self._get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
response_format=response_format,
|
|
tool_config=tool_config,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
|
sampling_options = get_sampling_options(request.sampling_params)
|
|
# This is needed since the Ollama API expects num_predict to be set
|
|
# for early truncation instead of max_tokens.
|
|
if sampling_options.get("max_tokens") is not None:
|
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
|
|
|
input_dict: dict[str, Any] = {}
|
|
media_present = request_has_media(request)
|
|
llama_model = self.register_helper.get_llama_model(request.model)
|
|
if isinstance(request, ChatCompletionRequest):
|
|
if media_present or not llama_model:
|
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
|
# flatten the list of lists
|
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
|
else:
|
|
input_dict["raw"] = True
|
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
|
request,
|
|
llama_model,
|
|
)
|
|
else:
|
|
assert not media_present, "Ollama does not support media for Completion requests"
|
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
input_dict["raw"] = True
|
|
|
|
if fmt := request.response_format:
|
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
|
input_dict["format"] = fmt.json_schema
|
|
elif isinstance(fmt, GrammarResponseFormat):
|
|
raise NotImplementedError("Grammar response format is not supported")
|
|
else:
|
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
|
|
|
params = {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"options": sampling_options,
|
|
"stream": request.stream,
|
|
}
|
|
logger.debug(f"params to ollama: {params}")
|
|
|
|
return params
|
|
|
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
params = await self._get_params(request)
|
|
if "messages" in params:
|
|
r = await self.client.chat(**params)
|
|
else:
|
|
r = await self.client.generate(**params)
|
|
|
|
if "message" in r:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
return process_chat_completion_response(response, request)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
if "messages" in params:
|
|
s = await self.client.chat(**params)
|
|
else:
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
if "message" in chunk:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
yield chunk
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
output_dimension: Optional[int] = None,
|
|
task_type: Optional[EmbeddingTaskType] = None,
|
|
) -> EmbeddingsResponse:
|
|
model = await self._get_model(model_id)
|
|
|
|
assert all(not content_has_media(content) for content in contents), (
|
|
"Ollama does not support media for embeddings"
|
|
)
|
|
response = await self.client.embed(
|
|
model=model.provider_resource_id,
|
|
input=[interleaved_content_as_str(content) for content in contents],
|
|
)
|
|
embeddings = response["embeddings"]
|
|
|
|
return EmbeddingsResponse(embeddings=embeddings)
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
model = await self.register_helper.register_model(model)
|
|
if model.model_type == ModelType.embedding:
|
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
|
await self.client.pull(model.provider_resource_id)
|
|
# we use list() here instead of ps() -
|
|
# - ps() only lists running models, not available models
|
|
# - models not currently running are run by the ollama server as needed
|
|
response = await self.client.list()
|
|
available_models = [m["model"] for m in response["models"]]
|
|
if model.provider_resource_id not in available_models:
|
|
raise ValueError(
|
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
|
)
|
|
|
|
return model
|
|
|
|
async def openai_completion(
|
|
self,
|
|
model: str,
|
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
|
best_of: Optional[int] = None,
|
|
echo: Optional[bool] = None,
|
|
frequency_penalty: Optional[float] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
guided_choice: Optional[List[str]] = None,
|
|
prompt_logprobs: Optional[int] = None,
|
|
) -> OpenAICompletion:
|
|
if not isinstance(prompt, str):
|
|
raise ValueError("Ollama does not support non-string prompts for completion")
|
|
|
|
model_obj = await self._get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"prompt": prompt,
|
|
"best_of": best_of,
|
|
"echo": echo,
|
|
"frequency_penalty": frequency_penalty,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"presence_penalty": presence_penalty,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.completions.create(**params) # type: ignore
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[OpenAIMessageParam],
|
|
frequency_penalty: Optional[float] = None,
|
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_completion_tokens: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
response_format: Optional[Dict[str, str]] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
top_logprobs: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
) -> OpenAIChatCompletion:
|
|
model_obj = await self._get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"messages": messages,
|
|
"frequency_penalty": frequency_penalty,
|
|
"function_call": function_call,
|
|
"functions": functions,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_completion_tokens": max_completion_tokens,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"parallel_tool_calls": parallel_tool_calls,
|
|
"presence_penalty": presence_penalty,
|
|
"response_format": response_format,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"tool_choice": tool_choice,
|
|
"tools": tools,
|
|
"top_logprobs": top_logprobs,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
|
|
|
async def batch_completion(
|
|
self,
|
|
model_id: str,
|
|
content_batch: List[InterleavedContent],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
):
|
|
raise NotImplementedError("Batch completion is not supported for Ollama")
|
|
|
|
async def batch_chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages_batch: List[List[Message]],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
):
|
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
|
|
|
|
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
|
async def _convert_content(content) -> dict:
|
|
if isinstance(content, ImageContentItem):
|
|
return {
|
|
"role": message.role,
|
|
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
|
|
}
|
|
else:
|
|
text = content.text if isinstance(content, TextContentItem) else content
|
|
assert isinstance(text, str)
|
|
return {
|
|
"role": message.role,
|
|
"content": text,
|
|
}
|
|
|
|
if isinstance(message.content, list):
|
|
return [await _convert_content(c) for c in message.content]
|
|
else:
|
|
return [await _convert_content(message.content)]
|