forked from phoenix-oss/llama-stack-mirror
fix: ollama openai completion and chat completion params (#2125)
# What does this PR do? The ollama provider was using an older variant of the code to convert incoming parameters from the OpenAI API completions and chat completion endpoints into requests that get sent to the backend provider over its own OpenAI client. This updates it to use the common `prepare_openai_completion_params` method used elsewhere, which takes care of removing stray `None` values even for nested structures. Without this, some other parameters, even if they have values of `None`, make their way to ollama and actually influence its inference output as opposed to when those parameters are not sent at all. ## Test Plan This passes tests/integration/inference/test_openai_completion.py and fixes the issue found in #2098, which was tested via manual curl requests crafted a particular way. Closes #2098 Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
80c349965f
commit
136e6b3cf7
1 changed files with 45 additions and 52 deletions
|
@ -61,6 +61,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
|
@ -395,29 +396,25 @@ class OllamaInferenceAdapter(
|
|||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"prompt": prompt,
|
||||
"best_of": best_of,
|
||||
"echo": echo,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"logit_bias": logit_bias,
|
||||
"logprobs": logprobs,
|
||||
"max_tokens": max_tokens,
|
||||
"n": n,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"user": user,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
@ -447,35 +444,31 @@ class OllamaInferenceAdapter(
|
|||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"messages": messages,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"function_call": function_call,
|
||||
"functions": functions,
|
||||
"logit_bias": logit_bias,
|
||||
"logprobs": logprobs,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"n": n,
|
||||
"parallel_tool_calls": parallel_tool_calls,
|
||||
"presence_penalty": presence_penalty,
|
||||
"response_format": response_format,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"tool_choice": tool_choice,
|
||||
"tools": tools,
|
||||
"top_logprobs": top_logprobs,
|
||||
"top_p": top_p,
|
||||
"user": user,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def batch_completion(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue