forked from phoenix-oss/llama-stack-mirror
		
	feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063)
# What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Closes https://github.com/meta-llama/llama-stack/issues/1046. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 14 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 7%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 14%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 35%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 85%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%] =============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================ ``` --------- Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
		
							parent
							
								
									bf11cc0450
								
							
						
					
					
						commit
						5e97dd9919
					
				
					 3 changed files with 101 additions and 45 deletions
				
			
		|  | @ -6,7 +6,7 @@ | |||
| 
 | ||||
| import json | ||||
| import warnings | ||||
| from typing import AsyncGenerator, Literal, Union | ||||
| from typing import AsyncGenerator, Literal | ||||
| 
 | ||||
| from groq import Stream | ||||
| from groq.types.chat.chat_completion import ChatCompletion | ||||
|  | @ -15,9 +15,6 @@ from groq.types.chat.chat_completion_assistant_message_param import ( | |||
| ) | ||||
| from groq.types.chat.chat_completion_chunk import ChatCompletionChunk | ||||
| from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam | ||||
| from groq.types.chat.chat_completion_message_tool_call import ( | ||||
|     ChatCompletionMessageToolCall, | ||||
| ) | ||||
| from groq.types.chat.chat_completion_system_message_param import ( | ||||
|     ChatCompletionSystemMessageParam, | ||||
| ) | ||||
|  | @ -30,7 +27,6 @@ from groq.types.shared.function_definition import FunctionDefinition | |||
| 
 | ||||
| from llama_models.llama3.api.datatypes import ToolParamDefinition | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import ( | ||||
|     TextDelta, | ||||
|  | @ -52,6 +48,8 @@ from llama_stack.apis.inference import ( | |||
| ) | ||||
| from llama_stack.providers.utils.inference.openai_compat import ( | ||||
|     get_sampling_strategy_options, | ||||
|     convert_tool_call, | ||||
|     UnparseableToolCall, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -143,7 +141,7 @@ def convert_chat_completion_response( | |||
|     # groq only supports n=1 at time of writing, so there is only one choice | ||||
|     choice = response.choices[0] | ||||
|     if choice.finish_reason == "tool_calls": | ||||
|         tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls] | ||||
|         tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] | ||||
|         if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): | ||||
|             # If we couldn't parse a tool call, jsonify the tool calls and return them | ||||
|             return ChatCompletionResponse( | ||||
|  | @ -216,7 +214,7 @@ async def convert_chat_completion_response_stream( | |||
|                 warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") | ||||
| 
 | ||||
|             # We assume Groq produces fully formed tool calls for each chunk | ||||
|             tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) | ||||
|             tool_call = convert_tool_call(choice.delta.tool_calls[0]) | ||||
|             if isinstance(tool_call, ToolCall): | ||||
|                 yield ChatCompletionResponseStreamChunk( | ||||
|                     event=ChatCompletionResponseEvent( | ||||
|  | @ -247,37 +245,3 @@ async def convert_chat_completion_response_stream( | |||
|                 ) | ||||
|             ) | ||||
|         event_type = ChatCompletionResponseEventType.progress | ||||
| 
 | ||||
| 
 | ||||
| class UnparseableToolCall(BaseModel): | ||||
|     """ | ||||
|     A ToolCall with arguments that are not valid JSON. | ||||
|     Mirrors the ToolCall schema, but with arguments as a string. | ||||
|     """ | ||||
| 
 | ||||
|     call_id: str | ||||
|     tool_name: str | ||||
|     arguments: str | ||||
| 
 | ||||
| 
 | ||||
| def _convert_groq_tool_call( | ||||
|     tool_call: ChatCompletionMessageToolCall, | ||||
| ) -> Union[ToolCall, UnparseableToolCall]: | ||||
|     """ | ||||
|     Convert a Groq tool call to a ToolCall. | ||||
|     Returns an UnparseableToolCall if the tool call is not valid JSON. | ||||
|     """ | ||||
|     try: | ||||
|         arguments = json.loads(tool_call.function.arguments) | ||||
|     except Exception as e: | ||||
|         return UnparseableToolCall( | ||||
|             call_id=tool_call.id, | ||||
|             tool_name=tool_call.function.name, | ||||
|             arguments=tool_call.function.arguments, | ||||
|         ) | ||||
| 
 | ||||
|     return ToolCall( | ||||
|         call_id=tool_call.id, | ||||
|         tool_name=tool_call.function.name, | ||||
|         arguments=arguments, | ||||
|     ) | ||||
|  |  | |||
|  | @ -13,7 +13,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer | |||
| from llama_models.sku_list import all_registered_models | ||||
| from openai import OpenAI | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import InterleavedContent | ||||
| from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta | ||||
| from llama_stack.apis.inference import ( | ||||
|     ChatCompletionRequest, | ||||
|     ChatCompletionResponse, | ||||
|  | @ -32,6 +32,9 @@ from llama_stack.apis.inference import ( | |||
|     ToolDefinition, | ||||
|     ToolPromptFormat, | ||||
|     CompletionMessage, | ||||
|     ChatCompletionResponseEventType, | ||||
|     ChatCompletionResponseStreamChunk, | ||||
|     ChatCompletionResponseEvent, | ||||
| ) | ||||
| from llama_stack.apis.models import Model, ModelType | ||||
| from llama_stack.providers.datatypes import ModelsProtocolPrivate | ||||
|  | @ -42,9 +45,12 @@ from llama_stack.providers.utils.inference.model_registry import ( | |||
| from llama_stack.providers.utils.inference.openai_compat import ( | ||||
|     convert_message_to_openai_dict, | ||||
|     get_sampling_options, | ||||
|     process_chat_completion_stream_response, | ||||
|     process_completion_response, | ||||
|     process_completion_stream_response, | ||||
|     OpenAICompatCompletionResponse, | ||||
|     UnparseableToolCall, | ||||
|     convert_tool_call, | ||||
|     process_chat_completion_stream_response, | ||||
| ) | ||||
| from llama_stack.providers.utils.inference.prompt_adapter import ( | ||||
|     completion_request_to_prompt, | ||||
|  | @ -136,6 +142,51 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: | |||
|     }.get(finish_reason, StopReason.end_of_turn) | ||||
| 
 | ||||
| 
 | ||||
| async def _process_vllm_chat_completion_stream_response( | ||||
|     stream: AsyncGenerator[OpenAICompatCompletionResponse, None], | ||||
| ) -> AsyncGenerator: | ||||
|     event_type = ChatCompletionResponseEventType.start | ||||
|     tool_call_buf = UnparseableToolCall() | ||||
|     async for chunk in stream: | ||||
|         choice = chunk.choices[0] | ||||
|         if choice.finish_reason: | ||||
|             yield ChatCompletionResponseStreamChunk( | ||||
|                 event=ChatCompletionResponseEvent( | ||||
|                     event_type=event_type, | ||||
|                     delta=ToolCallDelta( | ||||
|                         tool_call=ToolCall( | ||||
|                             call_id=tool_call_buf.call_id, | ||||
|                             tool_name=tool_call_buf.tool_name, | ||||
|                             arguments=json.loads(tool_call_buf.arguments), | ||||
|                         ), | ||||
|                         parse_status=ToolCallParseStatus.succeeded, | ||||
|                     ), | ||||
|                 ) | ||||
|             ) | ||||
|             yield ChatCompletionResponseStreamChunk( | ||||
|                 event=ChatCompletionResponseEvent( | ||||
|                     event_type=ChatCompletionResponseEventType.complete, | ||||
|                     delta=TextDelta(text=choice.delta.content or ""), | ||||
|                     logprobs=None, | ||||
|                     stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), | ||||
|                 ) | ||||
|             ) | ||||
|         elif choice.delta.tool_calls: | ||||
|             tool_call = convert_tool_call(choice.delta.tool_calls[0]) | ||||
|             tool_call_buf.tool_name += tool_call.tool_name | ||||
|             tool_call_buf.call_id += tool_call.call_id | ||||
|             tool_call_buf.arguments += tool_call.arguments | ||||
|         else: | ||||
|             yield ChatCompletionResponseStreamChunk( | ||||
|                 event=ChatCompletionResponseEvent( | ||||
|                     event_type=event_type, | ||||
|                     delta=TextDelta(text=choice.delta.content or ""), | ||||
|                     logprobs=None, | ||||
|                 ) | ||||
|             ) | ||||
|             event_type = ChatCompletionResponseEventType.progress | ||||
| 
 | ||||
| 
 | ||||
| class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | ||||
|     def __init__(self, config: VLLMInferenceAdapterConfig) -> None: | ||||
|         self.register_helper = ModelRegistryHelper(build_model_aliases()) | ||||
|  | @ -232,7 +283,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|                 yield chunk | ||||
| 
 | ||||
|         stream = _to_async_generator() | ||||
|         async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): | ||||
|         if len(request.tools) > 0: | ||||
|             res = _process_vllm_chat_completion_stream_response(stream) | ||||
|         else: | ||||
|             res = process_chat_completion_stream_response(stream, self.formatter, request) | ||||
|         async for chunk in res: | ||||
|             yield chunk | ||||
| 
 | ||||
|     async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ | |||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| import json | ||||
| import logging | ||||
| from typing import AsyncGenerator, Dict, List, Optional, Union | ||||
| 
 | ||||
|  | @ -14,7 +15,8 @@ from llama_models.datatypes import ( | |||
| ) | ||||
| 
 | ||||
| from llama_models.llama3.api.chat_format import ChatFormat | ||||
| from llama_models.llama3.api.datatypes import StopReason | ||||
| from llama_models.llama3.api.datatypes import StopReason, ToolCall | ||||
| from openai.types.chat import ChatCompletionMessageToolCall | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import ( | ||||
|  | @ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals | |||
|         "role": message.role, | ||||
|         "content": content, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class UnparseableToolCall(BaseModel): | ||||
|     """ | ||||
|     A ToolCall with arguments that are not valid JSON. | ||||
|     Mirrors the ToolCall schema, but with arguments as a string. | ||||
|     """ | ||||
| 
 | ||||
|     call_id: str = "" | ||||
|     tool_name: str = "" | ||||
|     arguments: str = "" | ||||
| 
 | ||||
| 
 | ||||
| def convert_tool_call( | ||||
|     tool_call: ChatCompletionMessageToolCall, | ||||
| ) -> Union[ToolCall, UnparseableToolCall]: | ||||
|     """ | ||||
|     Convert a ChatCompletionMessageToolCall tool call to either a | ||||
|     ToolCall or UnparseableToolCall. Returns an UnparseableToolCall | ||||
|     if the tool call is not valid JSON. | ||||
|     """ | ||||
|     try: | ||||
|         arguments = json.loads(tool_call.function.arguments) | ||||
|     except Exception as e: | ||||
|         return UnparseableToolCall( | ||||
|             call_id=tool_call.id or "", | ||||
|             tool_name=tool_call.function.name or "", | ||||
|             arguments=tool_call.function.arguments or "", | ||||
|         ) | ||||
| 
 | ||||
|     return ToolCall( | ||||
|         call_id=tool_call.id, | ||||
|         tool_name=tool_call.function.name, | ||||
|         arguments=arguments, | ||||
|     ) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue