Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-13 20:36:14 -08:00
commit a3b1c3438b
171 changed files with 14529 additions and 5612 deletions

View file

@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
RawContent,
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
ToolChoice,
UserMessage,
)
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)
@ -361,14 +358,13 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
fmt = request.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
@ -413,7 +409,8 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
if request.tool_prompt_format != ToolPromptFormat.python_list:
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)

View file

@ -48,5 +48,27 @@ class RedisKVStoreImpl(KVStore):
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0
pattern = start_key + "*" # Match all keys starting with start_key prefix
matching_keys = []
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
return await self.redis.zrangebylex(start_key, end_key)
for key in keys:
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
if start_key <= key_str <= end_key:
matching_keys.append(key)
if cursor == 0:
break
# Then fetch all values in a single MGET call
if matching_keys:
values = await self.redis.mget(matching_keys)
return [
value.decode("utf-8") if isinstance(value, bytes) else value
for value in values
if value is not None
]
return []