diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 36941480c..a2ac113e8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, List, Optional, Union import httpx from ollama import AsyncClient @@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseStreamChunk, CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, + GrammarResponseFormat, Inference, + JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, @@ -94,10 +99,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: + ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: + assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -111,7 +117,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_completion(request) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion( + self, request: CompletionRequest + ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): @@ -129,7 +137,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async for chunk in process_completion_stream_response(stream): yield chunk - async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = await self.client.generate(**params) @@ -148,17 +156,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = self.model_store.get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -181,7 +190,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): if sampling_options.get("max_tokens") is not None: sampling_options["num_predict"] = sampling_options["max_tokens"] - input_dict = {} + input_dict: dict[str, Any] = {} media_present = request_has_media(request) llama_model = self.register_helper.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): @@ -201,9 +210,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True if fmt := request.response_format: - if fmt.type == "json_schema": + if isinstance(fmt, JsonSchemaResponseFormat): input_dict["format"] = fmt.json_schema - elif fmt.type == "grammar": + elif isinstance(fmt, GrammarResponseFormat): raise NotImplementedError("Grammar response format is not supported") else: raise ValueError(f"Unknown response format type: {fmt.type}") @@ -240,7 +249,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) return process_chat_completion_response(response, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): @@ -275,7 +286,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - model = await self.model_store.get_model(model_id) + assert self.model_store is not None + model = self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( "Ollama does not support media for embeddings" @@ -288,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return EmbeddingsResponse(embeddings=embeddings) - async def register_model(self, model: Model) -> Model: + async def register_model(self, model: Model): model = await self.register_helper.register_model(model) if model.model_type == ModelType.embedding: logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") @@ -302,8 +314,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" ) - return model - async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 04db45e67..9aedfade7 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -300,7 +300,7 @@ def process_chat_completion_response( async def process_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator: +) -> AsyncGenerator[CompletionResponseStreamChunk, None]: stop_reason = None async for chunk in stream: @@ -337,7 +337,7 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], request: ChatCompletionRequest, -) -> AsyncGenerator: +) -> AsyncGenerator[ChatCompletionResponseStreamChunk]: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, diff --git a/pyproject.toml b/pyproject.toml index fee0191e7..990d365b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -245,7 +245,6 @@ exclude = [ "^llama_stack/providers/remote/inference/gemini/", "^llama_stack/providers/remote/inference/groq/", "^llama_stack/providers/remote/inference/nvidia/", - "^llama_stack/providers/remote/inference/ollama/", "^llama_stack/providers/remote/inference/openai/", "^llama_stack/providers/remote/inference/passthrough/", "^llama_stack/providers/remote/inference/runpod/",