diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 944493b6d..dc2c8b3f5 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,10 +5,11 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union import httpx from ollama import AsyncClient +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -38,6 +39,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -45,10 +47,8 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionUnsupportedMixin, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, - OpenAICompletionUnsupportedMixin, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, @@ -70,8 +70,6 @@ logger = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter( - OpenAICompletionUnsupportedMixin, - OpenAIChatCompletionUnsupportedMixin, Inference, ModelsProtocolPrivate, ): @@ -83,6 +81,10 @@ class OllamaInferenceAdapter( def client(self) -> AsyncClient: return AsyncClient(host=self.url) + @property + def openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama") + async def initialize(self) -> None: logger.info(f"checking connectivity to Ollama at `{self.url}`...") try: @@ -326,6 +328,110 @@ class OllamaInferenceAdapter( return model + async def openai_completion( + self, + model: str, + prompt: str, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAICompletion: + model_obj = await self._get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "prompt": prompt, + "best_of": best_of, + "echo": echo, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.openai_client.completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self._get_model(model) + params = { + k: v + for k, v in { + "model": model_obj.provider_resource_id, + "messages": messages, + "frequency_penalty": frequency_penalty, + "function_call": function_call, + "functions": functions, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_completion_tokens": max_completion_tokens, + "max_tokens": max_tokens, + "n": n, + "parallel_tool_calls": parallel_tool_calls, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_logprobs": top_logprobs, + "top_p": top_p, + "user": user, + }.items() + if v is not None + } + return await self.openai_client.chat.completions.create(**params) # type: ignore + async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: