From ac5dc8fae29d433a394363bd001fb150f686a0d3 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 15:43:53 -0400 Subject: [PATCH] Add prompt_logprobs and guided_choice to OpenAI completions This adds the vLLM-specific extra_body parameters of prompt_logprobs and guided_choice to our openai_completion inference endpoint. The plan here would be to expand this to support all common optional parameters of any of the OpenAI providers, allowing each provider to use or ignore these parameters based on whether their server supports them. Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 9 ++++ docs/_static/llama-stack-spec.yaml | 6 +++ llama_stack/apis/inference/inference.py | 4 ++ llama_stack/distribution/routers/routers.py | 4 ++ .../remote/inference/ollama/ollama.py | 2 + .../inference/passthrough/passthrough.py | 4 ++ .../remote/inference/together/together.py | 4 ++ .../providers/remote/inference/vllm/vllm.py | 10 ++++ .../utils/inference/litellm_openai_mixin.py | 4 ++ .../utils/inference/openai_compat.py | 2 + .../inference/test_openai_completion.py | 54 +++++++++++++++++-- 11 files changed, 98 insertions(+), 5 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a74932147..36bfad49e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9523,6 +9523,15 @@ "user": { "type": "string", "description": "(Optional) The user to use" + }, + "guided_choice": { + "type": "array", + "items": { + "type": "string" + } + }, + "prompt_logprobs": { + "type": "integer" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index b475dc142..82faf450a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6556,6 +6556,12 @@ components: user: type: string description: (Optional) The user to use + guided_choice: + type: array + items: + type: string + prompt_logprobs: + type: integer additionalProperties: false required: - model diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index b29e165f7..3390a3fef 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -779,6 +779,7 @@ class Inference(Protocol): @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, + # Standard OpenAI completion parameters model: str, prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, @@ -796,6 +797,9 @@ class Inference(Protocol): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + # vLLM-specific parameters + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: """Generate an OpenAI-compatible completion for the given prompt using the specified model. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2d0c95688..bc313036f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -439,6 +439,8 @@ class InferenceRouter(Inference): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", @@ -467,6 +469,8 @@ class InferenceRouter(Inference): temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) provider = self.routing_table.get_provider_impl(model_obj.identifier) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index cdd41e372..b8671197e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -347,6 +347,8 @@ class OllamaInferenceAdapter( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: if not isinstance(prompt, str): raise ValueError("Ollama does not support non-string prompts for completion") diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 7d19c7813..0eb38c395 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -222,6 +222,8 @@ class PassthroughInferenceAdapter(Inference): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -244,6 +246,8 @@ class PassthroughInferenceAdapter(Inference): temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return await client.inference.openai_completion(**params) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index be984167a..2c9a7ec03 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -276,6 +276,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -296,6 +298,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return await self._get_openai_client().completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7425d68bd..cac310613 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -440,8 +440,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) + + extra_body: Dict[str, Any] = {} + if prompt_logprobs: + extra_body["prompt_logprobs"] = prompt_logprobs + if guided_choice: + extra_body["guided_choice"] = guided_choice + params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, prompt=prompt, @@ -460,6 +469,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): temperature=temperature, top_p=top_p, user=user, + extra_body=extra_body, ) return await self.client.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 3119c8b40..2d2f0400a 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -267,6 +267,8 @@ class LiteLLMOpenAIMixin( temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -287,6 +289,8 @@ class LiteLLMOpenAIMixin( temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, ) return litellm.text_completion(**params) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 74587c7f5..f33cb4443 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1104,6 +1104,8 @@ class OpenAICompletionUnsupportedMixin: temperature: Optional[float] = None, top_p: Optional[float] = None, user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: if stream: raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 78df64af0..410c1fe22 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -13,15 +13,19 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from ..test_cases.test_case import TestCase -def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI completions are not supported when testing with library client yet.") - +def provider_from_model(client_with_models, model_id): models = {m.identifier: m for m in client_with_models.models.list()} models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} - provider = providers[provider_id] + return providers[provider_id] + + +def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI completions are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( "inline::meta-reference", "inline::sentence-transformers", @@ -37,6 +41,12 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_provider_isnt_vllm(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type != "remote::vllm": + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.") + + @pytest.fixture def openai_client(client_with_models, text_model_id): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) @@ -85,3 +95,37 @@ def test_openai_completion_streaming(openai_client, text_model_id, test_case): streamed_content = [chunk.choices[0].text for chunk in response] content_str = "".join(streamed_content).lower().strip() assert len(content_str) > 10 + + +def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + prompt = "Hello, world!" + response = openai_client.completions.create( + model=text_model_id, + prompt=prompt, + stream=False, + extra_body={ + "prompt_logprobs": 1, + }, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert len(choice.prompt_logprobs) > 0 + + +def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id): + skip_if_provider_isnt_vllm(client_with_models, text_model_id) + + prompt = "I am feeling really sad today." + response = openai_client.completions.create( + model=text_model_id, + prompt=prompt, + stream=False, + extra_body={ + "guided_choice": ["joy", "sadness"], + }, + ) + assert len(response.choices) > 0 + choice = response.choices[0] + assert choice.text in ["joy", "sadness"]