diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 0a6c3e6e2..fffcf4692 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -112,11 +112,19 @@ async def run_main( content="hello world, write me a 2 sentence poem about the moon" ) cprint(f"User>{message.content}", "green") + + if logprobs: + logprobs_config = LogProbConfig( + top_k=1, + ) + else: + logprobs_config = None + iterator = client.chat_completion( model=model, messages=[message], stream=stream, - logprobs=logprobs, + logprobs=logprobs_config, ) if logprobs: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index abb2822e3..428f29b88 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -69,7 +69,7 @@ class ChatCompletionResponseEvent(BaseModel): event_type: ChatCompletionResponseEventType delta: Union[str, ToolCallDelta] - logprobs: Optional[List[float]] = None + logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[StopReason] = None @@ -80,7 +80,7 @@ class CompletionRequest(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() stream: Optional[bool] = False - logprobs: Optional[bool] = False + logprobs: Optional[LogProbConfig] = None @json_schema_type @@ -88,7 +88,7 @@ class CompletionResponse(BaseModel): """Completion response.""" completion_message: CompletionMessage - logprobs: Optional[List[float]] = None + logprobs: Optional[List[TokenLogProbs]] = None @json_schema_type @@ -97,7 +97,7 @@ class CompletionResponseStreamChunk(BaseModel): delta: str stop_reason: Optional[StopReason] = None - logprobs: Optional[List[float]] = None + logprobs: Optional[List[TokenLogProbs]] = None @json_schema_type @@ -105,7 +105,7 @@ class BatchCompletionRequest(BaseModel): model: str content_batch: List[InterleavedTextMedia] sampling_params: Optional[SamplingParams] = SamplingParams() - logprobs: Optional[bool] = False + logprobs: Optional[LogProbConfig] = None @json_schema_type @@ -129,7 +129,7 @@ class ChatCompletionRequest(BaseModel): ) stream: Optional[bool] = False - logprobs: Optional[bool] = False + logprobs: Optional[LogProbConfig] = None @json_schema_type @@ -144,7 +144,7 @@ class ChatCompletionResponse(BaseModel): """Chat completion response.""" completion_message: CompletionMessage - logprobs: Optional[List[float]] = None + logprobs: Optional[List[TokenLogProbs]] = None @json_schema_type @@ -159,7 +159,7 @@ class BatchChatCompletionRequest(BaseModel): tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) - logprobs: Optional[bool] = False + logprobs: Optional[LogProbConfig] = None @json_schema_type @@ -180,7 +180,7 @@ class Inference(Protocol): content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, - logprobs: Optional[bool] = None, + logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... @webmethod(route="/inference/chat_completion") @@ -194,7 +194,7 @@ class Inference(Protocol): tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, - logprobs: Optional[bool] = None, + logprobs: Optional[LogProbConfig] = None, ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... @webmethod(route="/inference/embeddings") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index fb562dc2f..2e87b2e24 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -101,7 +101,7 @@ class InferenceRouter(Inference): tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, - logprobs: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = False, ) -> AsyncGenerator: params = dict( model=model, @@ -125,7 +125,7 @@ class InferenceRouter(Inference): content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, - logprobs: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = False, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: return await self.routing_table.get_provider_impl(model).completion( model=model, diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index b8cccf949..e50736b04 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -58,7 +58,7 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider): tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, - logprobs: Optional[bool] = None, + logprobs: Optional[LogProbConfig] = None, ) -> AsyncIterator[ Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] ]: @@ -135,7 +135,13 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider): assert ( len(token_result.logprobs) == 1 ), "Expected logprob to contain 1 result for the current token" - logprobs.append(token_result.logprobs[0]) + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) continue