mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: better using get_default_tool_prompt_format
Summary: Test Plan:
This commit is contained in:
parent
7d111c7510
commit
5f3ec93a35
2 changed files with 6 additions and 7 deletions
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import copy
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack import logcat
|
from llama_stack import logcat
|
||||||
|
@ -54,7 +53,6 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
@ -181,9 +179,6 @@ class InferenceRouter(Inference):
|
||||||
params["tool_prompt_format"] = tool_prompt_format
|
params["tool_prompt_format"] = tool_prompt_format
|
||||||
tool_config = ToolConfig(**params)
|
tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
tool_config = copy.copy(tool_config)
|
|
||||||
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
|
||||||
|
|
||||||
tools = tools or []
|
tools = tools or []
|
||||||
if tool_config.tool_choice == ToolChoice.none:
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
tools = []
|
tools = []
|
||||||
|
|
|
@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, Union
|
||||||
import httpx
|
import httpx
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -253,7 +254,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(
|
model_input = formatter.encode_dialog_prompt(
|
||||||
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
request.messages,
|
||||||
|
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||||
)
|
)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
@ -267,7 +269,8 @@ async def chat_completion_request_to_model_input_info(
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(
|
model_input = formatter.encode_dialog_prompt(
|
||||||
request.messages, tool_prompt_format=request.tool_config.tool_prompt_format
|
request.messages,
|
||||||
|
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
|
@ -461,6 +464,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
||||||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||||
llama_model = resolve_model(model)
|
llama_model = resolve_model(model)
|
||||||
if llama_model is None:
|
if llama_model is None:
|
||||||
|
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||||
return ToolPromptFormat.json
|
return ToolPromptFormat.json
|
||||||
|
|
||||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue