mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Use Llama Stack template when streaming
This commit is contained in:
parent
ade413f1e3
commit
5d54c2ee70
1 changed files with 98 additions and 59 deletions
|
@ -68,7 +68,12 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
get_stop_reason,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
)
|
)
|
||||||
|
@ -92,7 +97,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Adjust logging parameters from Python code. This appears to be the standard way to control
|
# Adjust logging parameters from Python code. This appears to be the standard way to control
|
||||||
# logging in Llama Stack.
|
# logging in Llama Stack.
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.INFO)
|
||||||
stderr_handler = logging.StreamHandler()
|
stderr_handler = logging.StreamHandler()
|
||||||
stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s"))
|
stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s"))
|
||||||
logger.addHandler(stderr_handler)
|
logger.addHandler(stderr_handler)
|
||||||
|
@ -266,7 +271,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
else: # if resolved_llama_model is None
|
else: # if resolved_llama_model is None
|
||||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||||
resolved_model_id = model.provider_model_id
|
resolved_model_id = model.provider_model_id
|
||||||
is_meta_llama_model = True
|
is_meta_llama_model = False
|
||||||
|
|
||||||
logger.info(f"Model id {model} resolved to {resolved_model_id}")
|
logger.info(f"Model id {model} resolved to {resolved_model_id}")
|
||||||
|
|
||||||
|
@ -282,6 +287,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.model_ids.add(model.model_id)
|
self.model_ids.add(model.model_id)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
if is_meta_llama_model:
|
||||||
|
logger.info(f"Model {resolved_model_id} is a Meta Llama model.")
|
||||||
self.is_meta_llama_model = is_meta_llama_model
|
self.is_meta_llama_model = is_meta_llama_model
|
||||||
logger.info(f"Preloading model: {resolved_model_id}")
|
logger.info(f"Preloading model: {resolved_model_id}")
|
||||||
|
|
||||||
|
@ -443,11 +450,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
f"Routing {self.resolved_model_id} chat completion through "
|
f"Routing {self.resolved_model_id} chat completion through "
|
||||||
f"Llama Stack's templating layer instead of vLLM's."
|
f"Llama Stack's templating layer instead of vLLM's."
|
||||||
)
|
)
|
||||||
if stream:
|
return await self._chat_completion_for_meta_llama(request)
|
||||||
# return self._chat_completion_for_meta_llama_streaming(request)
|
|
||||||
pass # Use vLLM until the above method is implemented.
|
|
||||||
else:
|
|
||||||
return await self._chat_completion_for_meta_llama_non_streaming(request)
|
|
||||||
|
|
||||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||||
|
|
||||||
|
@ -523,15 +526,19 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
# The final output chunk should be labeled with the reason that the overall generate()
|
# The final output chunk should be labeled with the reason that the overall generate()
|
||||||
# call completed.
|
# call completed.
|
||||||
stop_reason_str = output.stop_reason
|
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
||||||
if stop_reason_str is None:
|
if output.stop_reason is None:
|
||||||
stop_reason = None # Still going
|
stop_reason = None # Still going
|
||||||
elif stop_reason_str == "stop":
|
elif output.stop_reason == "stop":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif stop_reason_str == "length":
|
elif output.stop_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
elif isinstance(output.stop_reason, int):
|
||||||
|
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
||||||
|
# will return the token ID of the EOS token in the stop_reason field.
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
|
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
||||||
|
|
||||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||||
# some reason.
|
# some reason.
|
||||||
|
@ -549,49 +556,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _chat_completion_for_meta_llama_non_streaming(
|
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
"""
|
|
||||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
|
||||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
|
||||||
of the chat template currently produces more reliable outputs.
|
|
||||||
|
|
||||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
|
||||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.debug("Routing request through Llama Stack templates.")
|
|
||||||
|
|
||||||
formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
# Note that this function call modifies `request` in place.
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
|
|
||||||
|
|
||||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
|
||||||
completion_response = await self.completion(
|
|
||||||
model_id=model_id,
|
|
||||||
content=prompt,
|
|
||||||
sampling_params=request.sampling_params,
|
|
||||||
response_format=request.response_format,
|
|
||||||
stream=False,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
)
|
|
||||||
if not isinstance(completion_response, CompletionResponse): # Sanity check
|
|
||||||
raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.")
|
|
||||||
|
|
||||||
raw_message = formatter.decode_assistant_message_from_content(
|
|
||||||
completion_response.content, completion_response.stop_reason
|
|
||||||
)
|
|
||||||
return ChatCompletionResponse(
|
|
||||||
completion_message=CompletionMessage(
|
|
||||||
content=raw_message.content,
|
|
||||||
stop_reason=raw_message.stop_reason,
|
|
||||||
tool_calls=raw_message.tool_calls,
|
|
||||||
),
|
|
||||||
logprobs=completion_response.logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_non_streaming_results(
|
def _convert_non_streaming_results(
|
||||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
|
@ -634,7 +598,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
completion_message=converted_message,
|
completion_message=converted_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator:
|
async def _chat_completion_for_meta_llama(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
"""
|
"""
|
||||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||||
|
@ -643,9 +609,82 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||||
"""
|
"""
|
||||||
logger.debug("Routing streaming request through Llama Stack templates.")
|
formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
raise NotImplementedError()
|
# Note that this function call modifies `request` in place.
|
||||||
|
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
|
||||||
|
|
||||||
|
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||||
|
completion_response_or_iterator = await self.completion(
|
||||||
|
model_id=model_id,
|
||||||
|
content=prompt,
|
||||||
|
sampling_params=request.sampling_params,
|
||||||
|
response_format=request.response_format,
|
||||||
|
stream=request.stream,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
||||||
|
raise TypeError(
|
||||||
|
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||||
|
)
|
||||||
|
return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator)
|
||||||
|
|
||||||
|
# elsif not request.stream:
|
||||||
|
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||||
|
raise TypeError(
|
||||||
|
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
||||||
|
)
|
||||||
|
completion_response: CompletionResponse = completion_response_or_iterator
|
||||||
|
raw_message = formatter.decode_assistant_message_from_content(
|
||||||
|
completion_response.content, completion_response.stop_reason
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=raw_message.content,
|
||||||
|
stop_reason=raw_message.stop_reason,
|
||||||
|
tool_calls=raw_message.tool_calls,
|
||||||
|
),
|
||||||
|
logprobs=completion_response.logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _chat_completion_for_meta_llama_streaming(
|
||||||
|
self, formatter: ChatFormat, results_iterator: AsyncIterator
|
||||||
|
) -> AsyncIterator:
|
||||||
|
"""
|
||||||
|
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||||
|
method to keep asyncio happy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
||||||
|
last_text_len = 0
|
||||||
|
async for chunk in results_iterator:
|
||||||
|
if chunk.stop_reason == StopReason.end_of_turn:
|
||||||
|
finish_reason = "stop"
|
||||||
|
elif chunk.stop_reason == StopReason.end_of_message:
|
||||||
|
finish_reason = "eos"
|
||||||
|
elif chunk.stop_reason == StopReason.out_of_tokens:
|
||||||
|
finish_reason = "length"
|
||||||
|
else:
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
|
# Convert delta back to an actual delta
|
||||||
|
text_delta = chunk.delta[last_text_len:]
|
||||||
|
last_text_len = len(chunk.delta)
|
||||||
|
|
||||||
|
logger.debug(f"{text_delta=}; {finish_reason=}")
|
||||||
|
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_chat_completion_stream_response(stream, formatter):
|
||||||
|
logger.debug(f"Returning chunk: {chunk}")
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue