OpenAI completion prompt can also be an array

The OpenAI completion prompt field can be a string or an array, so
update things to use and pass that properly.

This also stubs in a basic conversion of OpenAI non-streaming
completion requests to Llama Stack completion calls, for those
providers that don't actually have an OpenAI backend to allow them to
still accept requests via the OpenAI APIs.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-09 09:28:50 -04:00
parent 24cfa1ef1a
commit a6cf8fa12b
10 changed files with 95 additions and 12 deletions

View file

@ -9401,7 +9401,17 @@
"description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint."
}, },
"prompt": { "prompt": {
"type": "string", "oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
}
],
"description": "The prompt to generate a completion for" "description": "The prompt to generate a completion for"
}, },
"best_of": { "best_of": {

View file

@ -6477,7 +6477,11 @@ components:
The identifier of the model to use. The model must be registered with The identifier of the model to use. The model must be registered with
Llama Stack and available via the /models endpoint. Llama Stack and available via the /models endpoint.
prompt: prompt:
type: string oneOf:
- type: string
- type: array
items:
type: string
description: The prompt to generate a completion for description: The prompt to generate a completion for
best_of: best_of:
type: integer type: integer

View file

@ -780,7 +780,7 @@ class Inference(Protocol):
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -423,7 +423,7 @@ class InferenceRouter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -331,7 +331,7 @@ class OllamaInferenceAdapter(
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin(
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging import logging
import time
import uuid
import warnings import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
@ -83,7 +85,7 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy, TopPSamplingStrategy,
UserMessage, UserMessage,
) )
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
StopReason, StopReason,
@ -844,6 +846,31 @@ def _convert_openai_logprobs(
] ]
def _convert_openai_sampling_params(
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> SamplingParams:
sampling_params = SamplingParams()
if max_tokens:
sampling_params.max_tokens = max_tokens
# Map an explicit temperature of 0 to greedy sampling
if temperature == 0:
strategy = GreedySamplingStrategy()
else:
# OpenAI defaults to 1.0 for temperature and top_p if unset
if temperature is None:
temperature = 1.0
if top_p is None:
top_p = 1.0
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
sampling_params.strategy = strategy
return sampling_params
def convert_openai_chat_completion_choice( def convert_openai_chat_completion_choice(
choice: OpenAIChoice, choice: OpenAIChoice,
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin:
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: str, prompt: Union[str, List[str]],
best_of: Optional[int] = None, best_of: Optional[int] = None,
echo: Optional[bool] = None, echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
@ -1078,7 +1105,49 @@ class OpenAICompletionUnsupportedMixin:
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise ValueError(f"{self.__class__.__name__} doesn't support openai completion") if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
# This is a pretty hacky way to do emulate completions -
# basically just de-batches them...
prompts = [prompt] if not isinstance(prompt, list) else prompt
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
choices = []
# "n" is the number of completions to generate per prompt
for _i in range(0, n):
# and we may have multiple prompts, if batching was used
for prompt in prompts:
result = self.completion(
model_id=model,
content=prompt,
sampling_params=sampling_params,
)
index = len(choices)
text = result.content
finish_reason = _convert_openai_finish_reason(result.stop_reason)
choice = OpenAICompletionChoice(
index=index,
text=text,
finish_reason=finish_reason,
)
choices.append(choice)
return OpenAICompletion(
id=f"cmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="text_completion",
)
class OpenAIChatCompletionUnsupportedMixin: class OpenAIChatCompletionUnsupportedMixin: