From 136e6b3cf71f4da30979cb770fc956435717301a Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Mon, 12 May 2025 13:57:53 -0400 Subject: [PATCH] fix: ollama openai completion and chat completion params (#2125) # What does this PR do? The ollama provider was using an older variant of the code to convert incoming parameters from the OpenAI API completions and chat completion endpoints into requests that get sent to the backend provider over its own OpenAI client. This updates it to use the common `prepare_openai_completion_params` method used elsewhere, which takes care of removing stray `None` values even for nested structures. Without this, some other parameters, even if they have values of `None`, make their way to ollama and actually influence its inference output as opposed to when those parameters are not sent at all. ## Test Plan This passes tests/integration/inference/test_openai_completion.py and fixes the issue found in #2098, which was tested via manual curl requests crafted a particular way. Closes #2098 Signed-off-by: Ben Browning --- .../remote/inference/ollama/ollama.py | 97 +++++++++---------- 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 32e2b17d0..72cf0d129 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -61,6 +61,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -395,29 +396,25 @@ class OllamaInferenceAdapter( raise ValueError("Ollama does not support non-string prompts for completion") model_obj = await self._get_model(model) - params = { - k: v - for k, v in { - "model": model_obj.provider_resource_id, - "prompt": prompt, - "best_of": best_of, - "echo": echo, - "frequency_penalty": frequency_penalty, - "logit_bias": logit_bias, - "logprobs": logprobs, - "max_tokens": max_tokens, - "n": n, - "presence_penalty": presence_penalty, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options, - "temperature": temperature, - "top_p": top_p, - "user": user, - }.items() - if v is not None - } + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) return await self.openai_client.completions.create(**params) # type: ignore async def openai_chat_completion( @@ -447,35 +444,31 @@ class OllamaInferenceAdapter( user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self._get_model(model) - params = { - k: v - for k, v in { - "model": model_obj.provider_resource_id, - "messages": messages, - "frequency_penalty": frequency_penalty, - "function_call": function_call, - "functions": functions, - "logit_bias": logit_bias, - "logprobs": logprobs, - "max_completion_tokens": max_completion_tokens, - "max_tokens": max_tokens, - "n": n, - "parallel_tool_calls": parallel_tool_calls, - "presence_penalty": presence_penalty, - "response_format": response_format, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options, - "temperature": temperature, - "tool_choice": tool_choice, - "tools": tools, - "top_logprobs": top_logprobs, - "top_p": top_p, - "user": user, - }.items() - if v is not None - } + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) return await self.openai_client.chat.completions.create(**params) # type: ignore async def batch_completion(