diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 74aed6e5e..b19d54182 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -23,9 +23,12 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionResponse, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, + completion_request_to_prompt, ) OLLAMA_SUPPORTED_MODELS = { @@ -93,7 +96,64 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - raise NotImplementedError() + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + sampling_options = get_sampling_options(request) + # This is needed since the Ollama API expects num_predict to be set + # for early truncation instead of max_tokens. + if sampling_options["max_tokens"] is not None: + sampling_options["num_predict"] = sampling_options["max_tokens"] + return { + "model": OLLAMA_SUPPORTED_MODELS[request.model], + "prompt": completion_request_to_prompt(request, self.formatter), + "options": sampling_options, + "raw": True, + "stream": request.stream, + } + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + + async def _generate_and_convert_to_openai_compat(): + s = await self.client.generate(**params) + async for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["response"], + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + r = await self.client.generate(**params) + assert isinstance(r, dict) + + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["response"], + ) + response = OpenAICompatCompletionResponse( + choices=[choice], + ) + + return process_completion_response(response, self.formatter) async def chat_completion( self, diff --git a/llama_stack/providers/tests/inference/provider_config_example.yaml b/llama_stack/providers/tests/inference/provider_config_example.yaml index c4bb4af16..675ece1ea 100644 --- a/llama_stack/providers/tests/inference/provider_config_example.yaml +++ b/llama_stack/providers/tests/inference/provider_config_example.yaml @@ -4,6 +4,10 @@ providers: config: host: localhost port: 11434 + - provider_id: meta-reference + provider_type: meta-reference + config: + model: Llama3.2-1B-Instruct - provider_id: test-tgi provider_type: remote::tgi config: diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 09d6a69db..afec9a837 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -132,7 +132,10 @@ async def test_completion(inference_settings): params = inference_settings["common_params"] provider = inference_impl.routing_table.get_provider_impl(params["model"]) - if provider.__provider_id__ != "meta-reference": + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::ollama", + ): pytest.skip("Other inference providers don't support completion() yet") response = await inference_impl.completion( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 72db7b18c..add29da99 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict: if params := request.sampling_params: for attr in {"temperature", "top_p", "top_k", "max_tokens"}: if getattr(params, attr): + if attr == "max_tokens": + options["num_predict"] = getattr(params, attr) options[attr] = getattr(params, attr) if params.repetition_penalty is not None and params.repetition_penalty != 1.0: @@ -49,25 +51,35 @@ def text_from_choice(choice) -> str: return choice.text +def get_stop_reason(finish_reason: str) -> StopReason: + if finish_reason in ["stop", "eos"]: + return StopReason.end_of_turn + elif finish_reason == "eom": + return StopReason.end_of_message + elif finish_reason == "length": + return StopReason.out_of_tokens + + return StopReason.out_of_tokens + + +def process_completion_response( + response: OpenAICompatCompletionResponse, formatter: ChatFormat +) -> CompletionResponse: + choice = response.choices[0] + + return CompletionResponse( + stop_reason=get_stop_reason(choice.finish_reason), + content=choice.text, + ) + + def process_chat_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> ChatCompletionResponse: choice = response.choices[0] - stop_reason = None - if reason := choice.finish_reason: - if reason in ["stop", "eos"]: - stop_reason = StopReason.end_of_turn - elif reason == "eom": - stop_reason = StopReason.end_of_message - elif reason == "length": - stop_reason = StopReason.out_of_tokens - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - completion_message = formatter.decode_assistant_message_from_content( - text_from_choice(choice), stop_reason + text_from_choice(choice), get_stop_reason(choice.finish_reason) ) return ChatCompletionResponse( completion_message=completion_message, @@ -75,6 +87,43 @@ def process_chat_completion_response( ) +async def process_completion_stream_response( + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat +) -> AsyncGenerator: + + stop_reason = None + + async for chunk in stream: + choice = chunk.choices[0] + finish_reason = choice.finish_reason + + if finish_reason: + if finish_reason in ["stop", "eos", "eos_token"]: + stop_reason = StopReason.end_of_turn + elif finish_reason == "length": + stop_reason = StopReason.out_of_tokens + break + + text = text_from_choice(choice) + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + yield CompletionResponseStreamChunk( + delta=text, + stop_reason=stop_reason, + ) + + yield CompletionResponseStreamChunk( + delta="", + stop_reason=stop_reason, + ) + + async def process_chat_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 5b8ded52c..9d695698f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -23,6 +23,13 @@ from llama_models.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models +def completion_request_to_prompt( + request: CompletionRequest, formatter: ChatFormat +) -> str: + model_input = formatter.encode_content(request.content) + return formatter.tokenizer.decode(model_input.tokens) + + def chat_completion_request_to_prompt( request: ChatCompletionRequest, formatter: ChatFormat ) -> str: