diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index e92deaa41..34f060386 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9401,7 +9401,17 @@ "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." }, "prompt": { - "type": "string", + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], "description": "The prompt to generate a completion for" }, "best_of": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index f0c5d1a79..85a287643 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6477,7 +6477,11 @@ components: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. prompt: - type: string + oneOf: + - type: string + - type: array + items: + type: string description: The prompt to generate a completion for best_of: type: integer diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 6271466d4..13eacd217 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -780,7 +780,7 @@ class Inference(Protocol): async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 19cc8ac09..89f174451 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -423,7 +423,7 @@ class InferenceRouter(Inference): async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index dc2c8b3f5..fc1cf2265 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -331,7 +331,7 @@ class OllamaInferenceAdapter( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index cbe6e6cae..09bd22b4c 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference): async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 3e43a844c..bde32593c 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 66fb986f9..daeb95b27 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 8111e4463..cdb4b21aa 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin( async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index ea02de573..bc6eed104 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,6 +5,8 @@ # the root directory of this source tree. import json import logging +import time +import uuid import warnings from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union @@ -83,7 +85,7 @@ from llama_stack.apis.inference import ( TopPSamplingStrategy, UserMessage, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -844,6 +846,31 @@ def _convert_openai_logprobs( ] +def _convert_openai_sampling_params( + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, +) -> SamplingParams: + sampling_params = SamplingParams() + + if max_tokens: + sampling_params.max_tokens = max_tokens + + # Map an explicit temperature of 0 to greedy sampling + if temperature == 0: + strategy = GreedySamplingStrategy() + else: + # OpenAI defaults to 1.0 for temperature and top_p if unset + if temperature is None: + temperature = 1.0 + if top_p is None: + top_p = 1.0 + strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) + + sampling_params.strategy = strategy + return sampling_params + + def convert_openai_chat_completion_choice( choice: OpenAIChoice, ) -> ChatCompletionResponse: @@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin: async def openai_completion( self, model: str, - prompt: str, + prompt: Union[str, List[str]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, @@ -1078,7 +1105,49 @@ class OpenAICompletionUnsupportedMixin: top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAICompletion: - raise ValueError(f"{self.__class__.__name__} doesn't support openai completion") + if stream: + raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") + + # This is a pretty hacky way to do emulate completions - + # basically just de-batches them... + prompts = [prompt] if not isinstance(prompt, list) else prompt + + sampling_params = _convert_openai_sampling_params( + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + ) + + choices = [] + # "n" is the number of completions to generate per prompt + for _i in range(0, n): + # and we may have multiple prompts, if batching was used + + for prompt in prompts: + result = self.completion( + model_id=model, + content=prompt, + sampling_params=sampling_params, + ) + + index = len(choices) + text = result.content + finish_reason = _convert_openai_finish_reason(result.stop_reason) + + choice = OpenAICompletionChoice( + index=index, + text=text, + finish_reason=finish_reason, + ) + choices.append(choice) + + return OpenAICompletion( + id=f"cmpl-{uuid.uuid4()}", + choices=choices, + created=int(time.time()), + model=model, + object="text_completion", + ) class OpenAIChatCompletionUnsupportedMixin: