mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
Cleanup after merge
This commit is contained in:
parent
43998e4348
commit
3567a387bc
3 changed files with 19 additions and 41 deletions
|
@ -7,7 +7,6 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
|
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -23,7 +22,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# This file contains OpenAI compatibility code that is currently only used
|
# This file contains OpenAI compatibility code that is currently only used
|
||||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||||
|
@ -77,8 +75,7 @@ def _llama_stack_tools_to_openai_tools(
|
||||||
parameters = {
|
parameters = {
|
||||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||||
"properties": {
|
"properties": {
|
||||||
k: {"type": v.param_type, "description": v.description}
|
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||||
for k, v in t.parameters.items()
|
|
||||||
},
|
},
|
||||||
"required": required_params,
|
"required": required_params,
|
||||||
}
|
}
|
||||||
|
@ -88,11 +85,7 @@ def _llama_stack_tools_to_openai_tools(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||||
result.append(
|
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||||
vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(
|
|
||||||
function=function_def
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,9 +106,7 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
|
|
||||||
converted_messages = [
|
converted_messages = [
|
||||||
# This mystery async call makes the parent function also be async
|
# This mystery async call makes the parent function also be async
|
||||||
await convert_message_to_openai_dict(
|
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||||
_merge_context_into_content(m), download=True
|
|
||||||
)
|
|
||||||
for m in request.messages
|
for m in request.messages
|
||||||
]
|
]
|
||||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||||
|
@ -123,11 +114,7 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||||
# tool choice unless at least one tool is enabled.
|
# tool choice unless at least one tool is enabled.
|
||||||
converted_tool_choice = "none"
|
converted_tool_choice = "none"
|
||||||
if (
|
if request.tool_choice == ToolChoice.auto and request.tools is not None and len(request.tools) > 0:
|
||||||
request.tool_choice == ToolChoice.auto
|
|
||||||
and request.tools is not None
|
|
||||||
and len(request.tools) > 0
|
|
||||||
):
|
|
||||||
converted_tool_choice = "auto"
|
converted_tool_choice = "auto"
|
||||||
|
|
||||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||||
|
@ -143,13 +130,8 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
# API will handle correctly. Two wrongs make a right...
|
# API will handle correctly. Two wrongs make a right...
|
||||||
if "repeat_penalty" in sampling_options:
|
if "repeat_penalty" in sampling_options:
|
||||||
del sampling_options["repeat_penalty"]
|
del sampling_options["repeat_penalty"]
|
||||||
if (
|
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||||
request.sampling_params.repetition_penalty is not None
|
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||||
and request.sampling_params.repetition_penalty != 1.0
|
|
||||||
):
|
|
||||||
sampling_options["repetition_penalty"] = (
|
|
||||||
request.sampling_params.repetition_penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert a single response format into four different parameters, per
|
# Convert a single response format into four different parameters, per
|
||||||
# the OpenAI spec
|
# the OpenAI spec
|
||||||
|
@ -162,10 +144,7 @@ async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||||
f"ResponseFormat object is of unexpected "
|
|
||||||
f"subtype '{type(request.response_format)}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
logprob_options = dict()
|
logprob_options = dict()
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
|
|
|
@ -20,14 +20,13 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
# We deep-import the names that don't conflict with Llama Stack names
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
@ -35,6 +34,7 @@ from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingM
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
)
|
)
|
||||||
|
@ -54,9 +54,10 @@ from llama_stack.apis.inference import (
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TextTruncation,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolCall,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
@ -254,7 +255,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logger.debug(f"In register_model({model})")
|
logger.debug(f"In register_model({model})")
|
||||||
|
|
||||||
# First attempt to interpret the model coordinates as a Llama model name
|
# First attempt to interpret the model coordinates as a Llama model name
|
||||||
resolved_llama_model = resolve_model(model.provider_model_id)
|
resolved_llama_model = llama_models.sku_list.resolve_model(model.provider_model_id)
|
||||||
if resolved_llama_model is not None:
|
if resolved_llama_model is not None:
|
||||||
# Load from Hugging Face repo into default local cache dir
|
# Load from Hugging Face repo into default local cache dir
|
||||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||||
|
@ -277,16 +278,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
# Model already loaded
|
# Model already loaded
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Requested id {model} resolves to {model_id_for_vllm}, "
|
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
||||||
f"which is already loaded. Continuing."
|
|
||||||
)
|
)
|
||||||
self.model_ids.add(model.model_id)
|
self.model_ids.add(model.model_id)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
||||||
f"Requested id {model} resolves to {model_id_for_vllm}. Loading "
|
|
||||||
f"{model_id_for_vllm}."
|
|
||||||
)
|
|
||||||
if is_meta_llama_model:
|
if is_meta_llama_model:
|
||||||
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
||||||
self.is_meta_llama_model = is_meta_llama_model
|
self.is_meta_llama_model = is_meta_llama_model
|
||||||
|
@ -425,7 +422,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||||
if model_id not in self.model_ids:
|
if model_id not in self.model_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||||
|
|
|
@ -15,11 +15,12 @@ providers:
|
||||||
- provider_id: vllm
|
- provider_id: vllm
|
||||||
provider_type: inline::vllm
|
provider_type: inline::vllm
|
||||||
config:
|
config:
|
||||||
model: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
|
|
||||||
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:1}
|
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:1}
|
||||||
max_tokens: ${env.MAX_TOKENS:4096}
|
max_tokens: ${env.MAX_TOKENS:4096}
|
||||||
|
max_model_len: ${env.MAX_MODEL_LEN:4096}
|
||||||
|
max_num_seqs: ${env.MAX_NUM_SEQS:4}
|
||||||
enforce_eager: ${env.ENFORCE_EAGER:False}
|
enforce_eager: ${env.ENFORCE_EAGER:False}
|
||||||
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:0.7}
|
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:0.3}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue