Further cleanup after merge

This commit is contained in:
Fred Reiss 2025-02-21 15:50:07 -08:00 committed by Ashwin Bharambe
parent 3567a387bc
commit 8040f1463e
2 changed files with 24 additions and 24 deletions

View file

@ -7,7 +7,6 @@
from typing import List, Optional from typing import List, Optional
import vllm import vllm
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -17,6 +16,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
@ -114,7 +114,12 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
# Llama will try to use built-in tools with no tool catalog, so don't enable # Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled. # tool choice unless at least one tool is enabled.
converted_tool_choice = "none" converted_tool_choice = "none"
if request.tool_choice == ToolChoice.auto and request.tools is not None and len(request.tools) > 0: if (
request.tool_config is not None
and request.tool_config.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto" converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument. # TODO: Figure out what to do with the tool_prompt_format argument.

View file

@ -10,22 +10,11 @@ import re
import uuid import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import llama_models.sku_list
# These vLLM modules contain names that overlap with Llama Stack names, so we import # These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names # fully-qualified names
import vllm.entrypoints.openai.protocol import vllm.entrypoints.openai.protocol
import vllm.sampling_params import vllm.sampling_params
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -54,17 +43,23 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams,
TextTruncation, TextTruncation,
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama import sku_list
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.datatypes import (
from llama_stack.models.llama.sku_list import resolve_model StopReason,
from llama_stack.providers.datatypes import ModelsProtocolPrivate ToolCall,
from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
ModelsProtocolPrivate, ModelsProtocolPrivate,
@ -202,7 +197,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.config = config self.config = config
logger.info(f"Config is: {self.config}") logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_model_aliases()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
# The following are initialized when paths are bound to this provider # The following are initialized when paths are bound to this provider
@ -255,7 +250,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logger.debug(f"In register_model({model})") logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name # First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = llama_models.sku_list.resolve_model(model.provider_model_id) resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
if resolved_llama_model is not None: if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir # Load from Hugging Face repo into default local cache dir
model_id_for_vllm = resolved_llama_model.huggingface_repo model_id_for_vllm = resolved_llama_model.huggingface_repo
@ -616,7 +611,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
formatter = ChatFormat(Tokenizer.get_instance()) formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place. # Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter) prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
model_id = list(self.model_ids)[0] # Any model ID will do here model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion( completion_response_or_iterator = await self.completion(
@ -633,7 +628,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
raise TypeError( raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request." f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
) )
return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator) return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
# elsif not request.stream: # elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse): if not isinstance(completion_response_or_iterator, CompletionResponse):
@ -654,7 +649,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
async def _chat_completion_for_meta_llama_streaming( async def _chat_completion_for_meta_llama_streaming(
self, formatter: ChatFormat, results_iterator: AsyncIterator self, results_iterator: AsyncIterator, request: ChatCompletionRequest
) -> AsyncIterator: ) -> AsyncIterator:
""" """
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
@ -686,7 +681,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
stream = _generate_and_convert_to_openai_compat() stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, formatter): async for chunk in process_chat_completion_stream_response(stream, request):
logger.debug(f"Returning chunk: {chunk}") logger.debug(f"Returning chunk: {chunk}")
yield chunk yield chunk