From a7be58e4e188533e8398217ed078640d91fdc558 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 11 Sep 2024 12:29:22 -0700 Subject: [PATCH] migrate inference/completion --- llama_toolchain/inference/api/api.py | 6 +++++- llama_toolchain/inference/meta_reference/inference.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/llama_toolchain/inference/api/api.py b/llama_toolchain/inference/api/api.py index 419e2dafb..712b25b2b 100644 --- a/llama_toolchain/inference/api/api.py +++ b/llama_toolchain/inference/api/api.py @@ -170,7 +170,11 @@ class Inference(Protocol): @webmethod(route="/inference/completion") async def completion( self, - request: CompletionRequest, + model: str + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @webmethod(route="/inference/chat_completion") diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 2cc7ecfa6..c86c0db8b 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -65,6 +65,7 @@ class MetaReferenceInferenceImpl(Inference): ) -> AsyncIterator[ Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] ]: + # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, messages=messages,