Add prompt_logprobs and guided_choice to OpenAI completions

This adds the vLLM-specific extra_body parameters of prompt_logprobs
and guided_choice to our openai_completion inference endpoint. The
plan here would be to expand this to support all common optional
parameters of any of the OpenAI providers, allowing each provider to
use or ignore these parameters based on whether their server supports them.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-09 15:43:53 -04:00
parent ef684ff178
commit ac5dc8fae2
11 changed files with 98 additions and 5 deletions

View file

@ -9523,6 +9523,15 @@
"user": {
"type": "string",
"description": "(Optional) The user to use"
},
"guided_choice": {
"type": "array",
"items": {
"type": "string"
}
},
"prompt_logprobs": {
"type": "integer"
}
},
"additionalProperties": false,

View file

@ -6556,6 +6556,12 @@ components:
user:
type: string
description: (Optional) The user to use
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
additionalProperties: false
required:
- model

View file

@ -779,6 +779,7 @@ class Inference(Protocol):
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
@ -796,6 +797,9 @@ class Inference(Protocol):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
# vLLM-specific parameters
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.

View file

@ -439,6 +439,8 @@ class InferenceRouter(Inference):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@ -467,6 +469,8 @@ class InferenceRouter(Inference):
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)

View file

@ -347,6 +347,8 @@ class OllamaInferenceAdapter(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")

View file

@ -222,6 +222,8 @@ class PassthroughInferenceAdapter(Inference):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
@ -244,6 +246,8 @@ class PassthroughInferenceAdapter(Inference):
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return await client.inference.openai_completion(**params)

View file

@ -276,6 +276,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
@ -296,6 +298,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return await self._get_openai_client().completions.create(**params) # type: ignore

View file

@ -440,8 +440,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
extra_body: Dict[str, Any] = {}
if prompt_logprobs:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@ -460,6 +469,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
temperature=temperature,
top_p=top_p,
user=user,
extra_body=extra_body,
)
return await self.client.completions.create(**params) # type: ignore

View file

@ -267,6 +267,8 @@ class LiteLLMOpenAIMixin(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
@ -287,6 +289,8 @@ class LiteLLMOpenAIMixin(
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return litellm.text_completion(**params)

View file

@ -1104,6 +1104,8 @@ class OpenAICompletionUnsupportedMixin:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")

View file

@ -13,15 +13,19 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from ..test_cases.test_case import TestCase
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
def provider_from_model(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
return providers[provider_id]
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"inline::meta-reference",
"inline::sentence-transformers",
@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
def skip_if_provider_isnt_vllm(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type != "remote::vllm":
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
@pytest.fixture
def openai_client(client_with_models, text_model_id):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case):
streamed_content = [chunk.choices[0].text for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert len(content_str) > 10
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "Hello, world!"
response = openai_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
extra_body={
"prompt_logprobs": 1,
},
)
assert len(response.choices) > 0
choice = response.choices[0]
assert len(choice.prompt_logprobs) > 0
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "I am feeling really sad today."
response = openai_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
extra_body={
"guided_choice": ["joy", "sadness"],
},
)
assert len(response.choices) > 0
choice = response.choices[0]
assert choice.text in ["joy", "sadness"]