llama-stack/llama_stack/providers/remote/inference/ollama/ollama.py
Nathan Weinberg cf158f2cb9
feat: allow ollama to use 'latest' if available but not specified (#1903)
# What does this PR do?
ollama's CLI supports running models via commands such as 'ollama run
llama3.2' this syntax does not work with the INFERENCE_MODEL llamastack
var as currently specifying a tag such as 'latest' is required

this commit will check to see if the 'latest' model is available and use
that model if a user passes a model name without a tag but the 'latest'
is available in ollama

## Test Plan
Behavior pre-code change
```bash
$ INFERENCE_MODEL=llama3.2 llama stack build --template ollama --image-type venv --run
...
INFO     2025-04-08 13:42:42,842 llama_stack.providers.remote.inference.ollama.ollama:80 inference: checking            
         connectivity to Ollama at `http://beanlab1.bss.redhat.com:11434`...                                            
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/server/server.py", line 502, in <module>
    main()
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/server/server.py", line 401, in main
    impls = asyncio.run(construct_stack(config))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/stack.py", line 222, in construct_stack
    await register_resources(run_config, impls)
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/stack.py", line 99, in register_resources
    await method(**obj.model_dump())
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper
    result = await method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 294, in register_model
    registered_model = await self.register_object(model)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 228, in register_object
    registered_obj = await register_object_with_provider(obj, p)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 77, in register_object_with_provider
    return await p.register_model(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper
    result = await method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nathan/ai/llama-stack/repos/llama-stack/llama_stack/providers/remote/inference/ollama/ollama.py", line 315, in register_model
    raise ValueError(
ValueError: Model 'llama3.2' is not available in Ollama. Available models: llama3.2:latest
++ error_handler 108
++ echo 'Error occurred in script at line: 108'
Error occurred in script at line: 108
++ exit 1
```

Behavior post-code change
```bash
$ INFERENCE_MODEL=llama3.2 llama stack build --template ollama --image-type venv --run
...
INFO     2025-04-08 13:58:17,365 llama_stack.providers.remote.inference.ollama.ollama:80 inference: checking            
         connectivity to Ollama at `http://beanlab1.bss.redhat.com:11434`...                                            
WARNING  2025-04-08 13:58:18,190 llama_stack.providers.remote.inference.ollama.ollama:317 inference: Imprecise provider 
         resource id was used but 'latest' is available in Ollama - using 'llama3.2:latest'                             
INFO     2025-04-08 13:58:18,191 llama_stack.providers.remote.inference.ollama.ollama:308 inference: Pulling embedding  
         model `all-minilm:latest` if necessary...                                                                      
INFO     2025-04-08 13:58:18,799 __main__:478 server: Listening on ['::', '0.0.0.0']:8321                               
INFO:     Started server process [28378]
INFO:     Waiting for application startup.
INFO     2025-04-08 13:58:18,803 __main__:148 server: Starting up                                                       
INFO:     Application startup complete.
INFO:     Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
...
```

## Documentation
Did not document this anywhere but happy to do so if there is an
appropriate place

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
2025-04-14 09:03:54 -07:00

508 lines
19 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_image_content_to_url,
interleaved_content_as_str,
request_has_media,
)
from .models import model_entries
logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(
Inference,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_entries)
self.url = url
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
@property
def openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
await self.health()
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
This method is used by initialize() and the Provider API to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
return HealthResponse(status=HealthStatus.OK)
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["raw"] = True
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.client.chat(**params)
else:
s = await self.client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self._get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
return model
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"prompt": prompt,
"best_of": best_of,
"echo": echo,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_tokens": max_tokens,
"n": n,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"messages": messages,
"frequency_penalty": frequency_penalty,
"function_call": function_call,
"functions": functions,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_logprobs": top_logprobs,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {
"role": message.role,
"content": text,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]