chore: update the ollama inference impl to use OpenAIMixin for openai-compat functions (#3395)

# What does this PR do?

update Ollama inference provider to use OpenAIMixin for openai-compat
endpoints

## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-09-18 07:09:57 -04:00 committed by GitHub
parent 521865c388
commit ea396a54cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 1100 additions and 133 deletions

View file

@ -7,12 +7,10 @@
import asyncio
import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options,
prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def client(self) -> AsyncClient:
def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
if self._openai_client is None:
url = self.config.url.rstrip("/")
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
return self._openai_client
def get_api_key(self):
return "NO_KEY"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.client.list()
response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand
models = [
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
r = await self.ollama_client.chat(**params)
else:
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.client.chat(**params)
s = await self.ollama_client.chat(**params)
else:
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
response = await self.ollama_client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding:
response = await self.client.list()
response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
response = await self.ollama_client.list()
available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
top_p=top_p,
user=user,
)
response = await self.openai_client.chat.completions.create(**params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: