diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index acf154627..908654db6 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional import httpx from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.datatypes import Message, ModelDef from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -59,7 +59,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): pass async def register_model(self, model: ModelDef) -> None: - raise ValueError("Dynamic model registration is not supported") + """ + Registers and loads the specified model using ollama.client.pull(). + """ + ollama_model = OLLAMA_SUPPORTED_MODELS.get(model.llama_model) + if not ollama_model: + raise ValueError(f"Model {model.llama_model} is not supported by Ollama.") + + print(f"Registering model: {ollama_model}") + status = await self.client.pull(ollama_model) + assert ( + status["status"] == "success" + ), f"Failed to register model {ollama_model} in Ollama." async def list_models(self) -> List[ModelDef]: ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} @@ -88,7 +99,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model: str, content: InterleavedTextMedia, - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: @@ -98,20 +109,20 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model: str, messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), + sampling_params: Optional[SamplingParams] = None, tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_choice: Optional[ToolChoice] = None, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( model=model, messages=messages, - sampling_params=sampling_params, + sampling_params=sampling_params or SamplingParams(), tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, + tool_choice=tool_choice or ToolChoice.auto, + tool_prompt_format=tool_prompt_format or ToolPromptFormat.json, stream=stream, logprobs=logprobs, ) @@ -133,17 +144,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = self._get_params(request) - r = await self.client.generate(**params) - assert isinstance(r, dict) + response = await self.client.generate(**params) + assert isinstance(response, dict) choice = OpenAICompatCompletionChoice( - finish_reason=r["done_reason"] if r["done"] else None, - text=r["response"], + finish_reason=response["done_reason"] if response["done"] else None, + text=response["response"], ) - response = OpenAICompatCompletionResponse( + openai_response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(request, response, self.formatter) + return process_chat_completion_response(request, openai_response, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest @@ -151,8 +162,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): params = self._get_params(request) async def _generate_and_convert_to_openai_compat(): - s = await self.client.generate(**params) - async for chunk in s: + stream = await self.client.generate(**params) + async for chunk in stream: choice = OpenAICompatCompletionChoice( finish_reason=chunk["done_reason"] if chunk["done"] else None, text=chunk["response"], @@ -161,9 +172,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): choices=[choice], ) - stream = _generate_and_convert_to_openai_compat() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + request, _generate_and_convert_to_openai_compat(), self.formatter ): yield chunk