Further cleanup after merge

This commit is contained in:
Fred Reiss 2025-02-21 15:50:07 -08:00 committed by Ashwin Bharambe
parent 3567a387bc
commit 8040f1463e
2 changed files with 24 additions and 24 deletions

View file

@ -7,7 +7,6 @@
from typing import List, Optional
import vllm
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -17,6 +16,7 @@ from llama_stack.apis.inference import (
ToolChoice,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
@ -114,7 +114,12 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if request.tool_choice == ToolChoice.auto and request.tools is not None and len(request.tools) > 0:
if (
request.tool_config is not None
and request.tool_config.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.

View file

@ -10,22 +10,11 @@ import re
import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import llama_models.sku_list
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -54,17 +43,23 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
TokenLogProbs,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases
from llama_stack.models.llama import sku_list
from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
ModelsProtocolPrivate,
@ -202,7 +197,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.config = config
logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.formatter = ChatFormat(Tokenizer.get_instance())
# The following are initialized when paths are bound to this provider
@ -255,7 +250,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = llama_models.sku_list.resolve_model(model.provider_model_id)
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir
model_id_for_vllm = resolved_llama_model.huggingface_repo
@ -616,7 +611,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion(
@ -633,7 +628,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
)
return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator)
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
# elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse):
@ -654,7 +649,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
async def _chat_completion_for_meta_llama_streaming(
self, formatter: ChatFormat, results_iterator: AsyncIterator
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
) -> AsyncIterator:
"""
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
@ -686,7 +681,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, formatter):
async for chunk in process_chat_completion_stream_response(stream, request):
logger.debug(f"Returning chunk: {chunk}")
yield chunk