mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
OpenAI completion prompt can also be an array
The OpenAI completion prompt field can be a string or an array, so update things to use and pass that properly. This also stubs in a basic conversion of OpenAI non-streaming completion requests to Llama Stack completion calls, for those providers that don't actually have an OpenAI backend to allow them to still accept requests via the OpenAI APIs. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
24cfa1ef1a
commit
a6cf8fa12b
10 changed files with 95 additions and 12 deletions
12
docs/_static/llama-stack-spec.html
vendored
12
docs/_static/llama-stack-spec.html
vendored
|
@ -9401,7 +9401,17 @@
|
|||
"description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "The prompt to generate a completion for"
|
||||
},
|
||||
"best_of": {
|
||||
|
|
6
docs/_static/llama-stack-spec.yaml
vendored
6
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6477,7 +6477,11 @@ components:
|
|||
The identifier of the model to use. The model must be registered with
|
||||
Llama Stack and available via the /models endpoint.
|
||||
prompt:
|
||||
type: string
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
description: The prompt to generate a completion for
|
||||
best_of:
|
||||
type: integer
|
||||
|
|
|
@ -780,7 +780,7 @@ class Inference(Protocol):
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -423,7 +423,7 @@ class InferenceRouter(Inference):
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -331,7 +331,7 @@ class OllamaInferenceAdapter(
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin(
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||
|
||||
|
@ -83,7 +85,7 @@ from llama_stack.apis.inference import (
|
|||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
|
@ -844,6 +846,31 @@ def _convert_openai_logprobs(
|
|||
]
|
||||
|
||||
|
||||
def _convert_openai_sampling_params(
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> SamplingParams:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
if max_tokens:
|
||||
sampling_params.max_tokens = max_tokens
|
||||
|
||||
# Map an explicit temperature of 0 to greedy sampling
|
||||
if temperature == 0:
|
||||
strategy = GreedySamplingStrategy()
|
||||
else:
|
||||
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||
if temperature is None:
|
||||
temperature = 1.0
|
||||
if top_p is None:
|
||||
top_p = 1.0
|
||||
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
|
||||
|
||||
sampling_params.strategy = strategy
|
||||
return sampling_params
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
|
@ -1061,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
|
@ -1078,7 +1105,49 @@ class OpenAICompletionUnsupportedMixin:
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAICompletion:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support openai completion")
|
||||
if stream:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
||||
|
||||
# This is a pretty hacky way to do emulate completions -
|
||||
# basically just de-batches them...
|
||||
prompts = [prompt] if not isinstance(prompt, list) else prompt
|
||||
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
choices = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
for _i in range(0, n):
|
||||
# and we may have multiple prompts, if batching was used
|
||||
|
||||
for prompt in prompts:
|
||||
result = self.completion(
|
||||
model_id=model,
|
||||
content=prompt,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
index = len(choices)
|
||||
text = result.content
|
||||
finish_reason = _convert_openai_finish_reason(result.stop_reason)
|
||||
|
||||
choice = OpenAICompletionChoice(
|
||||
index=index,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return OpenAICompletion(
|
||||
id=f"cmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="text_completion",
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionUnsupportedMixin:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue