diff --git a/llama_toolchain/inference/api/api.py b/llama_toolchain/inference/api/api.py index 419e2dafb..712b25b2b 100644 --- a/llama_toolchain/inference/api/api.py +++ b/llama_toolchain/inference/api/api.py @@ -170,7 +170,11 @@ class Inference(Protocol): @webmethod(route="/inference/completion") async def completion( self, - request: CompletionRequest, + model: str + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @webmethod(route="/inference/chat_completion") diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 2cc7ecfa6..c86c0db8b 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -65,6 +65,7 @@ class MetaReferenceInferenceImpl(Inference): ) -> AsyncIterator[ Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] ]: + # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, messages=messages,