diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a74932147..36bfad49e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9523,6 +9523,15 @@ "user": { "type": "string", "description": "(Optional) The user to use" + }, + "guided_choice": { + "type": "array", + "items": { + "type": "string" + } + }, + "prompt_logprobs": { + "type": "integer" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index b475dc142..82faf450a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6556,6 +6556,12 @@ components: user: type: string description: (Optional) The user to use + guided_choice: + type: array + items: + type: string + prompt_logprobs: + type: integer additionalProperties: false required: - model diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b29e165f7..3390a3fef 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -779,6 +779,7 @@ class Inference(Protocol): @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, + # Standard OpenAI completion parameters model: str, prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, @@ -796,6 +797,9 @@ class Inference(Protocol): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + # vLLM-specific parameters + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: """Generate an OpenAI-compatible completion for the given prompt using the specified model. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2d0c95688..bc313036f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -439,6 +439,8 @@ class InferenceRouter(Inference): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", @@ -467,6 +469,8 @@ class InferenceRouter(Inference): temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) provider = self.routing_table.get_provider_impl(model_obj.identifier) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index cdd41e372..b8671197e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -347,6 +347,8 @@ class OllamaInferenceAdapter( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: if not isinstance(prompt, str): raise ValueError("Ollama does not support non-string prompts for completion") diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 7d19c7813..0eb38c395 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -222,6 +222,8 @@ class PassthroughInferenceAdapter(Inference): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -244,6 +246,8 @@ class PassthroughInferenceAdapter(Inference): temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return await client.inference.openai_completion(**params) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index be984167a..2c9a7ec03 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -276,6 +276,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -296,6 +298,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return await self._get_openai_client().completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7425d68bd..cac310613 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -440,8 +440,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) + + extra_body: Dict[str, Any] = {} + if prompt_logprobs: + extra_body["prompt_logprobs"] = prompt_logprobs + if guided_choice: + extra_body["guided_choice"] = guided_choice + params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, prompt=prompt, @@ -460,6 +469,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): temperature=temperature, top_p=top_p, user=user, + extra_body=extra_body, ) return await self.client.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 3119c8b40..2d2f0400a 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -267,6 +267,8 @@ class LiteLLMOpenAIMixin( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -287,6 +289,8 @@ class LiteLLMOpenAIMixin( temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return litellm.text_completion(**params) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 74587c7f5..f33cb4443 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1104,6 +1104,8 @@ class OpenAICompletionUnsupportedMixin: temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: if stream: raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 78df64af0..410c1fe22 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -13,15 +13,19 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from ..test_cases.test_case import TestCase -def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI completions are not supported when testing with library client yet.") - +def provider_from_model(client_with_models, model_id): models = {m.identifier: m for m in client_with_models.models.list()} models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} - provider = providers[provider_id] + return providers[provider_id] + + +def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI completions are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( "inline::meta-reference", "inline::sentence-transformers", @@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_provider_isnt_vllm(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type != "remote::vllm": + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.") + + @pytest.fixture def openai_client(client_with_models, text_model_id): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) @@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case): streamed_content = [chunk.choices[0].text for chunk in response] content_str = "".join(streamed_content).lower().strip() assert len(content_str) > 10 + + +def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + prompt = "Hello, world!" + response = openai_client.completions.create( + model=text_model_id, + prompt=prompt, + stream=False, + extra_body={ + "prompt_logprobs": 1, + }, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert len(choice.prompt_logprobs) > 0 + + +def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + prompt = "I am feeling really sad today." + response = openai_client.completions.create( + model=text_model_id, + prompt=prompt, + stream=False, + extra_body={ + "guided_choice": ["joy", "sadness"], + }, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert choice.text in ["joy", "sadness"]