mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Add prompt_logprobs and guided_choice to OpenAI completions
This adds the vLLM-specific extra_body parameters of prompt_logprobs and guided_choice to our openai_completion inference endpoint. The plan here would be to expand this to support all common optional parameters of any of the OpenAI providers, allowing each provider to use or ignore these parameters based on whether their server supports them. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
ef684ff178
commit
ac5dc8fae2
11 changed files with 98 additions and 5 deletions
9
docs/_static/llama-stack-spec.html
vendored
9
docs/_static/llama-stack-spec.html
vendored
|
@ -9523,6 +9523,15 @@
|
|||
"user": {
|
||||
"type": "string",
|
||||
"description": "(Optional) The user to use"
|
||||
},
|
||||
"guided_choice": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"prompt_logprobs": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
6
docs/_static/llama-stack-spec.yaml
vendored
6
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6556,6 +6556,12 @@ components:
|
|||
user:
|
||||
type: string
|
||||
description: (Optional) The user to use
|
||||
guided_choice:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
prompt_logprobs:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model
|
||||
|
|
|
@ -779,6 +779,7 @@ class Inference(Protocol):
|
|||
@webmethod(route="/openai/v1/completions", method="POST")
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
|
@ -796,6 +797,9 @@ class Inference(Protocol):
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
|
|
|
@ -439,6 +439,8 @@ class InferenceRouter(Inference):
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
|
@ -467,6 +469,8 @@ class InferenceRouter(Inference):
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
|
|
|
@ -347,6 +347,8 @@ class OllamaInferenceAdapter(
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
|
|
@ -222,6 +222,8 @@ class PassthroughInferenceAdapter(Inference):
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
@ -244,6 +246,8 @@ class PassthroughInferenceAdapter(Inference):
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
return await client.inference.openai_completion(**params)
|
||||
|
|
|
@ -276,6 +276,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
|
@ -296,6 +298,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
@ -440,8 +440,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: Dict[str, Any] = {}
|
||||
if prompt_logprobs:
|
||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||
if guided_choice:
|
||||
extra_body["guided_choice"] = guided_choice
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
@ -460,6 +469,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
return await self.client.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
@ -267,6 +267,8 @@ class LiteLLMOpenAIMixin(
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
|
@ -287,6 +289,8 @@ class LiteLLMOpenAIMixin(
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
return litellm.text_completion(**params)
|
||||
|
||||
|
|
|
@ -1104,6 +1104,8 @@ class OpenAICompletionUnsupportedMixin:
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
if stream:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
||||
|
|
|
@ -13,15 +13,19 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
|||
from ..test_cases.test_case import TestCase
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
||||
|
||||
def provider_from_model(client_with_models, model_id):
|
||||
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
return providers[provider_id]
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
"inline::sentence-transformers",
|
||||
|
@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||
|
||||
|
||||
def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type != "remote::vllm":
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models, text_model_id):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
|
@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case):
|
|||
streamed_content = [chunk.choices[0].text for chunk in response]
|
||||
content_str = "".join(streamed_content).lower().strip()
|
||||
assert len(content_str) > 10
|
||||
|
||||
|
||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "Hello, world!"
|
||||
response = openai_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"prompt_logprobs": 1,
|
||||
},
|
||||
)
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
assert len(choice.prompt_logprobs) > 0
|
||||
|
||||
|
||||
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "I am feeling really sad today."
|
||||
response = openai_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"guided_choice": ["joy", "sadness"],
|
||||
},
|
||||
)
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
assert choice.text in ["joy", "sadness"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue