mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 07:52:02 +00:00
Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
716cb09056
145 changed files with 21384 additions and 1283 deletions
|
|
@ -14,6 +14,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
|
|
|
|||
|
|
@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
|
|||
|
||||
:param role: Must be "tool" to identify this as a tool response
|
||||
:param call_id: Unique identifier for the tool call this response is for
|
||||
:param tool_name: Name of the tool that was called
|
||||
:param content: The response content from the tool
|
||||
"""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
content: InterleavedContent
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,6 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
|
|
@ -32,14 +25,21 @@ class HealthInfo(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
from .providers import * # noqa: F401 F403
|
||||
36
llama_stack/apis/providers/providers.py
Normal file
36
llama_stack/apis/providers/providers.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Providers(Protocol):
|
||||
"""
|
||||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
|
|
@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
|
|||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
metric: str
|
||||
value: Union[int, float]
|
||||
unit: Optional[str] = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
# The ideal way to do this is to have a way for all response types to include metrics
|
||||
# and all metric events logged to the telemetry API to be inlcuded with the response
|
||||
|
|
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
|
|||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
metrics: Optional[List[MetricEvent]] = None
|
||||
metrics: Optional[List[MetricInResponse]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now() > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ class ModelPromptFormat(Subcommand):
|
|||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="llama3_1",
|
||||
help="Model Family (llama3_1, llama3_X, etc.)",
|
||||
help="Example: Llama3.1-8B or Llama3.2-11B-Vision, etc\n"
|
||||
"(Run `llama model list` to see a list of valid model names)",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-l",
|
||||
|
|
@ -60,7 +60,6 @@ class ModelPromptFormat(Subcommand):
|
|||
]
|
||||
|
||||
model_list = [m.value for m in supported_model_ids]
|
||||
model_str = "\n".join(model_list)
|
||||
|
||||
if args.list:
|
||||
headers = ["Model(s)"]
|
||||
|
|
@ -81,10 +80,16 @@ class ModelPromptFormat(Subcommand):
|
|||
try:
|
||||
model_id = CoreModelId(args.model_name)
|
||||
except ValueError:
|
||||
self.parser.error(f"{args.model_name} is not a valid Model. Choose one from --\n{model_str}")
|
||||
self.parser.error(
|
||||
f"{args.model_name} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
if model_id not in supported_model_ids:
|
||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||
self.parser.error(
|
||||
f"{model_id} is not a valid Model. Choose one from the list of valid models. "
|
||||
f"Run `llama model list` to see the valid model names."
|
||||
)
|
||||
|
||||
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if config.apis:
|
||||
apis_to_serve = config.apis
|
||||
else:
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect)]
|
||||
apis_to_serve = [a.value for a in Api if a not in (Api.telemetry, Api.inspect, Api.providers)]
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
|
|
|||
|
|
@ -117,6 +117,14 @@ class Provider(BaseModel):
|
|||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: Dict[str, str] = Field(
|
||||
default_factory=Dict,
|
||||
description="""
|
||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||
)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
|
@ -176,6 +184,8 @@ a default SQLite store will be used.""",
|
|||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
|
|
|
|||
59
llama_stack/distribution/providers.py
Normal file
59
llama_stack/distribution/providers.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
all_providers = await self.list_providers()
|
||||
for p in all_providers.data:
|
||||
if p.provider_id == provider_id:
|
||||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
|
@ -59,6 +60,7 @@ class InvalidProviderError(Exception):
|
|||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
|
|
@ -247,6 +249,25 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
|
|
@ -206,12 +206,12 @@ class InferenceRouter(Inference):
|
|||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return metrics
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
|
|
@ -238,7 +238,6 @@ class InferenceRouter(Inference):
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
"core",
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
|
|
@ -306,34 +306,42 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
log_line = ""
|
||||
if args.yaml_config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
log_line = f"Using config file: {config_file}"
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file, "r") as fp:
|
||||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
|
@ -368,6 +376,7 @@ def main():
|
|||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
|
|
@ -44,6 +45,7 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class LlamaStack(
|
||||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def preserve_contexts_async_generator(
|
|||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
|
||||
async def wrapper():
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
|
|
|
|||
|
|
@ -7,13 +7,15 @@
|
|||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
from .distribution.datatypes import LoggingConfig
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
|
|
@ -34,6 +36,56 @@ CATEGORIES = [
|
|||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def config_to_category_levels(category: str, level: str):
|
||||
"""
|
||||
Helper function to be called either by environment parsing or yaml parsing to go from a list of categories and levels to a dictionary ready to be
|
||||
used by the logger dictConfig.
|
||||
|
||||
Parameters:
|
||||
category (str): logging category to apply the level to
|
||||
level (str): logging level to be used in the category
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
|
||||
category_levels: Dict[str, int] = {}
|
||||
level_value = logging._nameToLevel.get(str(level).upper())
|
||||
if level_value is None:
|
||||
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
||||
return category_levels
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
||||
"""
|
||||
Helper function to parse a yaml logging configuration found in the run.yaml
|
||||
|
||||
Parameters:
|
||||
yaml_config (Logging): the logger config object found in the run.yaml
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for category, level in yaml_config.category_levels.items():
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
|
@ -53,25 +105,7 @@ def parse_environment_config(env_config: str) -> Dict[str, int]:
|
|||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
category_levels.update(config_to_category_levels(category=category, level=level))
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
|
@ -176,7 +210,9 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
|
|||
logger.setLevel(root_level)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
def get_logger(
|
||||
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
|
||||
) -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
|
@ -184,10 +220,14 @@ def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdap
|
|||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
config (Logging): optional yaml config to override the existing logger configuration
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
if config:
|
||||
_category_levels.update(parse_yaml_config(config))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|||
)
|
||||
return PromptTemplate(
|
||||
template_str.lstrip("\n"),
|
||||
{"today": datetime.now().strftime("%d %B %Y")},
|
||||
{
|
||||
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
|
||||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -153,7 +153,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
tool_name=response.tool_name,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
|
|
@ -181,6 +180,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
|
|
@ -191,6 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
|
|
@ -219,8 +220,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages = await self.get_messages_from_turns(turns)
|
||||
if is_resume:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
|
||||
for x in request.tool_responses
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
last_turn = turns[-1]
|
||||
|
|
@ -239,7 +239,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now().astimezone().isoformat()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
|
|
@ -264,7 +264,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now().astimezone().isoformat()
|
||||
start_time = datetime.now(timezone.utc).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
|
|
@ -275,7 +275,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
toolgroups_for_turn=request.toolgroups if not is_resume else None,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
|
@ -296,7 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
|
@ -327,7 +326,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
|
@ -350,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -389,7 +386,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now().astimezone().isoformat()
|
||||
shield_call_start_time = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
@ -413,7 +410,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -436,7 +433,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -451,30 +448,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroups.add(tool_group_name)
|
||||
toolgroup_args[tool_group_name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
await self.handle_documents(session_id, documents, input_messages)
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info and session_info.vector_db_id:
|
||||
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||
else:
|
||||
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
for tool_name in self.tool_name_to_args.keys():
|
||||
if tool_name == MEMORY_QUERY_TOOL:
|
||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
||||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
|
@ -486,7 +472,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now().astimezone().isoformat()
|
||||
inference_start_time = datetime.now(timezone.utc).isoformat()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
|
|
@ -504,7 +490,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=tool_defs,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
|
|
@ -596,7 +582,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -667,7 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[],
|
||||
started_at=datetime.now().astimezone().isoformat(),
|
||||
started_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
yield message
|
||||
|
|
@ -684,14 +670,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_result = await execute_tool_call_maybe(
|
||||
self.tool_runtime_api,
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
if tool_result.content is None:
|
||||
raise ValueError(
|
||||
|
|
@ -700,7 +683,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
result_messages = [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
)
|
||||
]
|
||||
|
|
@ -720,13 +702,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result_message.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -744,9 +726,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
input_messages = input_messages + [message, result_message]
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||
|
||||
# Determine which tools to include
|
||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||
agent_config_toolgroups = []
|
||||
|
|
@ -755,8 +744,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if name not in agent_config_toolgroups:
|
||||
agent_config_toolgroups.append(name)
|
||||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
|
|
@ -774,53 +765,38 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if not tools.data:
|
||||
available_tool_groups = ", ".join(
|
||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||
)
|
||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||
if identifier == "web_search":
|
||||
identifier = BuiltinTool.brave_search
|
||||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
identifier = BuiltinTool(identifier)
|
||||
else:
|
||||
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
|
||||
if input_tool_name in (None, tool_def.identifier):
|
||||
identifier = tool_def.identifier
|
||||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_name_to_def.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
if tool_name in (None, tool_def.identifier):
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
tool_name=identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
|
|
@ -832,9 +808,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
return list(tool_name_to_def.values()), tool_to_group
|
||||
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
|
@ -853,15 +829,46 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
self,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
) -> ToolInvocationResult:
|
||||
tool_name = tool_call.tool_name
|
||||
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
|
||||
if tool_name not in registered_tool_names:
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
|
||||
)
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
if tool_name == BuiltinTool.brave_search:
|
||||
tool_name_str = WEB_SEARCH_TOOL
|
||||
else:
|
||||
tool_name_str = tool_name.value
|
||||
else:
|
||||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
|
@ -989,42 +996,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
tool_call: ToolCall,
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> ToolInvocationResult:
|
||||
name = tool_call.tool_name
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
else:
|
||||
name = name.value
|
||||
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -36,7 +36,7 @@ class AgentPersistence:
|
|||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(),
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class LocalFSDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to localfs storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="localfs_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
|
||||
) # Uses SQLite config specific to Meta Reference Eval storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="meta_reference_eval.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, UserMessage
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
|
|
@ -118,7 +118,7 @@ class MetaReferenceEvalImpl(
|
|||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
|
|
@ -168,10 +168,11 @@ class MetaReferenceEvalImpl(
|
|||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
|
@ -40,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -12,3 +12,9 @@ from pydantic import BaseModel
|
|||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(),
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now()
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
|
@ -93,7 +93,7 @@ class TorchtunePostTrainingImpl:
|
|||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now()
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
|
@ -532,7 +532,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
|
@ -23,3 +24,9 @@ class PromptGuardConfig(BaseModel):
|
|||
if v not in [t.value for t in PromptGuardType]:
|
||||
raise ValueError(f"Unknown prompt guard type: {v}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel): ...
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -22,12 +22,19 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
FIXED_FNS = [
|
||||
EqualityScoringFn,
|
||||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
||||
|
|
@ -3,10 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
||||
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
989
llama_stack/providers/inline/scoring/basic/utils/bfcl/checker.py
Normal file
|
|
@ -0,0 +1,989 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
||||
|
|
@ -3,7 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlmAsJudgeScoringConfig(BaseModel): ...
|
||||
class LlmAsJudgeScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
|
|
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
|
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
|
|
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime("%H:%M:%S.%f")[:-3]
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
|
|
@ -124,8 +124,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -143,8 +143,8 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
|
|
@ -161,7 +161,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleTelemetryImpl
|
||||
|
||||
impl = SampleTelemetryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleTelemetryImpl(Telemetry):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -76,6 +76,7 @@ class CodeExecutionRequest:
|
|||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
use_bwrap: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
|
|
@ -103,8 +104,6 @@ _set_seeds()\
|
|||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
|
@ -118,6 +117,13 @@ _set_seeds()\
|
|||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
|
||||
if req.use_bwrap:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
else:
|
||||
cmd = [sys.executable, "-c", script]
|
||||
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
|
|
@ -162,7 +168,7 @@ def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
|||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d")
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -61,7 +62,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
script = kwargs["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
# Use environment variable to control bwrap usage
|
||||
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
|
||||
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {"db_path": "{env.CHROMADB_PATH}"}
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
|
||||
|
|
@ -39,13 +37,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.tool_groups,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.agents,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.agents.sample",
|
||||
config_class="llama_stack.providers.remote.agents.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[],
|
||||
pip_packages=["tree_sitter"],
|
||||
module="llama_stack.providers.inline.eval.meta_reference",
|
||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
|||
|
|
@ -68,15 +68,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.sample",
|
||||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -27,27 +27,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[
|
||||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
deprecation_error="""
|
||||
Provider `inline::meta-reference` for API `safety` does not work with the latest Llama Stack.
|
||||
|
||||
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
|
||||
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
|
||||
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
|
||||
|
||||
""",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
provider_type="inline::llama-guard",
|
||||
|
|
@ -67,15 +46,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
|
|||
module="llama_stack.providers.inline.safety.code_scanner",
|
||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.safety.sample",
|
||||
config_class="llama_stack.providers.remote.safety.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -28,13 +26,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.telemetry.meta_reference",
|
||||
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.telemetry,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.telemetry.sample",
|
||||
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -92,16 +92,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.vector_io,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sample",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.vector_io.sample",
|
||||
config_class="llama_stack.providers.remote.vector_io.sample.SampleVectorIOConfig",
|
||||
),
|
||||
api_dependencies=[],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleAgentsImpl
|
||||
|
||||
impl = SampleAgentsImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleAgentsImpl(Agents):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -3,9 +3,10 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
|
|
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
|
|||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix()
|
||||
) # Uses SQLite config specific to HF storage
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="huggingface_datasetio.db",
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -21,3 +21,10 @@ class RunpodImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.models.llama.datatypes import Message
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleInferenceImpl
|
||||
|
||||
impl = SampleInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleInferenceImpl(Inference):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleSafetyImpl
|
||||
|
||||
impl = SampleSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -14,3 +14,9 @@ class BingSearchToolConfig(BaseModel):
|
|||
|
||||
api_key: Optional[str] = None
|
||||
top_k: int = 3
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.BING_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -13,3 +13,9 @@ class WolframAlphaToolConfig(BaseModel):
|
|||
"""Configuration for WolframAlpha Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -24,3 +24,9 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
timeout: Optional[int] = None
|
||||
host: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.QDRANT_API_KEY}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
|
||||
from .sample import SampleVectorIOImpl
|
||||
|
||||
impl = SampleVectorIOImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleVectorIOConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
|
||||
from .config import SampleVectorIOConfig
|
||||
|
||||
|
||||
class SampleVectorIOImpl(VectorIO):
|
||||
def __init__(self, config: SampleVectorIOConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# these are the vector dbs the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -13,4 +15,6 @@ class WeaviateRequestProviderData(BaseModel):
|
|||
|
||||
|
||||
class WeaviateVectorIOConfig(BaseModel):
|
||||
pass
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ class ColumnName(Enum):
|
|||
generated_answer = "generated_answer"
|
||||
context = "context"
|
||||
dialog = "dialog"
|
||||
function = "function"
|
||||
language = "language"
|
||||
id = "id"
|
||||
ground_truth = "ground_truth"
|
||||
|
||||
|
||||
VALID_SCHEMAS_FOR_SCORING = [
|
||||
|
|
@ -37,6 +41,15 @@ VALID_SCHEMAS_FOR_SCORING = [
|
|||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.context.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
VALID_SCHEMAS_FOR_EVAL = [
|
||||
|
|
@ -50,6 +63,15 @@ VALID_SCHEMAS_FOR_EVAL = [
|
|||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import logging
|
|||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ class TraceContext:
|
|||
span_id=generate_short_uuid(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(),
|
||||
start_time=datetime.now(timezone.utc),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
|
@ -203,7 +203,7 @@ class TelemetryHandler(logging.Handler):
|
|||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,14 +45,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
@ -43,14 +44,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -47,14 +48,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -50,14 +51,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -46,14 +47,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -62,7 +62,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -81,14 +82,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -56,14 +56,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
@ -88,7 +100,8 @@ providers:
|
|||
max_results: 3
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -50,14 +51,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
@ -82,7 +95,8 @@ providers:
|
|||
max_results: 3
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -50,14 +51,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -55,14 +56,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -50,14 +51,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -55,14 +56,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -50,14 +51,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -57,14 +58,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -51,14 +52,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -53,14 +54,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -47,14 +48,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
|
|||
|
|
@ -49,14 +49,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
@ -90,7 +102,8 @@ providers:
|
|||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -46,14 +47,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
@ -87,7 +100,8 @@ providers:
|
|||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config: {}
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
|
|
|
|||
|
|
@ -226,6 +226,22 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="bfcl",
|
||||
provider_id="huggingface",
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/bfcl_v3"),
|
||||
metadata={
|
||||
"path": "llamastack/bfcl_v3",
|
||||
"split": "train",
|
||||
},
|
||||
dataset_schema={
|
||||
"function": {"type": "string"},
|
||||
"language": {"type": "string"},
|
||||
"ground_truth": {"type": "string"},
|
||||
"id": {"type": "string"},
|
||||
"chat_completion_input": {"type": "string"},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
default_benchmarks = [
|
||||
|
|
@ -249,6 +265,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
dataset_id="math_500",
|
||||
scoring_functions=["basic::regex_parser_math_response"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-bfcl",
|
||||
dataset_id="bfcl",
|
||||
scoring_functions=["basic::bfcl"],
|
||||
),
|
||||
]
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
|
|
|||
|
|
@ -54,7 +54,8 @@ providers:
|
|||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
@ -73,14 +74,26 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
|
@ -203,6 +216,24 @@ datasets:
|
|||
split: test
|
||||
dataset_id: math_500
|
||||
provider_id: huggingface
|
||||
- dataset_schema:
|
||||
function:
|
||||
type: string
|
||||
language:
|
||||
type: string
|
||||
ground_truth:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
chat_completion_input:
|
||||
type: string
|
||||
url:
|
||||
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
|
||||
metadata:
|
||||
path: llamastack/bfcl_v3
|
||||
split: train
|
||||
dataset_id: bfcl
|
||||
provider_id: huggingface
|
||||
scoring_fns: []
|
||||
benchmarks:
|
||||
- dataset_id: simpleqa
|
||||
|
|
@ -225,6 +256,11 @@ benchmarks:
|
|||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
- dataset_id: bfcl
|
||||
scoring_functions:
|
||||
- basic::bfcl
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-bfcl
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
|
|
|
|||
|
|
@ -4,9 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
from .passthrough import get_distribution_template # noqa: F401
|
||||
|
|
@ -1,9 +1,10 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Use for running LLM inference with the endpoint that compatible with Llama Stack API
|
||||
description: Use Passthrough hosted llama-stack endpoint for LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- remote::passthrough
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
|
|
@ -26,6 +27,7 @@ distribution_spec:
|
|||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- remote::wolfram-alpha
|
||||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue