Add prompt_logprobs and guided_choice to OpenAI completions

This adds the vLLM-specific extra_body parameters of prompt_logprobs
and guided_choice to our openai_completion inference endpoint. The
plan here would be to expand this to support all common optional
parameters of any of the OpenAI providers, allowing each provider to
use or ignore these parameters based on whether their server supports them.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-09 15:43:53 -04:00
parent ef684ff178
commit ac5dc8fae2
11 changed files with 98 additions and 5 deletions

View file

@ -13,15 +13,19 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from ..test_cases.test_case import TestCase
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
def provider_from_model(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
return providers[provider_id]
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"inline::meta-reference",
"inline::sentence-transformers",
@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
def skip_if_provider_isnt_vllm(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type != "remote::vllm":
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
@pytest.fixture
def openai_client(client_with_models, text_model_id):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case):
streamed_content = [chunk.choices[0].text for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert len(content_str) > 10
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "Hello, world!"
response = openai_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
extra_body={
"prompt_logprobs": 1,
},
)
assert len(response.choices) > 0
choice = response.choices[0]
assert len(choice.prompt_logprobs) > 0
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "I am feeling really sad today."
response = openai_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
extra_body={
"guided_choice": ["joy", "sadness"],
},
)
assert len(response.choices) > 0
choice = response.choices[0]
assert choice.text in ["joy", "sadness"]