diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index e92deaa41..34f060386 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9401,7 +9401,17 @@
"description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint."
},
"prompt": {
- "type": "string",
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ ],
"description": "The prompt to generate a completion for"
},
"best_of": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index f0c5d1a79..85a287643 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6477,7 +6477,11 @@ components:
The identifier of the model to use. The model must be registered with
Llama Stack and available via the /models endpoint.
prompt:
- type: string
+ oneOf:
+ - type: string
+ - type: array
+ items:
+ type: string
description: The prompt to generate a completion for
best_of:
type: integer
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 6271466d4..13eacd217 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -780,7 +780,7 @@ class Inference(Protocol):
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 19cc8ac09..89f174451 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -423,7 +423,7 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index dc2c8b3f5..fc1cf2265 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -331,7 +331,7 @@ class OllamaInferenceAdapter(
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py
index cbe6e6cae..09bd22b4c 100644
--- a/llama_stack/providers/remote/inference/passthrough/passthrough.py
+++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py
@@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 3e43a844c..bde32593c 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 66fb986f9..daeb95b27 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 8111e4463..cdb4b21aa 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin(
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index ea02de573..bc6eed104 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -5,6 +5,8 @@
# the root directory of this source tree.
import json
import logging
+import time
+import uuid
import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
@@ -83,7 +85,7 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy,
UserMessage,
)
-from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion
+from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
@@ -844,6 +846,31 @@ def _convert_openai_logprobs(
]
+def _convert_openai_sampling_params(
+ max_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_p: Optional[float] = None,
+) -> SamplingParams:
+ sampling_params = SamplingParams()
+
+ if max_tokens:
+ sampling_params.max_tokens = max_tokens
+
+ # Map an explicit temperature of 0 to greedy sampling
+ if temperature == 0:
+ strategy = GreedySamplingStrategy()
+ else:
+ # OpenAI defaults to 1.0 for temperature and top_p if unset
+ if temperature is None:
+ temperature = 1.0
+ if top_p is None:
+ top_p = 1.0
+ strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
+
+ sampling_params.strategy = strategy
+ return sampling_params
+
+
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
@@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin:
async def openai_completion(
self,
model: str,
- prompt: str,
+ prompt: Union[str, List[str]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
@@ -1078,7 +1105,49 @@ class OpenAICompletionUnsupportedMixin:
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAICompletion:
- raise ValueError(f"{self.__class__.__name__} doesn't support openai completion")
+ if stream:
+ raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
+
+ # This is a pretty hacky way to do emulate completions -
+ # basically just de-batches them...
+ prompts = [prompt] if not isinstance(prompt, list) else prompt
+
+ sampling_params = _convert_openai_sampling_params(
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ )
+
+ choices = []
+ # "n" is the number of completions to generate per prompt
+ for _i in range(0, n):
+ # and we may have multiple prompts, if batching was used
+
+ for prompt in prompts:
+ result = self.completion(
+ model_id=model,
+ content=prompt,
+ sampling_params=sampling_params,
+ )
+
+ index = len(choices)
+ text = result.content
+ finish_reason = _convert_openai_finish_reason(result.stop_reason)
+
+ choice = OpenAICompletionChoice(
+ index=index,
+ text=text,
+ finish_reason=finish_reason,
+ )
+ choices.append(choice)
+
+ return OpenAICompletion(
+ id=f"cmpl-{uuid.uuid4()}",
+ choices=choices,
+ created=int(time.time()),
+ model=model,
+ object="text_completion",
+ )
class OpenAIChatCompletionUnsupportedMixin: