diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py index c59261d2c..90b5398f9 100644 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -7,7 +7,6 @@ from typing import List, Optional import vllm -from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -17,6 +16,7 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) +from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, @@ -114,7 +114,12 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict( # Llama will try to use built-in tools with no tool catalog, so don't enable # tool choice unless at least one tool is enabled. converted_tool_choice = "none" - if request.tool_choice == ToolChoice.auto and request.tools is not None and len(request.tools) > 0: + if ( + request.tool_config is not None + and request.tool_config.tool_choice == ToolChoice.auto + and request.tools is not None + and len(request.tools) > 0 + ): converted_tool_choice = "auto" # TODO: Figure out what to do with the tool_prompt_format argument. diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 639728278..6854b95f6 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -10,22 +10,11 @@ import re import uuid from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union -import llama_models.sku_list - # These vLLM modules contain names that overlap with Llama Stack names, so we import # fully-qualified names import vllm.entrypoints.openai.protocol import vllm.sampling_params from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ( - SamplingParams, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, - TopKSamplingStrategy, - TopPSamplingStrategy, -) from llama_models.llama3.api.tokenizer import Tokenizer from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -54,17 +43,23 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ResponseFormat, + SamplingParams, TextTruncation, TokenLogProbs, ToolChoice, ToolConfig, ) from llama_stack.apis.models import Model -from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.datatypes import ModelsProtocolPrivate -from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases +from llama_stack.models.llama import sku_list +from llama_stack.models.llama.datatypes import ( + StopReason, + ToolCall, + ToolDefinition, + ToolPromptFormat, + TopKSamplingStrategy, + TopPSamplingStrategy, +) +from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ModelsProtocolPrivate, @@ -202,7 +197,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.config = config logger.info(f"Config is: {self.config}") - self.register_helper = ModelRegistryHelper(build_model_aliases()) + self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.formatter = ChatFormat(Tokenizer.get_instance()) # The following are initialized when paths are bound to this provider @@ -255,7 +250,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): logger.debug(f"In register_model({model})") # First attempt to interpret the model coordinates as a Llama model name - resolved_llama_model = llama_models.sku_list.resolve_model(model.provider_model_id) + resolved_llama_model = sku_list.resolve_model(model.provider_model_id) if resolved_llama_model is not None: # Load from Hugging Face repo into default local cache dir model_id_for_vllm = resolved_llama_model.huggingface_repo @@ -616,7 +611,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): formatter = ChatFormat(Tokenizer.get_instance()) # Note that this function call modifies `request` in place. - prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter) + prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id) model_id = list(self.model_ids)[0] # Any model ID will do here completion_response_or_iterator = await self.completion( @@ -633,7 +628,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): raise TypeError( f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request." ) - return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator) + return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request) # elsif not request.stream: if not isinstance(completion_response_or_iterator, CompletionResponse): @@ -654,7 +649,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) async def _chat_completion_for_meta_llama_streaming( - self, formatter: ChatFormat, results_iterator: AsyncIterator + self, results_iterator: AsyncIterator, request: ChatCompletionRequest ) -> AsyncIterator: """ Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate @@ -686,7 +681,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, formatter): + async for chunk in process_chat_completion_stream_response(stream, request): logger.debug(f"Returning chunk: {chunk}") yield chunk