From baeaf7dfe0a4fc05b757934865d3a4b1a636a2c8 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 9 Sep 2025 13:45:58 -0400 Subject: [PATCH] chore: update the ollama inference impl to use OpenAIMixin for openai-compat functions --- .../remote/inference/ollama/ollama.py | 157 +++--------------- 1 file changed, 24 insertions(+), 133 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d3d107e1d..67a22cbe3 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -7,12 +7,10 @@ import asyncio import base64 -import uuid from collections.abc import AsyncGenerator, AsyncIterator from typing import Any -from ollama import AsyncClient # type: ignore[attr-defined] -from openai import AsyncOpenAI +from ollama import AsyncClient as AsyncOllamaClient from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -37,9 +35,6 @@ from llama_stack.apis.inference import ( Message, OpenAIChatCompletion, OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ResponseFormat, @@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, - b64_encode_openai_embeddings_response, get_sampling_options, prepare_openai_completion_params, - prepare_openai_embeddings_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, @@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama") class OllamaInferenceAdapter( + OpenAIMixin, InferenceProvider, ModelsProtocolPrivate, ): @@ -98,23 +93,21 @@ class OllamaInferenceAdapter( def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.config = config - self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} - self._openai_client = None + self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {} @property - def client(self) -> AsyncClient: + def ollama_client(self) -> AsyncOllamaClient: # ollama client attaches itself to the current event loop (sadly?) loop = asyncio.get_running_loop() if loop not in self._clients: - self._clients[loop] = AsyncClient(host=self.config.url) + self._clients[loop] = AsyncOllamaClient(host=self.config.url) return self._clients[loop] - @property - def openai_client(self) -> AsyncOpenAI: - if self._openai_client is None: - url = self.config.url.rstrip("/") - self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama") - return self._openai_client + def get_api_key(self): + return "NO_KEY" + + def get_base_url(self): + return self.config.url.rstrip("/") + "/v1" async def initialize(self) -> None: logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") @@ -129,7 +122,7 @@ class OllamaInferenceAdapter( async def list_models(self) -> list[Model] | None: provider_id = self.__provider_id__ - response = await self.client.list() + response = await self.ollama_client.list() # always add the two embedding models which can be pulled on demand models = [ @@ -189,7 +182,7 @@ class OllamaInferenceAdapter( HealthResponse: A dictionary containing the health status. """ try: - await self.client.ps() + await self.ollama_client.ps() return HealthResponse(status=HealthStatus.OK) except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") @@ -238,7 +231,7 @@ class OllamaInferenceAdapter( params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): - s = await self.client.generate(**params) + s = await self.ollama_client.generate(**params) async for chunk in s: choice = OpenAICompatCompletionChoice( finish_reason=chunk["done_reason"] if chunk["done"] else None, @@ -254,7 +247,7 @@ class OllamaInferenceAdapter( async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) - r = await self.client.generate(**params) + r = await self.ollama_client.generate(**params) choice = OpenAICompatCompletionChoice( finish_reason=r["done_reason"] if r["done"] else None, @@ -346,9 +339,9 @@ class OllamaInferenceAdapter( async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await self.client.chat(**params) + r = await self.ollama_client.chat(**params) else: - r = await self.client.generate(**params) + r = await self.ollama_client.generate(**params) if "message" in r: choice = OpenAICompatCompletionChoice( @@ -372,9 +365,9 @@ class OllamaInferenceAdapter( async def _generate_and_convert_to_openai_compat(): if "messages" in params: - s = await self.client.chat(**params) + s = await self.ollama_client.chat(**params) else: - s = await self.client.generate(**params) + s = await self.ollama_client.generate(**params) async for chunk in s: if "message" in chunk: choice = OpenAICompatCompletionChoice( @@ -407,7 +400,7 @@ class OllamaInferenceAdapter( assert all(not content_has_media(content) for content in contents), ( "Ollama does not support media for embeddings" ) - response = await self.client.embed( + response = await self.ollama_client.embed( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], ) @@ -422,14 +415,14 @@ class OllamaInferenceAdapter( pass # Ignore statically unknown model, will check live listing if model.model_type == ModelType.embedding: - response = await self.client.list() + response = await self.ollama_client.list() if model.provider_resource_id not in [m.model for m in response.models]: - await self.client.pull(model.provider_resource_id) + await self.ollama_client.pull(model.provider_resource_id) # we use list() here instead of ps() - # - ps() only lists running models, not available models # - models not currently running are run by the ollama server as needed - response = await self.client.list() + response = await self.ollama_client.list() available_models = [m.model for m in response.models] provider_resource_id = model.provider_resource_id @@ -448,90 +441,6 @@ class OllamaInferenceAdapter( return model - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - model_obj = await self._get_model(model) - if model_obj.provider_resource_id is None: - raise ValueError(f"Model {model} has no provider_resource_id set") - - # Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters - params = prepare_openai_embeddings_params( - model=model_obj.provider_resource_id, - input=input, - encoding_format=encoding_format, - dimensions=dimensions, - user=user, - ) - - response = await self.openai_client.embeddings.create(**params) - data = b64_encode_openai_embeddings_response(response.data, encoding_format) - - usage = OpenAIEmbeddingUsage( - prompt_tokens=response.usage.prompt_tokens, - total_tokens=response.usage.total_tokens, - ) - # TODO: Investigate why model_obj.identifier is used instead of response.model - return OpenAIEmbeddingsResponse( - data=data, - model=model_obj.identifier, - usage=usage, - ) - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - if not isinstance(prompt, str): - raise ValueError("Ollama does not support non-string prompts for completion") - - model_obj = await self._get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - suffix=suffix, - ) - return await self.openai_client.completions.create(**params) # type: ignore - async def openai_chat_completion( self, model: str, @@ -599,25 +508,7 @@ class OllamaInferenceAdapter( top_p=top_p, user=user, ) - response = await self.openai_client.chat.completions.create(**params) - return await self._adjust_ollama_chat_completion_response_ids(response) - - async def _adjust_ollama_chat_completion_response_ids( - self, - response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk], - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - id = f"chatcmpl-{uuid.uuid4()}" - if isinstance(response, AsyncIterator): - - async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]: - async for chunk in response: - chunk.id = id - yield chunk - - return stream_with_chunk_ids() - else: - response.id = id - return response + return await OpenAIMixin.openai_chat_completion(self, **params) async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: