mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Further cleanup after merge
This commit is contained in:
parent
3567a387bc
commit
8040f1463e
2 changed files with 24 additions and 24 deletions
|
@ -7,7 +7,6 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import vllm
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -17,6 +16,7 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
|
@ -114,7 +114,12 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
|||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if request.tool_choice == ToolChoice.auto and request.tools is not None and len(request.tools) > 0:
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
|
|
|
@ -10,22 +10,11 @@ import re
|
|||
import uuid
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import llama_models.sku_list
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
@ -54,17 +43,23 @@ from llama_stack.apis.inference import (
|
|||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases
|
||||
from llama_stack.models.llama import sku_list
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
|
@ -202,7 +197,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self.config = config
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# The following are initialized when paths are bound to this provider
|
||||
|
@ -255,7 +250,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = llama_models.sku_list.resolve_model(model.provider_model_id)
|
||||
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||
|
@ -616,7 +611,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response_or_iterator = await self.completion(
|
||||
|
@ -633,7 +628,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||
)
|
||||
return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator)
|
||||
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||
|
||||
# elsif not request.stream:
|
||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||
|
@ -654,7 +649,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama_streaming(
|
||||
self, formatter: ChatFormat, results_iterator: AsyncIterator
|
||||
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||
|
@ -686,7 +681,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, formatter):
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
logger.debug(f"Returning chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue