chore: update the ollama inference impl to use OpenAIMixin for openai-compat functions

This commit is contained in:
Matthew Farrellee 2025-09-09 13:45:58 -04:00
parent c86e45496e
commit baeaf7dfe0

View file

@ -7,12 +7,10 @@
import asyncio import asyncio
import base64 import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from ollama import AsyncClient # type: ignore[attr-defined] from ollama import AsyncClient as AsyncOllamaClient
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
Message, Message,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
self._openai_client = None
@property @property
def client(self) -> AsyncClient: def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?) # ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if loop not in self._clients: if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url) self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop] return self._clients[loop]
@property def get_api_key(self):
def openai_client(self) -> AsyncOpenAI: return "NO_KEY"
if self._openai_client is None:
url = self.config.url.rstrip("/") def get_base_url(self):
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama") return self.config.url.rstrip("/") + "/v1"
return self._openai_client
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None: async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__ provider_id = self.__provider_id__
response = await self.client.list() response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand # always add the two embedding models which can be pulled on demand
models = [ models = [
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
await self.client.ps() await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
except Exception as e: except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
params = await self._get_params(request) params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params) s = await self.ollama_client.generate(**params)
async for chunk in s: async for chunk in s:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None, finish_reason=chunk["done_reason"] if chunk["done"] else None,
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self.client.generate(**params) r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None, finish_reason=r["done_reason"] if r["done"] else None,
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await self.client.chat(**params) r = await self.ollama_client.chat(**params)
else: else:
r = await self.client.generate(**params) r = await self.ollama_client.generate(**params)
if "message" in r: if "message" in r:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
if "messages" in params: if "messages" in params:
s = await self.client.chat(**params) s = await self.ollama_client.chat(**params)
else: else:
s = await self.client.generate(**params) s = await self.ollama_client.generate(**params)
async for chunk in s: async for chunk in s:
if "message" in chunk: if "message" in chunk:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings" "Ollama does not support media for embeddings"
) )
response = await self.client.embed( response = await self.ollama_client.embed(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
pass # Ignore statically unknown model, will check live listing pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
response = await self.client.list() response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]: if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id) await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() - # we use list() here instead of ps() -
# - ps() only lists running models, not available models # - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.ollama_client.list()
available_models = [m.model for m in response.models] available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
return model return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, model: str,
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
response = await self.openai_client.chat.completions.create(**params) return await OpenAIMixin.openai_chat_completion(self, **params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: