mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-07 07:49:55 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
a3b1c3438b
171 changed files with 14529 additions and 5612 deletions
|
|
@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import httpx
|
||||
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
RawContent,
|
||||
|
|
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
|
|
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -361,14 +358,13 @@ def augment_messages_for_tools_llama_3_1(
|
|||
|
||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
if has_custom_tools:
|
||||
if request.tool_prompt_format == ToolPromptFormat.json:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.json
|
||||
if fmt == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
elif fmt == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
||||
|
||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
|
|
@ -413,7 +409,8 @@ def augment_messages_for_tools_llama_3_2(
|
|||
|
||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||
if custom_tools:
|
||||
if request.tool_prompt_format != ToolPromptFormat.python_list:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,5 +48,27 @@ class RedisKVStoreImpl(KVStore):
|
|||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
cursor = 0
|
||||
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
||||
matching_keys = []
|
||||
while True:
|
||||
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
|
||||
|
||||
return await self.redis.zrangebylex(start_key, end_key)
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||
if start_key <= key_str <= end_key:
|
||||
matching_keys.append(key)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Then fetch all values in a single MGET call
|
||||
if matching_keys:
|
||||
values = await self.redis.mget(matching_keys)
|
||||
return [
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value
|
||||
for value in values
|
||||
if value is not None
|
||||
]
|
||||
|
||||
return []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue