Update for latest APIs

This commit is contained in:
Fred Reiss 2025-01-24 16:55:56 -08:00 committed by Ashwin Bharambe
parent fcb87faa36
commit 25c780802f

View file

@ -47,8 +47,8 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
TokenLogProbs,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -200,13 +200,17 @@ def _response_format_to_guided_decoding_params(
def _convert_sampling_params(
sampling_params: Optional[SamplingParams],
response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig],
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from
Llama Stack's format to vLLM's format."""
# In the absence of provided config values, use Llama Stack defaults
# a encoded in the Llama Stack dataclasses. These defaults are
# different from vLLM's defaults.
if sampling_params is None:
# In the absence of a user-provided sampling config, we use
# Llama Stack defaults, which are different from vLLM defaults.
sampling_params = SamplingParams()
if log_prob_config is None:
log_prob_config = LogProbConfig()
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
if sampling_params.strategy.top_k == 0:
@ -235,6 +239,7 @@ def _convert_sampling_params(
top_k=vllm_top_k,
repetition_penalty=sampling_params.repetition_penalty,
guided_decoding=_response_format_to_guided_decoding_params(response_format),
logprobs=log_prob_config.top_k,
)
return vllm_sampling_params
@ -456,10 +461,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
raise NotImplementedError("Multimodal input not currently supported")
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs is not None:
raise NotImplementedError("logprobs argument not currently implemented")
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
_info(f"{converted_sampling_params=}")
if stream:
return self._streaming_completion(content, converted_sampling_params)
@ -505,12 +510,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the
# generate() call. There should be one or more output chunks.
completion_string = "".join([output.text for output in request_output.outputs])
# generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs
# only allow us to return one.
output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text
# Convert logprobs from vLLM's format to Llama Stack's format
logprobs = [
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
for logprob_dict in output.logprobs
]
# The final output chunk should be labeled with the reason that the
# overall generate() call completed.
stop_reason_str = request_output.outputs[-1].stop_reason
stop_reason_str = output.stop_reason
if stop_reason_str is None:
stop_reason = None # Still going
elif stop_reason_str == "stop":
@ -529,12 +543,16 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason)
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but
# vLLM doesn't always provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens)
yield CompletionResponseStreamChunk(
delta=completion_string,
stop_reason=StopReason.out_of_tokens,
logprobs=logprobs,
)
async def embeddings(
self,
@ -559,8 +577,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
if logprobs is not None:
raise NotImplementedError("logprobs argument not currently implemented")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
# dataclass.
@ -569,7 +585,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
converted_messages = [
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
]
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
converted_tools = _convert_tools(tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
@ -578,8 +594,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument
# TODO: Convert logprobs argument
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
model=self.resolved_model_id,
@ -587,8 +603,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tools=converted_tools,
tool_choice=converted_tool_choice,
stream=stream,
# tool_prompt_format=tool_prompt_format,
# logprobs=logprobs,
)
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
@ -727,7 +741,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str)
# print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs
# only support returning one.
@ -778,7 +792,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(content=tool_call_record, parse_status="succeeded"),
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
stop_reason=converted_stop_reason,
)
)