Use Llama Stack template when streaming

This commit is contained in:
Fred Reiss 2025-02-04 13:09:49 -08:00 committed by Ashwin Bharambe
parent ade413f1e3
commit 5d54c2ee70

View file

@ -68,7 +68,12 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
ModelsProtocolPrivate, ModelsProtocolPrivate,
) )
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_stop_reason,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
) )
@ -92,7 +97,7 @@ logger = logging.getLogger(__name__)
# Adjust logging parameters from Python code. This appears to be the standard way to control # Adjust logging parameters from Python code. This appears to be the standard way to control
# logging in Llama Stack. # logging in Llama Stack.
logger.setLevel(logging.DEBUG) logger.setLevel(logging.INFO)
stderr_handler = logging.StreamHandler() stderr_handler = logging.StreamHandler()
stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s")) stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s"))
logger.addHandler(stderr_handler) logger.addHandler(stderr_handler)
@ -266,7 +271,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
else: # if resolved_llama_model is None else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader # Not a Llama model name. Pass the model id through to vLLM's loader
resolved_model_id = model.provider_model_id resolved_model_id = model.provider_model_id
is_meta_llama_model = True is_meta_llama_model = False
logger.info(f"Model id {model} resolved to {resolved_model_id}") logger.info(f"Model id {model} resolved to {resolved_model_id}")
@ -282,6 +287,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.model_ids.add(model.model_id) self.model_ids.add(model.model_id)
return model return model
if is_meta_llama_model:
logger.info(f"Model {resolved_model_id} is a Meta Llama model.")
self.is_meta_llama_model = is_meta_llama_model self.is_meta_llama_model = is_meta_llama_model
logger.info(f"Preloading model: {resolved_model_id}") logger.info(f"Preloading model: {resolved_model_id}")
@ -443,11 +450,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
f"Routing {self.resolved_model_id} chat completion through " f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's." f"Llama Stack's templating layer instead of vLLM's."
) )
if stream: return await self._chat_completion_for_meta_llama(request)
# return self._chat_completion_for_meta_llama_streaming(request)
pass # Use vLLM until the above method is implemented.
else:
return await self._chat_completion_for_meta_llama_non_streaming(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model") logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
@ -523,15 +526,19 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# The final output chunk should be labeled with the reason that the overall generate() # The final output chunk should be labeled with the reason that the overall generate()
# call completed. # call completed.
stop_reason_str = output.stop_reason logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
if stop_reason_str is None: if output.stop_reason is None:
stop_reason = None # Still going stop_reason = None # Still going
elif stop_reason_str == "stop": elif output.stop_reason == "stop":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif stop_reason_str == "length": elif output.stop_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
elif isinstance(output.stop_reason, int):
# If the model config specifies multiple end-of-sequence tokens, then vLLM
# will return the token ID of the EOS token in the stop_reason field.
stop_reason = StopReason.end_of_turn
else: else:
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'") raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
# vLLM's protocol outputs the stop token, then sets end of message on the next step for # vLLM's protocol outputs the stop token, then sets end of message on the next step for
# some reason. # some reason.
@ -549,49 +556,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs, logprobs=logprobs,
) )
async def _chat_completion_for_meta_llama_non_streaming(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
logger.debug("Routing request through Llama Stack templates.")
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=False,
logprobs=request.logprobs,
)
if not isinstance(completion_response, CompletionResponse): # Sanity check
raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.")
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
def _convert_non_streaming_results( def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
@ -634,7 +598,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
completion_message=converted_message, completion_message=converted_message,
) )
def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator: async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
""" """
Subroutine that routes chat completions for Meta Llama models through Llama Stack's Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version chat template instead of using vLLM's version of that template. The Llama Stack version
@ -643,9 +609,82 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
Once vLLM's support for Meta Llama models has matured more, we should consider routing Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method. Meta Llama requests through the vLLM chat completions API instead of using this method.
""" """
logger.debug("Routing streaming request through Llama Stack templates.") formatter = ChatFormat(Tokenizer.get_instance())
raise NotImplementedError() # Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=request.stream,
logprobs=request.logprobs,
)
if request.stream:
if not isinstance(completion_response_or_iterator, AsyncIterator):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
)
return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator)
# elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
)
completion_response: CompletionResponse = completion_response_or_iterator
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
async def _chat_completion_for_meta_llama_streaming(
self, formatter: ChatFormat, results_iterator: AsyncIterator
) -> AsyncIterator:
"""
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
method to keep asyncio happy.
"""
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
async def _generate_and_convert_to_openai_compat():
chunk: CompletionResponseStreamChunk # Make Pylance happy
last_text_len = 0
async for chunk in results_iterator:
if chunk.stop_reason == StopReason.end_of_turn:
finish_reason = "stop"
elif chunk.stop_reason == StopReason.end_of_message:
finish_reason = "eos"
elif chunk.stop_reason == StopReason.out_of_tokens:
finish_reason = "length"
else:
finish_reason = None
# Convert delta back to an actual delta
text_delta = chunk.delta[last_text_len:]
last_text_len = len(chunk.delta)
logger.debug(f"{text_delta=}; {finish_reason=}")
yield OpenAICompatCompletionResponse(
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, formatter):
logger.debug(f"Returning chunk: {chunk}")
yield chunk
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator: async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
""" """