mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 11:00:01 +00:00
Resolved merge conflicts
This commit is contained in:
parent
3298e50105
commit
967dd0aa08
82 changed files with 66055 additions and 0 deletions
33
llama_stack/distribution/utils/context.py
Normal file
33
llama_stack/distribution/utils/context.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
||||
24
llama_stack/env.py
Normal file
24
llama_stack/env.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class MissingCredentialError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_env_or_fail(key: str) -> str:
|
||||
"""Get environment variable or raise helpful error"""
|
||||
value = os.getenv(key)
|
||||
if not value:
|
||||
raise MissingCredentialError(
|
||||
f"\nMissing {key} in environment. Please set it using one of these methods:"
|
||||
f"\n1. Export in shell: export {key}=your-key"
|
||||
f"\n2. Create .env file in project root with: {key}=your-key"
|
||||
f"\n3. Pass directly to pytest: pytest --env {key}=your-key"
|
||||
)
|
||||
return value
|
||||
203
llama_stack/log.py
Normal file
203
llama_stack/log.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# Predefined categories
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
||||
Parameters:
|
||||
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=120)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
"""Override emit to handle markup errors gracefully."""
|
||||
try:
|
||||
super().emit(record)
|
||||
except MarkupError:
|
||||
original_markup = self.markup
|
||||
self.markup = False
|
||||
try:
|
||||
super().emit(record)
|
||||
finally:
|
||||
self.markup = original_markup
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
log_file (str): Path to a log file to additionally pipe the logs into
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
class CategoryFilter(logging.Filter):
|
||||
"""Ensure category is always present in log records."""
|
||||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
handlers = {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use custom console handler
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
}
|
||||
|
||||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"rich": {
|
||||
"()": logging.Formatter,
|
||||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": handlers,
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": list(handlers.keys()), # Apply all handlers
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": list(handlers.keys()),
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
# Ensure third-party libraries follow the root log level
|
||||
for _, logger in logging.root.manager.loggerDict.items():
|
||||
if isinstance(logger, logging.Logger):
|
||||
logger.setLevel(root_level)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
||||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
||||
setup_logging(_category_levels, log_file)
|
||||
170
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
170
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import vllm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
# This file contains OpenAI compatibility code that is currently only used
|
||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||
# central location at a later date.
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tools": converted_tools,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
|
||||
REMOVED_EXPRESSIONS = [
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
|
||||
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||
if isinstance(expression, float):
|
||||
return expression
|
||||
new_expression = f"{expression}"
|
||||
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||
for match in re.finditer(regex, expression):
|
||||
try:
|
||||
value = float(match.group(1)) / float(match.group(2))
|
||||
new_expression = new_expression.replace(
|
||||
match.group(),
|
||||
f"{{value:{fmt}}}".format(value=value),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
value = parse_latex(expression).evalf() # type: ignore
|
||||
return f"{{value:{fmt}}}".format(value=value)
|
||||
except Exception:
|
||||
return expression
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
start = answer.rfind(box_start)
|
||||
if start < 0:
|
||||
return ""
|
||||
answer = answer[start + len(box_start) :].strip()
|
||||
ends_with_curly = answer.startswith("{")
|
||||
i = 0
|
||||
open_braces = 0
|
||||
while i < len(answer):
|
||||
if answer[i] == "{":
|
||||
open_braces += 1
|
||||
elif answer[i] == "}":
|
||||
open_braces -= 1
|
||||
if open_braces == 0:
|
||||
if ends_with_curly:
|
||||
answer = answer[: i + 1].strip()
|
||||
break
|
||||
elif answer[i] == "$":
|
||||
answer = answer[:i].strip()
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
return ""
|
||||
# remove extra curly braces
|
||||
while True:
|
||||
if answer.startswith("{") and answer.endswith("}"):
|
||||
answer = answer[1:-1].strip()
|
||||
else:
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
if len(match) > 0:
|
||||
if match_first:
|
||||
extraction = match[0]
|
||||
else:
|
||||
extraction = match[-1]
|
||||
else:
|
||||
extraction = extract_result_from_boxed(final_answer)
|
||||
|
||||
if len(extraction) == 0:
|
||||
return final_answer
|
||||
else:
|
||||
final_answer = extraction
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
for before, after in SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||
# Example: (a) -> a (but not (ab) -> ab)
|
||||
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("cfrac", "frac")
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
string = string.split("=")[-1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if len(split) == 0:
|
||||
return string
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) == 0:
|
||||
return string
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
return string
|
||||
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
19
llama_stack/providers/inline/vector_io/milvus/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
20
llama_stack/providers/inline/vector_io/milvus/config.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
||||
21
llama_stack/providers/remote/vector_io/milvus/__init__.py
Normal file
21
llama_stack/providers/remote/vector_io/milvus/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .milvus import MilvusVectorIOAdapter
|
||||
|
||||
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
22
llama_stack/providers/remote/vector_io/milvus/config.py
Normal file
22
llama_stack/providers/remote/vector_io/milvus/config.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
consistency_level: str = "Strong"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
||||
175
llama_stack/providers/remote/vector_io/milvus/milvus.py
Normal file
175
llama_stack/providers/remote/vector_io/milvus/milvus.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
|
||||
self.client = client
|
||||
self.collection_name = collection_name.replace("-", "_")
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
async def delete(self):
|
||||
if self.client.has_collection(self.collection_name):
|
||||
self.client.drop_collection(collection_name=self.collection_name)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
if not self.client.has_collection(self.collection_name):
|
||||
self.client.create_collection(
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"chunk_id": chunk_id,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
self.client.insert(
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
)
|
||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||
scores = [res["distance"] for res in search_res[0]]
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
||||
|
||||
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file
|
||||
26
llama_stack/providers/utils/scoring/basic_scoring_utils.py
Normal file
26
llama_stack/providers/utils/scoring/basic_scoring_utils.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import contextlib
|
||||
import signal
|
||||
from types import FrameType
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
raise TimeoutError("Timed out!")
|
||||
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
signal.signal(signal.SIGALRM, signal_handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
7
llama_stack/templates/open-benchmark/__init__.py
Normal file
7
llama_stack/templates/open-benchmark/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .open_benchmark import get_distribution_template # noqa: F401
|
||||
36
llama_stack/templates/open-benchmark/build.yaml
Normal file
36
llama_stack/templates/open-benchmark/build.yaml
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Distribution for running open benchmarks
|
||||
providers:
|
||||
inference:
|
||||
- remote::openai
|
||||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::together
|
||||
vector_io:
|
||||
- inline::sqlite-vec
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
300
llama_stack/templates/open-benchmark/open_benchmark.py
Normal file
300
llama_stack/templates/open-benchmark/open_benchmark.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BenchmarkInput,
|
||||
DatasetInput,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
)
|
||||
|
||||
|
||||
def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]:
|
||||
# in this template, we allow each API key to be optional
|
||||
providers = [
|
||||
(
|
||||
"openai",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="openai/gpt-4o",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/claude-3-5-sonnet-latest",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"gemini",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/gemini-1.5-flash",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"groq",
|
||||
[],
|
||||
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
|
||||
),
|
||||
(
|
||||
"together",
|
||||
[],
|
||||
TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"),
|
||||
),
|
||||
]
|
||||
inference_providers = []
|
||||
available_models = {}
|
||||
for provider_id, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=f"remote::{provider_id}",
|
||||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, available_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": [p.provider_type for p in inference_providers],
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
name = "open-benchmark"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
db="${env.PGVECTOR_DB:}",
|
||||
user="${env.PGVECTOR_USER:}",
|
||||
password="${env.PGVECTOR_PASSWORD:}",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
default_models = get_model_registry(available_models) + [
|
||||
ModelInput(
|
||||
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
||||
provider_id="groq",
|
||||
provider_model_id="groq/llama-3.3-70b-versatile",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ModelInput(
|
||||
model_id="meta-llama/Llama-3.1-405B-Instruct",
|
||||
provider_id="together",
|
||||
provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
|
||||
default_datasets = [
|
||||
DatasetInput(
|
||||
dataset_id="simpleqa",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
|
||||
metadata={
|
||||
"path": "llamastack/simpleqa",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="mmlu_cot",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
|
||||
metadata={
|
||||
"path": "llamastack/mmlu_cot",
|
||||
"name": "all",
|
||||
"split": "test",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="gpqa_cot",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
|
||||
metadata={
|
||||
"path": "llamastack/gpqa_0shot_cot",
|
||||
"name": "gpqa_main",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="math_500",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
|
||||
metadata={
|
||||
"path": "llamastack/math_500",
|
||||
"split": "test",
|
||||
},
|
||||
dataset_schema={
|
||||
"input_query": {"type": "string"},
|
||||
"expected_answer": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
default_benchmarks = [
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-simpleqa",
|
||||
dataset_id="simpleqa",
|
||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-mmlu-cot",
|
||||
dataset_id="mmlu_cot",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-gpqa-cot",
|
||||
dataset_id="gpqa_cot",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-math-500",
|
||||
dataset_id="math_500",
|
||||
scoring_functions=["basic::regex_parser_math_response"],
|
||||
),
|
||||
]
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Distribution for running open benchmarks",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": inference_providers,
|
||||
"vector_io": vector_io_providers,
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
default_datasets=default_datasets,
|
||||
default_benchmarks=default_benchmarks,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"TOGETHER_API_KEY": (
|
||||
"",
|
||||
"Together API Key",
|
||||
),
|
||||
"OPENAI_API_KEY": (
|
||||
"",
|
||||
"OpenAI API Key",
|
||||
),
|
||||
"GEMINI_API_KEY": (
|
||||
"",
|
||||
"Gemini API Key",
|
||||
),
|
||||
"ANTHROPIC_API_KEY": (
|
||||
"",
|
||||
"Anthropic API Key",
|
||||
),
|
||||
"GROQ_API_KEY": (
|
||||
"",
|
||||
"Groq API Key",
|
||||
),
|
||||
},
|
||||
)
|
||||
249
llama_stack/templates/open-benchmark/run.yaml
Normal file
249
llama_stack/templates/open-benchmark/run.yaml
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
version: '2'
|
||||
image_name: open-benchmark
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db
|
||||
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:}
|
||||
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:localhost}
|
||||
port: ${env.PGVECTOR_PORT:5432}
|
||||
db: ${env.PGVECTOR_DB:}
|
||||
user: ${env.PGVECTOR_USER:}
|
||||
password: ${env.PGVECTOR_PASSWORD:}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: anthropic/claude-3-5-sonnet-latest
|
||||
provider_id: anthropic
|
||||
provider_model_id: anthropic/claude-3-5-sonnet-latest
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: gemini/gemini-1.5-flash
|
||||
provider_id: gemini
|
||||
provider_model_id: gemini/gemini-1.5-flash
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/llama-3.3-70b-versatile
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct
|
||||
provider_id: together
|
||||
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
|
||||
model_type: llm
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
vector_dbs: []
|
||||
datasets:
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/simpleqa
|
||||
metadata:
|
||||
path: llamastack/simpleqa
|
||||
split: train
|
||||
dataset_id: simpleqa
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
|
||||
metadata:
|
||||
path: llamastack/mmlu_cot
|
||||
name: all
|
||||
split: test
|
||||
dataset_id: mmlu_cot
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
|
||||
metadata:
|
||||
path: llamastack/gpqa_0shot_cot
|
||||
name: gpqa_main
|
||||
split: train
|
||||
dataset_id: gpqa_cot
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
input_query:
|
||||
type: string
|
||||
expected_answer:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/math_500
|
||||
metadata:
|
||||
path: llamastack/math_500
|
||||
split: test
|
||||
dataset_id: math_500
|
||||
provider_id: huggingface
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- dataset_id: simpleqa
|
||||
scoring_functions:
|
||||
- llm-as-judge::405b-simpleqa
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-simpleqa
|
||||
- dataset_id: mmlu_cot
|
||||
scoring_functions:
|
||||
- basic::regex_parser_multiple_choice_answer
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-mmlu-cot
|
||||
- dataset_id: gpqa_cot
|
||||
scoring_functions:
|
||||
- basic::regex_parser_multiple_choice_answer
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-gpqa-cot
|
||||
- dataset_id: math_500
|
||||
scoring_functions:
|
||||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
server:
|
||||
port: 8321
|
||||
Loading…
Add table
Add a link
Reference in a new issue