diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 4e1fc853d..265c2ab78 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -47,8 +47,8 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ResponseFormat, - SamplingParams, - TextTruncation, + TokenLogProbs, + ToolCall, ToolChoice, ToolConfig, ToolDefinition, @@ -200,13 +200,17 @@ def _response_format_to_guided_decoding_params( def _convert_sampling_params( sampling_params: Optional[SamplingParams], response_format: Optional[ResponseFormat], # type: ignore + log_prob_config: Optional[LogProbConfig], ) -> vllm.SamplingParams: """Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's format.""" + # In the absence of provided config values, use Llama Stack defaults + # a encoded in the Llama Stack dataclasses. These defaults are + # different from vLLM's defaults. if sampling_params is None: - # In the absence of a user-provided sampling config, we use - # Llama Stack defaults, which are different from vLLM defaults. sampling_params = SamplingParams() + if log_prob_config is None: + log_prob_config = LogProbConfig() if isinstance(sampling_params.strategy, TopKSamplingStrategy): if sampling_params.strategy.top_k == 0: @@ -235,6 +239,7 @@ def _convert_sampling_params( top_k=vllm_top_k, repetition_penalty=sampling_params.repetition_penalty, guided_decoding=_response_format_to_guided_decoding_params(response_format), + logprobs=log_prob_config.top_k, ) return vllm_sampling_params @@ -456,10 +461,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): raise NotImplementedError("Multimodal input not currently supported") if sampling_params is None: sampling_params = SamplingParams() - if logprobs is not None: - raise NotImplementedError("logprobs argument not currently implemented") - converted_sampling_params = _convert_sampling_params(sampling_params, response_format) + converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) + + _info(f"{converted_sampling_params=}") if stream: return self._streaming_completion(content, converted_sampling_params) @@ -505,12 +510,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): raise ValueError("Inference produced empty result") # If we get here, then request_output contains the final output of the - # generate() call. There should be one or more output chunks. - completion_string = "".join([output.text for output in request_output.outputs]) + # generate() call. + # The result may include multiple alternate outputs, but Llama Stack APIs + # only allow us to return one. + output: vllm.CompletionOutput = request_output.outputs[0] + completion_string = output.text + + # Convert logprobs from vLLM's format to Llama Stack's format + logprobs = [ + TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()}) + for logprob_dict in output.logprobs + ] # The final output chunk should be labeled with the reason that the # overall generate() call completed. - stop_reason_str = request_output.outputs[-1].stop_reason + stop_reason_str = output.stop_reason if stop_reason_str is None: stop_reason = None # Still going elif stop_reason_str == "stop": @@ -529,12 +543,16 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if request_output.outputs[-1].token_ids[-1] == eos_token_id: stop_reason = StopReason.end_of_message - yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason) + yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs) # Llama Stack requires that the last chunk have a stop reason, but # vLLM doesn't always provide one if it runs out of tokens. if stop_reason is None: - yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens) + yield CompletionResponseStreamChunk( + delta=completion_string, + stop_reason=StopReason.out_of_tokens, + logprobs=logprobs, + ) async def embeddings( self, @@ -559,8 +577,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): raise ValueError( f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" ) - if logprobs is not None: - raise NotImplementedError("logprobs argument not currently implemented") # Arguments to the vLLM call must be packaged as a ChatCompletionRequest # dataclass. @@ -569,7 +585,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): converted_messages = [ await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages ] - converted_sampling_params = _convert_sampling_params(sampling_params, response_format) + converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) converted_tools = _convert_tools(tools) # Llama will try to use built-in tools with no tool catalog, so don't enable @@ -578,8 +594,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0: converted_tool_choice = "auto" - # TODO: Figure out what to do with the tool_prompt_format argument - # TODO: Convert logprobs argument + # TODO: Figure out what to do with the tool_prompt_format argument. + # Other connectors appear to drop it quietly. chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest( model=self.resolved_model_id, @@ -587,8 +603,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): tools=converted_tools, tool_choice=converted_tool_choice, stream=stream, - # tool_prompt_format=tool_prompt_format, - # logprobs=logprobs, ) # vLLM's OpenAI-compatible APIs take sampling parameters as multiple @@ -727,7 +741,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # Anything that is not "[DONE]" should be a JSON record parsed_chunk = json.loads(data_str) - # print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") + print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") # The result may contain multiple completions, but Llama Stack APIs # only support returning one. @@ -778,7 +792,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta(content=tool_call_record, parse_status="succeeded"), + delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"), stop_reason=converted_stop_reason, ) )