mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Update for latest APIs
This commit is contained in:
parent
fcb87faa36
commit
25c780802f
1 changed files with 35 additions and 21 deletions
|
@ -47,8 +47,8 @@ from llama_stack.apis.inference import (
|
|||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -200,13 +200,17 @@ def _response_format_to_guided_decoding_params(
|
|||
def _convert_sampling_params(
|
||||
sampling_params: Optional[SamplingParams],
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
log_prob_config: Optional[LogProbConfig],
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from
|
||||
Llama Stack's format to vLLM's format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults
|
||||
# a encoded in the Llama Stack dataclasses. These defaults are
|
||||
# different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
# In the absence of a user-provided sampling config, we use
|
||||
# Llama Stack defaults, which are different from vLLM defaults.
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
log_prob_config = LogProbConfig()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
|
@ -235,6 +239,7 @@ def _convert_sampling_params(
|
|||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
logprobs=log_prob_config.top_k,
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
@ -456,10 +461,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
raise NotImplementedError("Multimodal input not currently supported")
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs is not None:
|
||||
raise NotImplementedError("logprobs argument not currently implemented")
|
||||
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
_info(f"{converted_sampling_params=}")
|
||||
|
||||
if stream:
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
|
@ -505,12 +510,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the
|
||||
# generate() call. There should be one or more output chunks.
|
||||
completion_string = "".join([output.text for output in request_output.outputs])
|
||||
# generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs
|
||||
# only allow us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||
logprobs = [
|
||||
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the
|
||||
# overall generate() call completed.
|
||||
stop_reason_str = request_output.outputs[-1].stop_reason
|
||||
stop_reason_str = output.stop_reason
|
||||
if stop_reason_str is None:
|
||||
stop_reason = None # Still going
|
||||
elif stop_reason_str == "stop":
|
||||
|
@ -529,12 +543,16 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason)
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but
|
||||
# vLLM doesn't always provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens)
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -559,8 +577,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
if logprobs is not None:
|
||||
raise NotImplementedError("logprobs argument not currently implemented")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
||||
# dataclass.
|
||||
|
@ -569,7 +585,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
converted_messages = [
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
|
||||
]
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
converted_tools = _convert_tools(tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
|
@ -578,8 +594,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument
|
||||
# TODO: Convert logprobs argument
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
|
@ -587,8 +603,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
tools=converted_tools,
|
||||
tool_choice=converted_tool_choice,
|
||||
stream=stream,
|
||||
# tool_prompt_format=tool_prompt_format,
|
||||
# logprobs=logprobs,
|
||||
)
|
||||
|
||||
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
|
||||
|
@ -727,7 +741,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
# print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs
|
||||
# only support returning one.
|
||||
|
@ -778,7 +792,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(content=tool_call_record, parse_status="succeeded"),
|
||||
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue