From fcdeb3d7bfe6cb5ea4bc0b48e030b0c898ae6fb1 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 9 Apr 2025 10:05:50 -0400 Subject: [PATCH] OpenAI completion prompt can also include tokens The OpenAI completion API supports strings, array of strings, array of tokens, or array of token arrays. So, expand our type hinting to support all of these types. Signed-off-by: Ben Browning --- llama_stack/apis/inference/inference.py | 2 +- llama_stack/distribution/routers/routers.py | 2 +- llama_stack/providers/remote/inference/ollama/ollama.py | 2 +- .../providers/remote/inference/passthrough/passthrough.py | 2 +- llama_stack/providers/remote/inference/together/together.py | 2 +- llama_stack/providers/remote/inference/vllm/vllm.py | 2 +- llama_stack/providers/utils/inference/litellm_openai_mixin.py | 2 +- llama_stack/providers/utils/inference/openai_compat.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 13eacd217..b29e165f7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -780,7 +780,7 @@ class Inference(Protocol): async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 89f174451..2d0c95688 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -423,7 +423,7 @@ class InferenceRouter(Inference): async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index fc1cf2265..1fbc9e747 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -331,7 +331,7 @@ class OllamaInferenceAdapter( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 09bd22b4c..7d19c7813 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -206,7 +206,7 @@ class PassthroughInferenceAdapter(Inference): async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index bde32593c..be984167a 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -260,7 +260,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index daeb95b27..7425d68bd 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -424,7 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index cdb4b21aa..3119c8b40 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -251,7 +251,7 @@ class LiteLLMOpenAIMixin( async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index bc6eed104..74587c7f5 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1088,7 +1088,7 @@ class OpenAICompletionUnsupportedMixin: async def openai_completion( self, model: str, - prompt: Union[str, List[str]], + prompt: Union[str, List[str], List[int], List[List[int]]], best_of: Optional[int] = None, echo: Optional[bool] = None, frequency_penalty: Optional[float] = None,