mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
Further cleanup after merge
This commit is contained in:
parent
3567a387bc
commit
8040f1463e
2 changed files with 24 additions and 24 deletions
|
@ -7,7 +7,6 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -17,6 +16,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
@ -114,7 +114,12 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||||
# tool choice unless at least one tool is enabled.
|
# tool choice unless at least one tool is enabled.
|
||||||
converted_tool_choice = "none"
|
converted_tool_choice = "none"
|
||||||
if request.tool_choice == ToolChoice.auto and request.tools is not None and len(request.tools) > 0:
|
if (
|
||||||
|
request.tool_config is not None
|
||||||
|
and request.tool_config.tool_choice == ToolChoice.auto
|
||||||
|
and request.tools is not None
|
||||||
|
and len(request.tools) > 0
|
||||||
|
):
|
||||||
converted_tool_choice = "auto"
|
converted_tool_choice = "auto"
|
||||||
|
|
||||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||||
|
|
|
@ -10,22 +10,11 @@ import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import llama_models.sku_list
|
|
||||||
|
|
||||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||||
# fully-qualified names
|
# fully-qualified names
|
||||||
import vllm.entrypoints.openai.protocol
|
import vllm.entrypoints.openai.protocol
|
||||||
import vllm.sampling_params
|
import vllm.sampling_params
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
@ -54,17 +43,23 @@ from llama_stack.apis.inference import (
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama import sku_list
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.datatypes import (
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
StopReason,
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
ToolCall,
|
||||||
from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
|
@ -202,7 +197,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
logger.info(f"Config is: {self.config}")
|
logger.info(f"Config is: {self.config}")
|
||||||
|
|
||||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
# The following are initialized when paths are bound to this provider
|
# The following are initialized when paths are bound to this provider
|
||||||
|
@ -255,7 +250,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logger.debug(f"In register_model({model})")
|
logger.debug(f"In register_model({model})")
|
||||||
|
|
||||||
# First attempt to interpret the model coordinates as a Llama model name
|
# First attempt to interpret the model coordinates as a Llama model name
|
||||||
resolved_llama_model = llama_models.sku_list.resolve_model(model.provider_model_id)
|
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||||
if resolved_llama_model is not None:
|
if resolved_llama_model is not None:
|
||||||
# Load from Hugging Face repo into default local cache dir
|
# Load from Hugging Face repo into default local cache dir
|
||||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||||
|
@ -616,7 +611,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
formatter = ChatFormat(Tokenizer.get_instance())
|
formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
# Note that this function call modifies `request` in place.
|
# Note that this function call modifies `request` in place.
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
|
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||||
|
|
||||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||||
completion_response_or_iterator = await self.completion(
|
completion_response_or_iterator = await self.completion(
|
||||||
|
@ -633,7 +628,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||||
)
|
)
|
||||||
return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator)
|
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||||
|
|
||||||
# elsif not request.stream:
|
# elsif not request.stream:
|
||||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||||
|
@ -654,7 +649,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _chat_completion_for_meta_llama_streaming(
|
async def _chat_completion_for_meta_llama_streaming(
|
||||||
self, formatter: ChatFormat, results_iterator: AsyncIterator
|
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||||
) -> AsyncIterator:
|
) -> AsyncIterator:
|
||||||
"""
|
"""
|
||||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||||
|
@ -686,7 +681,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, formatter):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
logger.debug(f"Returning chunk: {chunk}")
|
logger.debug(f"Returning chunk: {chunk}")
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue