diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2e87b2e24..c360bcfb0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -101,7 +101,7 @@ class InferenceRouter(Inference): tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = False, + logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: params = dict( model=model, @@ -125,7 +125,7 @@ class InferenceRouter(Inference): content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = False, + logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: return await self.routing_table.get_provider_impl(model).completion( model=model, diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index e50736b04..9abafb451 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -135,6 +135,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider): assert ( len(token_result.logprobs) == 1 ), "Expected logprob to contain 1 result for the current token" + assert ( + logprobs.top_k == 1 + ), "Only top_k=1 is supported for LogProbConfig" + logprobs.append( TokenLogProbs( logprobs_by_token={