diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index a74932147..36bfad49e 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9523,6 +9523,15 @@
"user": {
"type": "string",
"description": "(Optional) The user to use"
+ },
+ "guided_choice": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "prompt_logprobs": {
+ "type": "integer"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index b475dc142..82faf450a 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6556,6 +6556,12 @@ components:
user:
type: string
description: (Optional) The user to use
+ guided_choice:
+ type: array
+ items:
+ type: string
+ prompt_logprobs:
+ type: integer
additionalProperties: false
required:
- model
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index b29e165f7..3390a3fef 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -779,6 +779,7 @@ class Inference(Protocol):
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,
+ # Standard OpenAI completion parameters
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
@@ -796,6 +797,9 @@ class Inference(Protocol):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ # vLLM-specific parameters
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 2d0c95688..bc313036f 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -439,6 +439,8 @@ class InferenceRouter(Inference):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@@ -467,6 +469,8 @@ class InferenceRouter(Inference):
temperature=temperature,
top_p=top_p,
user=user,
+ guided_choice=guided_choice,
+ prompt_logprobs=prompt_logprobs,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index cdd41e372..b8671197e 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -347,6 +347,8 @@ class OllamaInferenceAdapter(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py
index 7d19c7813..0eb38c395 100644
--- a/llama_stack/providers/remote/inference/passthrough/passthrough.py
+++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py
@@ -222,6 +222,8 @@ class PassthroughInferenceAdapter(Inference):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
@@ -244,6 +246,8 @@ class PassthroughInferenceAdapter(Inference):
temperature=temperature,
top_p=top_p,
user=user,
+ guided_choice=guided_choice,
+ prompt_logprobs=prompt_logprobs,
)
return await client.inference.openai_completion(**params)
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index be984167a..2c9a7ec03 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -276,6 +276,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
@@ -296,6 +298,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
temperature=temperature,
top_p=top_p,
user=user,
+ guided_choice=guided_choice,
+ prompt_logprobs=prompt_logprobs,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 7425d68bd..cac310613 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -440,8 +440,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
+
+ extra_body: Dict[str, Any] = {}
+ if prompt_logprobs:
+ extra_body["prompt_logprobs"] = prompt_logprobs
+ if guided_choice:
+ extra_body["guided_choice"] = guided_choice
+
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@@ -460,6 +469,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
temperature=temperature,
top_p=top_p,
user=user,
+ extra_body=extra_body,
)
return await self.client.completions.create(**params) # type: ignore
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 3119c8b40..2d2f0400a 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -267,6 +267,8 @@ class LiteLLMOpenAIMixin(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
@@ -287,6 +289,8 @@ class LiteLLMOpenAIMixin(
temperature=temperature,
top_p=top_p,
user=user,
+ guided_choice=guided_choice,
+ prompt_logprobs=prompt_logprobs,
)
return litellm.text_completion(**params)
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 74587c7f5..f33cb4443 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -1104,6 +1104,8 @@ class OpenAICompletionUnsupportedMixin:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
+ guided_choice: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py
index 78df64af0..410c1fe22 100644
--- a/tests/integration/inference/test_openai_completion.py
+++ b/tests/integration/inference/test_openai_completion.py
@@ -13,15 +13,19 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from ..test_cases.test_case import TestCase
-def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
- if isinstance(client_with_models, LlamaStackAsLibraryClient):
- pytest.skip("OpenAI completions are not supported when testing with library client yet.")
-
+def provider_from_model(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
- provider = providers[provider_id]
+ return providers[provider_id]
+
+
+def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
+ if isinstance(client_with_models, LlamaStackAsLibraryClient):
+ pytest.skip("OpenAI completions are not supported when testing with library client yet.")
+
+ provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"inline::meta-reference",
"inline::sentence-transformers",
@@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
+def skip_if_provider_isnt_vllm(client_with_models, model_id):
+ provider = provider_from_model(client_with_models, model_id)
+ if provider.provider_type != "remote::vllm":
+ pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
+
+
@pytest.fixture
def openai_client(client_with_models, text_model_id):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
@@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case):
streamed_content = [chunk.choices[0].text for chunk in response]
content_str = "".join(streamed_content).lower().strip()
assert len(content_str) > 10
+
+
+def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id):
+ skip_if_provider_isnt_vllm(client_with_models, text_model_id)
+
+ prompt = "Hello, world!"
+ response = openai_client.completions.create(
+ model=text_model_id,
+ prompt=prompt,
+ stream=False,
+ extra_body={
+ "prompt_logprobs": 1,
+ },
+ )
+ assert len(response.choices) > 0
+ choice = response.choices[0]
+ assert len(choice.prompt_logprobs) > 0
+
+
+def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
+ skip_if_provider_isnt_vllm(client_with_models, text_model_id)
+
+ prompt = "I am feeling really sad today."
+ response = openai_client.completions.create(
+ model=text_model_id,
+ prompt=prompt,
+ stream=False,
+ extra_body={
+ "guided_choice": ["joy", "sadness"],
+ },
+ )
+ assert len(response.choices) > 0
+ choice = response.choices[0]
+ assert choice.text in ["joy", "sadness"]