mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
OpenAI completion prompt can also be an array
The OpenAI completion prompt field can be a string or an array, so update things to use and pass that properly. This also stubs in a basic conversion of OpenAI non-streaming completion requests to Llama Stack completion calls, for those providers that don't actually have an OpenAI backend to allow them to still accept requests via the OpenAI APIs. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
24cfa1ef1a
commit
a6cf8fa12b
10 changed files with 95 additions and 12 deletions
12
docs/_static/llama-stack-spec.html
vendored
12
docs/_static/llama-stack-spec.html
vendored
|
@ -9401,7 +9401,17 @@
|
||||||
"description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint."
|
"description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint."
|
||||||
},
|
},
|
||||||
"prompt": {
|
"prompt": {
|
||||||
"type": "string",
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
"description": "The prompt to generate a completion for"
|
"description": "The prompt to generate a completion for"
|
||||||
},
|
},
|
||||||
"best_of": {
|
"best_of": {
|
||||||
|
|
6
docs/_static/llama-stack-spec.yaml
vendored
6
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6477,7 +6477,11 @@ components:
|
||||||
The identifier of the model to use. The model must be registered with
|
The identifier of the model to use. The model must be registered with
|
||||||
Llama Stack and available via the /models endpoint.
|
Llama Stack and available via the /models endpoint.
|
||||||
prompt:
|
prompt:
|
||||||
type: string
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
description: The prompt to generate a completion for
|
description: The prompt to generate a completion for
|
||||||
best_of:
|
best_of:
|
||||||
type: integer
|
type: integer
|
||||||
|
|
|
@ -780,7 +780,7 @@ class Inference(Protocol):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -423,7 +423,7 @@ class InferenceRouter(Inference):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -331,7 +331,7 @@ class OllamaInferenceAdapter(
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin(
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
|
@ -83,7 +85,7 @@ from llama_stack.apis.inference import (
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion
|
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -844,6 +846,31 @@ def _convert_openai_logprobs(
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_openai_sampling_params(
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
) -> SamplingParams:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
sampling_params.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# Map an explicit temperature of 0 to greedy sampling
|
||||||
|
if temperature == 0:
|
||||||
|
strategy = GreedySamplingStrategy()
|
||||||
|
else:
|
||||||
|
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||||
|
if temperature is None:
|
||||||
|
temperature = 1.0
|
||||||
|
if top_p is None:
|
||||||
|
top_p = 1.0
|
||||||
|
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
|
||||||
|
|
||||||
|
sampling_params.strategy = strategy
|
||||||
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_chat_completion_choice(
|
def convert_openai_chat_completion_choice(
|
||||||
choice: OpenAIChoice,
|
choice: OpenAIChoice,
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
|
@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin:
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
echo: Optional[bool] = None,
|
echo: Optional[bool] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
|
@ -1078,7 +1105,49 @@ class OpenAICompletionUnsupportedMixin:
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise ValueError(f"{self.__class__.__name__} doesn't support openai completion")
|
if stream:
|
||||||
|
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
||||||
|
|
||||||
|
# This is a pretty hacky way to do emulate completions -
|
||||||
|
# basically just de-batches them...
|
||||||
|
prompts = [prompt] if not isinstance(prompt, list) else prompt
|
||||||
|
|
||||||
|
sampling_params = _convert_openai_sampling_params(
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
choices = []
|
||||||
|
# "n" is the number of completions to generate per prompt
|
||||||
|
for _i in range(0, n):
|
||||||
|
# and we may have multiple prompts, if batching was used
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
result = self.completion(
|
||||||
|
model_id=model,
|
||||||
|
content=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
index = len(choices)
|
||||||
|
text = result.content
|
||||||
|
finish_reason = _convert_openai_finish_reason(result.stop_reason)
|
||||||
|
|
||||||
|
choice = OpenAICompletionChoice(
|
||||||
|
index=index,
|
||||||
|
text=text,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
|
||||||
|
return OpenAICompletion(
|
||||||
|
id=f"cmpl-{uuid.uuid4()}",
|
||||||
|
choices=choices,
|
||||||
|
created=int(time.time()),
|
||||||
|
model=model,
|
||||||
|
object="text_completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionUnsupportedMixin:
|
class OpenAIChatCompletionUnsupportedMixin:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue