OpenAI completion prompt can also be an array

The OpenAI completion prompt field can be a string or an array, so
update things to use and pass that properly.

This also stubs in a basic conversion of OpenAI non-streaming
completion requests to Llama Stack completion calls, for those
providers that don't actually have an OpenAI backend to allow them to
still accept requests via the OpenAI APIs.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-09 09:28:50 -04:00
parent 24cfa1ef1a
commit a6cf8fa12b
10 changed files with 95 additions and 12 deletions

View file

@ -9401,7 +9401,17 @@
"description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint."
},
"prompt": {
"type": "string",
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
}
],
"description": "The prompt to generate a completion for"
},
"best_of": {

View file

@ -6477,7 +6477,11 @@ components:
The identifier of the model to use. The model must be registered with
Llama Stack and available via the /models endpoint.
prompt:
type: string
oneOf:
- type: string
- type: array
items:
type: string
description: The prompt to generate a completion for
best_of:
type: integer

View file

@ -780,7 +780,7 @@ class Inference(Protocol):
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -423,7 +423,7 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -331,7 +331,7 @@ class OllamaInferenceAdapter(
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin(
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
import json
import logging
import time
import uuid
import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
@ -83,7 +85,7 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
@ -844,6 +846,31 @@ def _convert_openai_logprobs(
]
def _convert_openai_sampling_params(
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> SamplingParams:
sampling_params = SamplingParams()
if max_tokens:
sampling_params.max_tokens = max_tokens
# Map an explicit temperature of 0 to greedy sampling
if temperature == 0:
strategy = GreedySamplingStrategy()
else:
# OpenAI defaults to 1.0 for temperature and top_p if unset
if temperature is None:
temperature = 1.0
if top_p is None:
top_p = 1.0
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
sampling_params.strategy = strategy
return sampling_params
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin:
async def openai_completion(
self,
model: str,
prompt: str,
prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
@ -1078,7 +1105,49 @@ class OpenAICompletionUnsupportedMixin:
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAICompletion:
raise ValueError(f"{self.__class__.__name__} doesn't support openai completion")
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
# This is a pretty hacky way to do emulate completions -
# basically just de-batches them...
prompts = [prompt] if not isinstance(prompt, list) else prompt
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
choices = []
# "n" is the number of completions to generate per prompt
for _i in range(0, n):
# and we may have multiple prompts, if batching was used
for prompt in prompts:
result = self.completion(
model_id=model,
content=prompt,
sampling_params=sampling_params,
)
index = len(choices)
text = result.content
finish_reason = _convert_openai_finish_reason(result.stop_reason)
choice = OpenAICompletionChoice(
index=index,
text=text,
finish_reason=finish_reason,
)
choices.append(choice)
return OpenAICompletion(
id=f"cmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="text_completion",
)
class OpenAIChatCompletionUnsupportedMixin: