mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Update for latest APIs
This commit is contained in:
parent
fcb87faa36
commit
25c780802f
1 changed files with 35 additions and 21 deletions
|
@ -47,8 +47,8 @@ from llama_stack.apis.inference import (
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
TokenLogProbs,
|
||||||
TextTruncation,
|
ToolCall,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -200,13 +200,17 @@ def _response_format_to_guided_decoding_params(
|
||||||
def _convert_sampling_params(
|
def _convert_sampling_params(
|
||||||
sampling_params: Optional[SamplingParams],
|
sampling_params: Optional[SamplingParams],
|
||||||
response_format: Optional[ResponseFormat], # type: ignore
|
response_format: Optional[ResponseFormat], # type: ignore
|
||||||
|
log_prob_config: Optional[LogProbConfig],
|
||||||
) -> vllm.SamplingParams:
|
) -> vllm.SamplingParams:
|
||||||
"""Convert sampling and constrained decoding configuration from
|
"""Convert sampling and constrained decoding configuration from
|
||||||
Llama Stack's format to vLLM's format."""
|
Llama Stack's format to vLLM's format."""
|
||||||
|
# In the absence of provided config values, use Llama Stack defaults
|
||||||
|
# a encoded in the Llama Stack dataclasses. These defaults are
|
||||||
|
# different from vLLM's defaults.
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
# In the absence of a user-provided sampling config, we use
|
|
||||||
# Llama Stack defaults, which are different from vLLM defaults.
|
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
if log_prob_config is None:
|
||||||
|
log_prob_config = LogProbConfig()
|
||||||
|
|
||||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||||
if sampling_params.strategy.top_k == 0:
|
if sampling_params.strategy.top_k == 0:
|
||||||
|
@ -235,6 +239,7 @@ def _convert_sampling_params(
|
||||||
top_k=vllm_top_k,
|
top_k=vllm_top_k,
|
||||||
repetition_penalty=sampling_params.repetition_penalty,
|
repetition_penalty=sampling_params.repetition_penalty,
|
||||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||||
|
logprobs=log_prob_config.top_k,
|
||||||
)
|
)
|
||||||
return vllm_sampling_params
|
return vllm_sampling_params
|
||||||
|
|
||||||
|
@ -456,10 +461,10 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
raise NotImplementedError("Multimodal input not currently supported")
|
raise NotImplementedError("Multimodal input not currently supported")
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
if logprobs is not None:
|
|
||||||
raise NotImplementedError("logprobs argument not currently implemented")
|
|
||||||
|
|
||||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
|
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||||
|
|
||||||
|
_info(f"{converted_sampling_params=}")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._streaming_completion(content, converted_sampling_params)
|
return self._streaming_completion(content, converted_sampling_params)
|
||||||
|
@ -505,12 +510,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError("Inference produced empty result")
|
raise ValueError("Inference produced empty result")
|
||||||
|
|
||||||
# If we get here, then request_output contains the final output of the
|
# If we get here, then request_output contains the final output of the
|
||||||
# generate() call. There should be one or more output chunks.
|
# generate() call.
|
||||||
completion_string = "".join([output.text for output in request_output.outputs])
|
# The result may include multiple alternate outputs, but Llama Stack APIs
|
||||||
|
# only allow us to return one.
|
||||||
|
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||||
|
completion_string = output.text
|
||||||
|
|
||||||
|
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||||
|
logprobs = [
|
||||||
|
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||||
|
for logprob_dict in output.logprobs
|
||||||
|
]
|
||||||
|
|
||||||
# The final output chunk should be labeled with the reason that the
|
# The final output chunk should be labeled with the reason that the
|
||||||
# overall generate() call completed.
|
# overall generate() call completed.
|
||||||
stop_reason_str = request_output.outputs[-1].stop_reason
|
stop_reason_str = output.stop_reason
|
||||||
if stop_reason_str is None:
|
if stop_reason_str is None:
|
||||||
stop_reason = None # Still going
|
stop_reason = None # Still going
|
||||||
elif stop_reason_str == "stop":
|
elif stop_reason_str == "stop":
|
||||||
|
@ -529,12 +543,16 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||||
stop_reason = StopReason.end_of_message
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason)
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||||
|
|
||||||
# Llama Stack requires that the last chunk have a stop reason, but
|
# Llama Stack requires that the last chunk have a stop reason, but
|
||||||
# vLLM doesn't always provide one if it runs out of tokens.
|
# vLLM doesn't always provide one if it runs out of tokens.
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens)
|
yield CompletionResponseStreamChunk(
|
||||||
|
delta=completion_string,
|
||||||
|
stop_reason=StopReason.out_of_tokens,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -559,8 +577,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||||
)
|
)
|
||||||
if logprobs is not None:
|
|
||||||
raise NotImplementedError("logprobs argument not currently implemented")
|
|
||||||
|
|
||||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
||||||
# dataclass.
|
# dataclass.
|
||||||
|
@ -569,7 +585,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
converted_messages = [
|
converted_messages = [
|
||||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
|
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
|
||||||
]
|
]
|
||||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
|
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||||
converted_tools = _convert_tools(tools)
|
converted_tools = _convert_tools(tools)
|
||||||
|
|
||||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||||
|
@ -578,8 +594,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
|
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
|
||||||
converted_tool_choice = "auto"
|
converted_tool_choice = "auto"
|
||||||
|
|
||||||
# TODO: Figure out what to do with the tool_prompt_format argument
|
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||||
# TODO: Convert logprobs argument
|
# Other connectors appear to drop it quietly.
|
||||||
|
|
||||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
|
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
|
||||||
model=self.resolved_model_id,
|
model=self.resolved_model_id,
|
||||||
|
@ -587,8 +603,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
tools=converted_tools,
|
tools=converted_tools,
|
||||||
tool_choice=converted_tool_choice,
|
tool_choice=converted_tool_choice,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
# tool_prompt_format=tool_prompt_format,
|
|
||||||
# logprobs=logprobs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
|
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
|
||||||
|
@ -727,7 +741,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
# Anything that is not "[DONE]" should be a JSON record
|
# Anything that is not "[DONE]" should be a JSON record
|
||||||
parsed_chunk = json.loads(data_str)
|
parsed_chunk = json.loads(data_str)
|
||||||
|
|
||||||
# print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||||
|
|
||||||
# The result may contain multiple completions, but Llama Stack APIs
|
# The result may contain multiple completions, but Llama Stack APIs
|
||||||
# only support returning one.
|
# only support returning one.
|
||||||
|
@ -778,7 +792,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(content=tool_call_record, parse_status="succeeded"),
|
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||||
stop_reason=converted_stop_reason,
|
stop_reason=converted_stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue